Move away from deprecated asserts
- assertEquals -> assertEqual - assertRaisesRegexp -> assertRegexpMatches - assertRegexpMatches -> assertRegex PiperOrigin-RevId: 319118081 Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
This commit is contained in:
parent
ec07b637ce
commit
f618ab4955
tensorflow
compiler/tests
add_n_test.pybucketize_op_test.pyconcat_ops_test.pycond_test.pyeager_test.pyensure_shape_op_test.pyfifo_queue_test.pyimage_ops_test.pymomentum_test.pytensor_array_ops_test.pytensor_list_ops_test.pytridiagonal_solve_ops_test.pyvariable_ops_test.pyxla_ops_test.py
lite
python
autograph
converters
impl
lang
operators
pyct
client
compiler
data
experimental/kernel_tests
assert_cardinality_test.pydense_to_sparse_batch_test.pydirected_interleave_dataset_test.pyget_single_element_test.pygroup_by_reducer_test.pygroup_by_window_test.pymake_batched_features_dataset_test.pymake_csv_dataset_test.pymap_and_batch_test.pymap_defun_op_test.py
optimization
rebatch_dataset_test.pyscan_test.pyserialization
kernel_tests
concatenate_test.pydataset_test.pyfrom_generator_test.pyiterator_test.pymap_test.pyoptions_test.pypadded_batch_test.py
util
debug
cli
analyzer_cli_test.pycli_shared_test.pycommand_parser_test.pycurses_ui_test.pycurses_widgets_test.pydebugger_cli_common_test.pyevaluator_test.pyprofile_analyzer_cli_test.pyreadline_ui_test.pytensor_format_test.py
lib
debug_data_test.pydebug_events_writer_test.pydebug_gradients_test.pydebug_graphs_test.pydebug_v2_ops_test.pydumping_callback_test.pysession_debug_grpc_test.pysource_utils_test.py
wrappers
distribute
all_reduce_test.py
cluster_resolver
custom_training_loop_input_test.pydistribute_lib_test.pymirrored_strategy_test.pymirrored_variable_test.pymulti_process_runner_no_init_test.pymulti_process_runner_test.pymulti_worker_util_test.pysharded_variable_test.pyshared_variable_creator_test.pystrategy_combinations_test.pyeager
@ -50,7 +50,7 @@ class XlaAddNTest(xla_test.XLATestCase):
|
||||
l2 = list_ops.tensor_list_reserve(
|
||||
element_shape=[], element_dtype=dtypes.float32, num_elements=3)
|
||||
l = math_ops.add_n([l1, l2])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"TensorList arguments to AddN must all have the same shape"):
|
||||
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
|
||||
@ -70,7 +70,7 @@ class XlaAddNTest(xla_test.XLATestCase):
|
||||
element_dtype=dtypes.float32,
|
||||
num_elements=3)
|
||||
l = math_ops.add_n([l1, l2])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"TensorList arguments to AddN must all have the same shape"):
|
||||
session.run(
|
||||
|
@ -64,13 +64,13 @@ class BucketizationOpTest(xla_test.XLATestCase):
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
with self.test_scope():
|
||||
op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11])
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"Expected sorted boundaries"):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
"Expected sorted boundaries"):
|
||||
sess.run(op, {p: [-5, 0]})
|
||||
|
||||
def testBoundariesNotList(self):
|
||||
with self.session():
|
||||
with self.assertRaisesRegexp(TypeError, "Expected list.*"):
|
||||
with self.assertRaisesRegex(TypeError, "Expected list.*"):
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
with self.test_scope():
|
||||
math_ops._bucketize(p, boundaries=0)
|
||||
|
@ -288,7 +288,7 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
scalar = constant_op.constant(7)
|
||||
dim = array_ops.placeholder(dtypes.int32)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
|
||||
array_ops.concat([scalar, scalar, scalar], dim)
|
||||
|
||||
|
@ -175,8 +175,8 @@ class CondTest(xla_test.XLATestCase):
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(True), if_true, if_false)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must be a compile-time constant"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"must be a compile-time constant"):
|
||||
sess.run(
|
||||
output, feed_dict={
|
||||
x: [0., 1., 2.],
|
||||
@ -209,8 +209,8 @@ class CondTest(xla_test.XLATestCase):
|
||||
|
||||
output = xla.compile(f)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must be a compile-time constant"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"must be a compile-time constant"):
|
||||
sess.run(
|
||||
output, feed_dict={
|
||||
x: [0., 1., 2.],
|
||||
|
@ -704,8 +704,8 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([0.0, 4.0], r_y)
|
||||
if context.executing_eagerly():
|
||||
# backing_device is only available for eager tensors.
|
||||
self.assertRegexpMatches(r_x.backing_device, self.device)
|
||||
self.assertRegexpMatches(r_y.backing_device, self.device)
|
||||
self.assertRegex(r_x.backing_device, self.device)
|
||||
self.assertRegex(r_y.backing_device, self.device)
|
||||
|
||||
# When function is executed op-by-op, requested devices will be
|
||||
# respected.
|
||||
@ -714,8 +714,8 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([0.0, 4.0], r_y)
|
||||
if context.executing_eagerly():
|
||||
# backing_device is only available for eager tensors.
|
||||
self.assertRegexpMatches(r_x.backing_device, self.device)
|
||||
self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
|
||||
self.assertRegex(r_x.backing_device, self.device)
|
||||
self.assertRegex(r_y.backing_device, 'device:CPU:0')
|
||||
|
||||
|
||||
class ExcessivePaddingTest(xla_test.XLATestCase):
|
||||
|
@ -42,8 +42,8 @@ class EnsureShapeOpTest(xla_test.XLATestCase):
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
with self.test_scope():
|
||||
op = check_ops.ensure_shape(p, (None, 3, 3))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"is not compatible with expected shape"):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
"is not compatible with expected shape"):
|
||||
sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
|
||||
def testEnqueueDictWithoutNames(self):
|
||||
with self.session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
with self.assertRaisesRegexp(ValueError, "must have names"):
|
||||
with self.assertRaisesRegex(ValueError, "must have names"):
|
||||
q.enqueue({"a": 12.0})
|
||||
|
||||
def testParallelEnqueue(self):
|
||||
|
@ -297,7 +297,7 @@ class AdjustHueTest(xla_test.XLATestCase):
|
||||
x_np = np.random.rand(2, 3) * 255.
|
||||
delta_h = np.random.rand() * 2.0 - 1.0
|
||||
fused = False
|
||||
with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 3"):
|
||||
with self.assertRaisesRegex(ValueError, "Shape must be at least rank 3"):
|
||||
self._adjustHueTf(x_np, delta_h)
|
||||
x_np = np.random.rand(4, 2, 4) * 255.
|
||||
delta_h = np.random.rand() * 2.0 - 1.0
|
||||
|
@ -54,10 +54,10 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
|
||||
# Check we have slots
|
||||
self.assertEqual(["momentum"], mom_opt.get_slot_names())
|
||||
slot0 = mom_opt.get_slot(var0, "momentum")
|
||||
self.assertEquals(slot0.get_shape(), var0.get_shape())
|
||||
self.assertEqual(slot0.get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot0 in variables.trainable_variables())
|
||||
slot1 = mom_opt.get_slot(var1, "momentum")
|
||||
self.assertEquals(slot1.get_shape(), var1.get_shape())
|
||||
self.assertEqual(slot1.get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot1 in variables.trainable_variables())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
@ -140,10 +140,10 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
|
||||
# Check we have slots
|
||||
self.assertEqual(["momentum"], mom_opt.get_slot_names())
|
||||
slot0 = mom_opt.get_slot(var0, "momentum")
|
||||
self.assertEquals(slot0.get_shape(), var0.get_shape())
|
||||
self.assertEqual(slot0.get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot0 in variables.trainable_variables())
|
||||
slot1 = mom_opt.get_slot(var1, "momentum")
|
||||
self.assertEquals(slot1.get_shape(), var1.get_shape())
|
||||
self.assertEqual(slot1.get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot1 in variables.trainable_variables())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
|
@ -393,9 +393,8 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
# Test writing the wrong datatype.
|
||||
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
|
||||
# callers provide proper init dtype.
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
r"("
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors.InvalidArgumentError), r"("
|
||||
r"conversion requested dtype float32 for Tensor with dtype int32"
|
||||
r"|"
|
||||
r"TensorArray dtype is float but op has dtype int32"
|
||||
|
@ -103,8 +103,8 @@ class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase):
|
||||
l = list_ops.tensor_list_push_back(
|
||||
l, constant_op.constant(1.0, shape=(7, 15)))
|
||||
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Set the max number of elements"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Set the max number of elements"):
|
||||
self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
|
||||
|
||||
def testEmptyTensorListMax(self):
|
||||
@ -174,7 +174,7 @@ class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase):
|
||||
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
||||
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
||||
# Pushing an element with a different shape should raise an error.
|
||||
with self.assertRaisesRegexp(errors.InternalError, "shape"):
|
||||
with self.assertRaisesRegex(errors.InternalError, "shape"):
|
||||
l = list_ops.tensor_list_push_back(l, 5.)
|
||||
self.evaluate(
|
||||
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
||||
|
@ -223,7 +223,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
|
||||
num_rhs)).astype(np.float32)
|
||||
|
||||
with self.session() as sess, self.test_scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors_impl.UnimplementedError,
|
||||
"Current implementation does not yet support pivoting."):
|
||||
diags = array_ops.placeholder(
|
||||
|
@ -485,8 +485,8 @@ class SliceAssignTest(xla_test.XLATestCase):
|
||||
checker2[None] = [6] # new axis
|
||||
|
||||
def testUninitialized(self):
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
"uninitialized"):
|
||||
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
||||
"uninitialized"):
|
||||
with self.session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable([1, 2])
|
||||
sess.run(v[:].assign([1, 2]))
|
||||
|
@ -343,7 +343,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([5, 7]), np.array([2, 3, 4]))
|
||||
with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
|
||||
session.run(output)
|
||||
self.assertRegexpMatches(
|
||||
self.assertRegex(
|
||||
invalid_arg_error.exception.message,
|
||||
(r'start_indices must be a vector with length equal to input rank, '
|
||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||
@ -357,7 +357,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([5, 7, 3]), np.array([2, 3]))
|
||||
with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
|
||||
session.run(output)
|
||||
self.assertRegexpMatches(
|
||||
self.assertRegex(
|
||||
invalid_arg_error.exception.message,
|
||||
(r'size_indices must be a vector with length equal to input rank, '
|
||||
r'but input rank is 3 and size_indices has shape \[2\].*'))
|
||||
|
@ -52,7 +52,7 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRegistererFailure(self):
|
||||
bogus_name = 'CompletelyBogusRegistererName'
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
|
||||
interpreter_wrapper.InterpreterWithCustomOps(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
@ -69,15 +69,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
||||
|
||||
def testThreads_NegativeValue(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'num_threads should >= 1'):
|
||||
with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=-1)
|
||||
|
||||
def testThreads_WrongType(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'type of num_threads should be int'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'type of num_threads should be int'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=4.2)
|
||||
@ -261,13 +260,13 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
|
||||
def testInvalidModelContent(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Model provided has model identifier \''):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Model provided has model identifier \''):
|
||||
interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
|
||||
|
||||
def testInvalidModelFile(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Could not open \'totally_invalid_file_name\''):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Could not open \'totally_invalid_file_name\''):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path='totally_invalid_file_name')
|
||||
|
||||
@ -275,12 +274,12 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'))
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'Invoke called on model that is not ready'):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'Invoke called on model that is not ready'):
|
||||
interpreter.invoke()
|
||||
|
||||
def testInvalidModelFileContent(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`model_path` or `model_content` must be specified.'):
|
||||
interpreter_wrapper.Interpreter(model_path=None, model_content=None)
|
||||
|
||||
@ -290,9 +289,9 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
'testdata/permute_float.tflite'))
|
||||
interpreter.allocate_tensors()
|
||||
# Invalid tensor index passed.
|
||||
with self.assertRaisesRegexp(ValueError, 'Tensor with no shape found.'):
|
||||
with self.assertRaisesRegex(ValueError, 'Tensor with no shape found.'):
|
||||
interpreter._get_tensor_details(4)
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid node index'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid node index'):
|
||||
interpreter._get_op_details(4)
|
||||
|
||||
|
||||
@ -339,12 +338,10 @@ class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
|
||||
def testBaseProtectsFunctions(self):
|
||||
in0 = self.interpreter.tensor(self.input0)()
|
||||
# Make sure we get an exception if we try to run an unsafe operation
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'There is at least 1 reference'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'):
|
||||
_ = self.interpreter.allocate_tensors()
|
||||
# Make sure we get an exception if we try to run an unsafe operation
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'There is at least 1 reference'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'):
|
||||
_ = self.interpreter.invoke()
|
||||
# Now test that we can run
|
||||
del in0 # this is our only buffer reference, so now it is safe to change
|
||||
@ -483,7 +480,7 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(lib.get_options_counter(), 2)
|
||||
|
||||
def testFail(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
# Due to exception chaining in PY3, we can't be more specific here and check that
|
||||
# the phrase 'Fail argument sent' is present.
|
||||
ValueError,
|
||||
|
@ -255,17 +255,17 @@ class TestSchemaUpgrade(test_util.TensorFlowTestCase):
|
||||
def testNonExistentFile(self):
|
||||
converter = upgrade_schema_lib.Converter()
|
||||
non_existent = tempfile.mktemp(suffix=".json")
|
||||
with self.assertRaisesRegexp(IOError, "No such file or directory"):
|
||||
with self.assertRaisesRegex(IOError, "No such file or directory"):
|
||||
converter.Convert(non_existent, non_existent)
|
||||
|
||||
def testInvalidExtension(self):
|
||||
converter = upgrade_schema_lib.Converter()
|
||||
invalid_extension = tempfile.mktemp(suffix=".foo")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid extension on input"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid extension on input"):
|
||||
converter.Convert(invalid_extension, invalid_extension)
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json:
|
||||
JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json)
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid extension on output"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid extension on output"):
|
||||
converter.Convert(in_json.name, invalid_extension)
|
||||
|
||||
def CheckConversion(self, data_old, data_expected):
|
||||
|
@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase):
|
||||
tr = self.transform(f, (functions, asserts, return_statements))
|
||||
|
||||
op = tr(constant_op.constant(False))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||
self.evaluate(op)
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
directives.set_loop_options()
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
|
||||
with self.assertRaisesRegex(ValueError, 'must be used inside a statement'):
|
||||
self.transform(f, directives_converter, include_ast=True)
|
||||
|
||||
def test_loop_target_not_first(self):
|
||||
@ -88,7 +88,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
a = 2
|
||||
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
|
||||
with self.assertRaisesRegex(ValueError, 'must be the first statement'):
|
||||
self.transform(f, directives_converter, include_ast=True)
|
||||
|
||||
def test_value_verification_does_not_trigger_properties(self):
|
||||
|
@ -587,7 +587,7 @@ class ApiTest(test.TestCase):
|
||||
opts = converter.ConversionOptions(internal_convert_user_code=False)
|
||||
|
||||
# f should not be converted, causing len to error out.
|
||||
with self.assertRaisesRegexp(Exception, 'len is not well defined'):
|
||||
with self.assertRaisesRegex(Exception, 'len is not well defined'):
|
||||
api.converted_call(f, (constant_op.constant([0]),), None, options=opts)
|
||||
|
||||
# len on the other hand should work fine.
|
||||
|
@ -62,12 +62,12 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||
with self.assertRaisesRegex(ValueError, 'unknown type'):
|
||||
special_functions.tensor_list(np.array([1, 2, 3]))
|
||||
|
||||
def test_tensor_list_empty_list_no_type(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'element_dtype and element_shape are required'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'element_dtype and element_shape are required'):
|
||||
special_functions.tensor_list([])
|
||||
|
||||
def test_tensor_list_from_elements(self):
|
||||
|
@ -48,7 +48,7 @@ class IfExpTest(test.TestCase):
|
||||
conditional_expressions.if_exp(
|
||||
constant_op.constant(True), lambda: 1.0, lambda: 2, 'expr_repr')
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"'expr_repr' has dtype float32 in the main.*int32 in the else"):
|
||||
test_fn()
|
||||
|
@ -685,7 +685,7 @@ class WhileLoopTest(test.TestCase):
|
||||
if not __debug__:
|
||||
self.skipTest('Feature disabled in optimized mode.')
|
||||
with test.mock.patch.object(control_flow, 'PYTHON_MAX_ITERATIONS', 100):
|
||||
with self.assertRaisesRegexp(ValueError, 'iteration limit'):
|
||||
with self.assertRaisesRegex(ValueError, 'iteration limit'):
|
||||
control_flow.while_stmt(
|
||||
test=lambda: True,
|
||||
body=lambda: None,
|
||||
@ -698,7 +698,7 @@ class WhileLoopTest(test.TestCase):
|
||||
if not __debug__:
|
||||
self.skipTest('Feature disabled in optimized mode.')
|
||||
with test.mock.patch.object(control_flow, 'PYTHON_MAX_ITERATIONS', 100):
|
||||
with self.assertRaisesRegexp(ValueError, 'iteration limit'):
|
||||
with self.assertRaisesRegex(ValueError, 'iteration limit'):
|
||||
control_flow.for_stmt(
|
||||
iter_=range(101),
|
||||
extra_test=None,
|
||||
|
@ -40,8 +40,8 @@ class ExceptionsTest(test.TestCase):
|
||||
constant_op.constant(False),
|
||||
lambda: constant_op.constant('test message'))
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
self.evaluate(t)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -54,8 +54,8 @@ class ExceptionsTest(test.TestCase):
|
||||
t = exceptions.assert_stmt(
|
||||
constant_op.constant(False), lambda: two_tensors)
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message.*another message'):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
'test message.*another message'):
|
||||
self.evaluate(t)
|
||||
|
||||
def test_assert_python_untriggered(self):
|
||||
@ -81,7 +81,7 @@ class ExceptionsTest(test.TestCase):
|
||||
side_effect_trace.append(tracer)
|
||||
return 'test message'
|
||||
|
||||
with self.assertRaisesRegexp(AssertionError, 'test message'):
|
||||
with self.assertRaisesRegex(AssertionError, 'test message'):
|
||||
exceptions.assert_stmt(False, expression_with_side_effects)
|
||||
self.assertListEqual(side_effect_trace, [tracer])
|
||||
|
||||
|
@ -211,7 +211,7 @@ class TransformerTest(test.TestCase):
|
||||
node = tr.visit(node)
|
||||
obtained_message = str(cm.exception)
|
||||
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
|
||||
self.assertRegexpMatches(obtained_message, expected_message)
|
||||
self.assertRegex(obtained_message, expected_message)
|
||||
|
||||
def test_robust_error_on_ast_corruption(self):
|
||||
# A child class should not be able to be so broken that it causes the error
|
||||
|
@ -73,7 +73,7 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
|
||||
def __str__(self):
|
||||
return "Invalid"
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "Invalid"):
|
||||
with self.assertRaisesRegex(TypeError, "Invalid"):
|
||||
_pywrap_events_writer.EventsWriter(b"foo").WriteEvent(_Invalid())
|
||||
|
||||
|
||||
|
@ -119,8 +119,8 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
x = array_ops.placeholder(dtypes.float32, shape=())
|
||||
fetches = [x * 2, x * 3]
|
||||
handle = sess.partial_run_setup(fetches=fetches, feeds=[])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'You must feed a value for placeholder'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'You must feed a value for placeholder'):
|
||||
sess.partial_run(handle, fetches[0])
|
||||
|
||||
def RunTestPartialRunUnspecifiedFeed(self, sess):
|
||||
@ -130,8 +130,8 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
r1 = math_ops.add(a, b)
|
||||
|
||||
h = sess.partial_run_setup([r1], [a, b])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3})
|
||||
|
||||
def RunTestPartialRunUnspecifiedFetch(self, sess):
|
||||
@ -142,8 +142,8 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
r2 = math_ops.multiply(a, c)
|
||||
|
||||
h = sess.partial_run_setup([r1], [a, b, c])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
|
||||
|
||||
def RunTestPartialRunAlreadyFed(self, sess):
|
||||
@ -155,8 +155,8 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'has already been fed.$'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'has already been fed.$'):
|
||||
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
|
||||
|
||||
def RunTestPartialRunAlreadyFetched(self, sess):
|
||||
@ -168,8 +168,8 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'has already been fetched.$'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'has already been fetched.$'):
|
||||
sess.partial_run(h, r1, feed_dict={c: 3})
|
||||
|
||||
def RunTestPartialRunEmptyFetches(self, sess):
|
||||
@ -185,7 +185,7 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
def testInvalidPartialRunSetup(self):
|
||||
sess = session.Session()
|
||||
x = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
'specify at least one target to fetch or execute.'):
|
||||
sess.partial_run_setup(fetches=[], feeds=[x])
|
||||
|
@ -1269,11 +1269,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testUseEmptyGraph(self):
|
||||
with session.Session() as sess:
|
||||
with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
|
||||
sess.run([])
|
||||
with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
|
||||
sess.run(())
|
||||
with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
|
||||
sess.run({})
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@ -1516,11 +1516,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
feed_t = array_ops.placeholder(dtype=dtypes.float32)
|
||||
out_t = array_ops.identity(feed_t)
|
||||
feed_val = constant_op.constant(5.0)
|
||||
with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
|
||||
with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
|
||||
sess.run(out_t, feed_dict={feed_t: feed_val})
|
||||
with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
|
||||
with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
|
||||
out_t.eval(feed_dict={feed_t: feed_val})
|
||||
with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
|
||||
with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
|
||||
out_t.op.run(feed_dict={feed_t: feed_val})
|
||||
|
||||
def testFeedPrecisionLossError(self):
|
||||
@ -1532,11 +1532,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
out_t = constant_op.constant(1.0)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'is not compatible with Tensor type'):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'is not compatible with Tensor type'):
|
||||
sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64})
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'is not compatible with Tensor type'):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'is not compatible with Tensor type'):
|
||||
sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64})
|
||||
|
||||
def testStringFetch(self):
|
||||
@ -1598,7 +1598,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(c_list[i], out[i].decode('utf-8'))
|
||||
|
||||
def testInvalidTargetFails(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.NotFoundError,
|
||||
'No session factory registered for the given session options'):
|
||||
session.Session('INVALID_TARGET')
|
||||
@ -1662,7 +1662,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
def testFeedDictKeyException(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(1.0, dtypes.float32, name='a')
|
||||
with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'):
|
||||
with self.assertRaisesRegex(TypeError, 'Cannot interpret feed_dict'):
|
||||
sess.run(a, feed_dict={'a': [2.0]})
|
||||
|
||||
def testPerStepTrace(self):
|
||||
@ -1717,10 +1717,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
new_shape = constant_op.constant([2, 2])
|
||||
reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'):
|
||||
with self.assertRaisesRegex(ValueError, 'Cannot feed value of shape'):
|
||||
sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
'Input to reshape is a tensor with 4 values, '
|
||||
'but the requested shape has 21'):
|
||||
@ -1794,7 +1794,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
sess2_controller = sess2.as_default()
|
||||
sess2_controller.__enter__()
|
||||
|
||||
with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
|
||||
with self.assertRaisesRegex(AssertionError, 'Nesting violated'):
|
||||
sess1_controller.__exit__(None, None, None)
|
||||
|
||||
ops._default_session_stack.reset()
|
||||
@ -1818,17 +1818,17 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testReentry(self):
|
||||
sess = session.Session()
|
||||
with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'not re-entrant'):
|
||||
with sess:
|
||||
with sess:
|
||||
pass
|
||||
|
||||
def testInvalidArgument(self):
|
||||
with self.assertRaisesRegexp(TypeError, 'target must be a string'):
|
||||
with self.assertRaisesRegex(TypeError, 'target must be a string'):
|
||||
session.Session(37)
|
||||
with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'):
|
||||
with self.assertRaisesRegex(TypeError, 'config must be a tf.ConfigProto'):
|
||||
session.Session(config=37)
|
||||
with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'):
|
||||
with self.assertRaisesRegex(TypeError, 'graph must be a tf.Graph'):
|
||||
session.Session(graph=37)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@ -2061,7 +2061,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
def testAutoConvertAndCheckData(self):
|
||||
with self.cached_session() as sess:
|
||||
a = array_ops.placeholder(dtype=dtypes.string)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'):
|
||||
sess.run(a, feed_dict={a: 1})
|
||||
|
||||
|
@ -32,8 +32,8 @@ class MLIRImportTest(test.TestCase):
|
||||
self.assertIn('func @main', mlir_module)
|
||||
|
||||
def test_invalid_pbtxt(self):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'Could not parse input proto'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'Could not parse input proto'):
|
||||
mlir.convert_graph_def('some invalid proto')
|
||||
|
||||
|
||||
|
@ -563,7 +563,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
# Run TRT conversion.
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Option is_dynamic_op=False is not supported in TF 2.0, "
|
||||
"please set it to True instead."):
|
||||
self._CreateConverterV2(input_saved_model_dir, is_dynamic_op=False)
|
||||
@ -684,16 +684,16 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
gen_resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=False)
|
||||
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
with self.assertRaisesRegex(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
_DestroyCache()
|
||||
|
||||
# Load the converted model and make sure the engine cache is populated by
|
||||
# default.
|
||||
root = load.load(output_saved_model_dir)
|
||||
_DestroyCache()
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
with self.assertRaisesRegex(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
_DestroyCache()
|
||||
|
||||
# Load the converted model again and make sure the engine cache is destroyed
|
||||
@ -701,8 +701,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
root = load.load(output_saved_model_dir)
|
||||
del root
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
with self.assertRaisesRegex(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
_DestroyCache()
|
||||
|
||||
def _CompareSavedModel(self, model_class):
|
||||
|
@ -103,8 +103,8 @@ class ExperimentalCompileTest(test.TestCase):
|
||||
x = xla_func(inputs)
|
||||
# XLA support is not yet enabled for TF ROCm
|
||||
if not test.is_built_with_rocm():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"not compilable"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"not compilable"):
|
||||
with session.Session(graph=g) as sess:
|
||||
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||
|
||||
|
@ -57,7 +57,7 @@ class JITTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testJITInEager(self):
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "xla.experimental.jit_scope is not supported when eager "
|
||||
"execution is enabled. Try use it inside tf.function."):
|
||||
with jit.experimental_jit_scope(True):
|
||||
@ -204,7 +204,7 @@ class CompilationEnabledInGradientTest(test.TestCase, parameterized.TestCase):
|
||||
for cg in c_grad_ops:
|
||||
self.assertTrue(cg.get_attr("_XlaCompile"))
|
||||
for ncg in nc_grad_ops:
|
||||
with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"):
|
||||
with self.assertRaisesRegex(ValueError, "[Nn]o attr named"):
|
||||
ncg.get_attr("_XlaCompile")
|
||||
|
||||
# d/dx (x ** 4) = 4 * (x ** 3)
|
||||
|
@ -112,7 +112,7 @@ class XLACompileContextTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, 'Non-resource Variables are not supported inside '
|
||||
r'XLA computations \(operator name: Assign\)'):
|
||||
state_ops.assign(a, a + 1)
|
||||
@ -126,8 +126,8 @@ class XLACompileContextTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
context2 = self.create_test_xla_compile_context()
|
||||
context2.Enter()
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'XLA compiled computations cannot be nested'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'XLA compiled computations cannot be nested'):
|
||||
constant_op.constant(1)
|
||||
context2.Exit()
|
||||
context1.Exit()
|
||||
|
@ -69,8 +69,7 @@ class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(
|
||||
cardinality.assert_cardinality(asserted_cardinality))
|
||||
get_next = self.getNext(dataset)
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
expected_error):
|
||||
with self.assertRaisesRegex(errors.FailedPreconditionError, expected_error):
|
||||
while True:
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
@ -85,7 +85,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||
input_tensor = array_ops.constant([[1]])
|
||||
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimension -2 must be >= 0"):
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [-2]))
|
||||
|
||||
@ -98,14 +98,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
# Initialize with an input tensor of incompatible rank.
|
||||
get_next = self.getNext(dataset_fn([[1]]))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# Initialize with an input tensor that is larger than `row_shape`.
|
||||
get_next = self.getNext(dataset_fn(np.int32(range(13))))
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
with self.assertRaisesRegex(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
|
@ -113,38 +113,38 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"vector of length `len\(datasets\)`"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"vector of length `len\(datasets\)`"):
|
||||
interleave_ops.sample_from_datasets(
|
||||
[dataset_ops.Dataset.range(10),
|
||||
dataset_ops.Dataset.range(20)],
|
||||
weights=[0.25, 0.25, 0.25, 0.25])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
|
||||
with self.assertRaisesRegex(TypeError, "`tf.float32` or `tf.float64`"):
|
||||
interleave_ops.sample_from_datasets(
|
||||
[dataset_ops.Dataset.range(10),
|
||||
dataset_ops.Dataset.range(20)],
|
||||
weights=[1, 1])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "must have the same type"):
|
||||
with self.assertRaisesRegex(TypeError, "must have the same type"):
|
||||
interleave_ops.sample_from_datasets([
|
||||
dataset_ops.Dataset.from_tensors(0),
|
||||
dataset_ops.Dataset.from_tensors(0.0)
|
||||
])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "tf.int64"):
|
||||
with self.assertRaisesRegex(TypeError, "tf.int64"):
|
||||
interleave_ops.choose_from_datasets([
|
||||
dataset_ops.Dataset.from_tensors(0),
|
||||
dataset_ops.Dataset.from_tensors(1)
|
||||
], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "scalar"):
|
||||
with self.assertRaisesRegex(TypeError, "scalar"):
|
||||
interleave_ops.choose_from_datasets([
|
||||
dataset_ops.Dataset.from_tensors(0),
|
||||
dataset_ops.Dataset.from_tensors(1)
|
||||
], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "out of range"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, "out of range"):
|
||||
dataset = interleave_ops.choose_from_datasets(
|
||||
[dataset_ops.Dataset.from_tensors(0)],
|
||||
choice_dataset=dataset_ops.Dataset.from_tensors(
|
||||
|
@ -64,7 +64,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertAllEqual([skip], sparse_val.values)
|
||||
self.assertAllEqual([skip], sparse_val.dense_shape)
|
||||
else:
|
||||
with self.assertRaisesRegexp(error, error_msg):
|
||||
with self.assertRaisesRegex(error, error_msg):
|
||||
self.evaluate(get_single_element.get_single_element(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
|
@ -143,7 +143,7 @@ class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The element types for the new state must match the initial state."):
|
||||
dataset.apply(
|
||||
@ -158,7 +158,7 @@ class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
|
||||
@ -172,7 +172,7 @@ class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: "wrong", reducer))
|
||||
|
@ -265,7 +265,7 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
|
||||
|
||||
get_next = self.getNext(dataset)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
print(self.evaluate(get_next()))
|
||||
|
@ -223,7 +223,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOldStyleReader(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"The `reader` argument must return a `Dataset` object. "
|
||||
r"`tf.ReaderBase` subclasses are not supported."):
|
||||
_ = readers.make_batched_features_dataset(
|
||||
|
@ -258,8 +258,8 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
compression_type="GZIP",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"compression_type .ZLIB. is not supported"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"compression_type .ZLIB. is not supported"):
|
||||
self._test_dataset(
|
||||
inputs,
|
||||
expected_output=expected_output,
|
||||
|
@ -226,7 +226,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testMapAndBatchFails(self):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, "oops"):
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
array_ops.check_numerics(
|
||||
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
|
||||
|
@ -185,8 +185,8 @@ class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
|
||||
[100, 1])
|
||||
map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
self.evaluate(map_defun_op)
|
||||
|
||||
@combinations.generate(_test_combinations())
|
||||
|
@ -150,8 +150,8 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
expected_error_msg = ("`num_elements_per_branch` must be divisible by "
|
||||
"`ratio_denominator`")
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
expected_error_msg):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
expected_error_msg):
|
||||
make_dataset()
|
||||
else:
|
||||
choose_fastest = make_dataset()
|
||||
|
@ -466,8 +466,8 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# x has leading dimension 5, this will raise an error
|
||||
return array_ops.gather(x, 10)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
|
||||
5, drop_remainder=True)
|
||||
_, optimized = self._get_test_datasets(base_dataset, map_fn)
|
||||
|
@ -79,8 +79,8 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testScalarInputError(self):
|
||||
dataset = dataset_ops.Dataset.range(1024)
|
||||
distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
|
||||
with self.assertRaisesRegexp(ValueError, ("You can fix the issue "
|
||||
"by adding the `batch`")):
|
||||
with self.assertRaisesRegex(ValueError, ("You can fix the issue "
|
||||
"by adding the `batch`")):
|
||||
distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -184,7 +184,7 @@ class ScanTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
start = empty_ta
|
||||
start = start.write(0, -1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"construct a new TensorArray inside the function"):
|
||||
dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
|
||||
@ -226,7 +226,7 @@ class ScanTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
return constant_op.constant(1, dtype=dtypes.int64), state
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The element types for the new state must match the initial state."):
|
||||
dataset.apply(
|
||||
@ -239,7 +239,7 @@ class ScanTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
return constant_op.constant(1, dtype=dtypes.int64)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"The scan function must return a pair comprising the new state and the "
|
||||
"output value."):
|
||||
|
@ -53,8 +53,8 @@ class SkipDatasetSerializationTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInvalidSkip(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), 0)
|
||||
|
||||
|
||||
@ -83,8 +83,8 @@ class TakeDatasetSerializationTest(
|
||||
self.run_core_tests(lambda: self._build_take_dataset(0), 0)
|
||||
|
||||
def testInvalidTake(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
self.run_core_tests(lambda: self._build_take_dataset([1, 2]), 0)
|
||||
|
||||
|
||||
@ -120,8 +120,8 @@ class RepeatDatasetSerializationTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInvalidRepeat(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Shape must be rank 0 but is rank 1'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Shape must be rank 0 but is rank 1'):
|
||||
self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), 0)
|
||||
|
||||
|
||||
|
@ -44,8 +44,8 @@ class StatsDatasetSerializationTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def test_bytes_produced_stats_invalid_tag_shape(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be rank 0 but is rank 1"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Shape must be rank 0 but is rank 1"):
|
||||
# pylint: disable=g-long-lambda
|
||||
self.run_core_tests(
|
||||
lambda: dataset_ops.Dataset.range(100).apply(
|
||||
@ -71,8 +71,8 @@ class StatsDatasetSerializationTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def test_latency_stats_invalid_tag_shape(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be rank 0 but is rank 1"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Shape must be rank 0 but is rank 1"):
|
||||
# pylint: disable=g-long-lambda
|
||||
self.run_core_tests(
|
||||
lambda: dataset_ops.Dataset.range(100).apply(
|
||||
|
@ -110,7 +110,7 @@ class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
|
||||
to_concatenate_components)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
with self.assertRaisesRegex(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -128,7 +128,7 @@ class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
|
||||
to_concatenate_components)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
with self.assertRaisesRegex(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -144,7 +144,7 @@ class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
|
||||
to_concatenate_components)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
with self.assertRaisesRegex(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
|
||||
|
@ -351,7 +351,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testSameGraphError(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
@combinations.generate(
|
||||
@ -359,7 +359,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testSameGraphErrorOneShot(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Please ensure that all datasets in the pipeline are "
|
||||
"created in the same graph as the iterator."):
|
||||
_ = dataset_ops.make_one_shot_iterator(dataset)
|
||||
@ -369,7 +369,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testSameGraphErrorInitializable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Please ensure that all datasets in the pipeline are "
|
||||
"created in the same graph as the iterator."):
|
||||
_ = dataset_ops.make_initializable_iterator(dataset)
|
||||
|
@ -453,7 +453,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for _ in range(10):
|
||||
yield [20]
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"Cannot convert value \[tf.int64\] to a TensorFlow DType"):
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=[dtypes.int64])
|
||||
|
@ -72,7 +72,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
|
||||
.map(lambda x: x + var))
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
|
||||
"datasets that capture stateful objects.+myvar"):
|
||||
dataset_ops.make_one_shot_iterator(dataset)
|
||||
@ -213,17 +213,17 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
|
||||
sess.run(next_element)
|
||||
|
||||
# Test that subsequent attempts to use the iterator also fail.
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
|
||||
sess.run(next_element)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
|
||||
def consumer_thread():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
|
||||
sess.run(next_element)
|
||||
|
||||
num_threads = 8
|
||||
@ -293,8 +293,8 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
"iterator has not been initialized"):
|
||||
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
||||
"iterator has not been initialized"):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
|
@ -1012,7 +1012,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(_test_combinations())
|
||||
def testReturnValueError(self, apply_map):
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"Unsupported return value from function passed to "
|
||||
r"Dataset.map\(\)"):
|
||||
_ = apply_map(dataset, lambda x: Foo)
|
||||
|
@ -68,8 +68,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options1.experimental_optimization.autotune = True
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_optimization.autotune = False
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Cannot merge incompatible values"):
|
||||
with self.assertRaisesRegex(ValueError, "Cannot merge incompatible values"):
|
||||
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
|
@ -243,14 +243,14 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorWrongRank(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'The padded shape \(1,\) is not compatible with the '
|
||||
r'corresponding input component shape \(\).'):
|
||||
_ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorTooSmall(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'The padded shape \(1,\) is not compatible with the '
|
||||
r'corresponding input component shape \(3,\).'):
|
||||
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
|
||||
@ -258,7 +258,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorShapeNotRank1(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Padded shape .* must be a 1-D tensor '
|
||||
r'of tf.int64 values, but its shape was \(2, 2\).'):
|
||||
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
|
||||
@ -266,7 +266,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorShapeNotInt(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r'Padded shape .* must be a 1-D tensor '
|
||||
r'of tf.int64 values, but its element type was float32.'):
|
||||
_ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
|
||||
@ -274,7 +274,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorWrongRankFromTensor(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'The padded shape \(1,\) is not compatible with the '
|
||||
r'corresponding input component shape \(\).'):
|
||||
shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
|
||||
@ -283,14 +283,14 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeErrorDefaultShapeWithUnknownRank(self):
|
||||
with self.assertRaisesRegexp(ValueError, r'`padded_shapes`.*unknown rank'):
|
||||
with self.assertRaisesRegex(ValueError, r'`padded_shapes`.*unknown rank'):
|
||||
ds = dataset_ops.Dataset.from_generator(
|
||||
lambda: iter([1, 2, 3]), output_types=dtypes.int32)
|
||||
ds.padded_batch(2)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPaddedBatchShapeErrorPlaceholder(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
|
||||
r'corresponding input component shape \(\).'):
|
||||
|
@ -79,13 +79,13 @@ class ConvertTest(test.TestCase):
|
||||
constant_op.constant([-1],
|
||||
dtype=dtypes.int64))))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
|
||||
r"values, but the shape was \(2, 2\)."):
|
||||
convert.partial_shape_to_tensor(constant_op.constant(
|
||||
[[1, 1], [1, 1]], dtype=dtypes.int64))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"The given shape .* must be a 1-D tensor of tf.int64 "
|
||||
r"values, but the element type was float32."):
|
||||
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
|
||||
|
@ -58,10 +58,10 @@ class NestTest(test.TestCase):
|
||||
self.assertEqual(
|
||||
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
|
||||
with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
|
||||
nest.pack_sequence_as("scalar", [4, 5])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
|
||||
with self.assertRaisesRegex(TypeError, "flat_sequence"):
|
||||
nest.pack_sequence_as([4, 5], "bad_sequence")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
@ -191,20 +191,20 @@ class NestTest(test.TestCase):
|
||||
nest.assert_same_structure("abc", np.array([0, 1]))
|
||||
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure(structure1, structure_different_num_elements)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure((0, 1), np.array([0, 1]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure(0, (0, 1))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure(structure1, structure_different_nesting)
|
||||
|
||||
named_type_0 = collections.namedtuple("named_0", ("a", "b"))
|
||||
@ -217,24 +217,23 @@ class NestTest(test.TestCase):
|
||||
self.assertRaises(TypeError, nest.assert_same_structure,
|
||||
named_type_0(3, 4), named_type_1(3, 4))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure(((3,), 4), (3, (4,)))
|
||||
|
||||
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
|
||||
structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)}
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
"don't have the same sequence type"):
|
||||
with self.assertRaisesRegex(TypeError, "don't have the same sequence type"):
|
||||
nest.assert_same_structure(structure1, structure1_list)
|
||||
nest.assert_same_structure(structure1, structure2, check_types=False)
|
||||
nest.assert_same_structure(structure1, structure1_list, check_types=False)
|
||||
with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
|
||||
with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
|
||||
nest.assert_same_structure(structure1_list, structure2_list)
|
||||
with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
|
||||
with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
|
||||
nest.assert_same_structure(structure_dictionary,
|
||||
structure_dictionary_diff_nested)
|
||||
nest.assert_same_structure(
|
||||
@ -262,26 +261,26 @@ class NestTest(test.TestCase):
|
||||
|
||||
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "callable"):
|
||||
with self.assertRaisesRegex(TypeError, "callable"):
|
||||
nest.map_structure("bad", structure1_plus1)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError, "same nested structure"):
|
||||
nest.map_structure(lambda x, y: None, 3, (3,))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "same sequence type"):
|
||||
with self.assertRaisesRegex(TypeError, "same sequence type"):
|
||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError, "same nested structure"):
|
||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||
with self.assertRaisesRegex(ValueError, "same nested structure"):
|
||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
|
||||
check_types=False)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
|
||||
with self.assertRaisesRegex(ValueError, "Only valid keyword argument"):
|
||||
nest.map_structure(lambda x: None, structure1, foo="a")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
|
||||
with self.assertRaisesRegex(ValueError, "Only valid keyword argument"):
|
||||
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
|
||||
|
||||
def testAssertShallowStructure(self):
|
||||
@ -290,7 +289,7 @@ class NestTest(test.TestCase):
|
||||
expected_message = (
|
||||
"The two structures don't have the same sequence length. Input "
|
||||
"structure has length 2, while shallow structure has length 3.")
|
||||
with self.assertRaisesRegexp(ValueError, expected_message):
|
||||
with self.assertRaisesRegex(ValueError, expected_message):
|
||||
nest.assert_shallow_structure(inp_abc, inp_ab)
|
||||
|
||||
inp_ab1 = ((1, 1), (2, 2))
|
||||
@ -299,7 +298,7 @@ class NestTest(test.TestCase):
|
||||
"The two structures don't have the same sequence type. Input structure "
|
||||
"has type <(type|class) 'tuple'>, while shallow structure has type "
|
||||
"<(type|class) 'dict'>.")
|
||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||
with self.assertRaisesRegex(TypeError, expected_message):
|
||||
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
||||
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
|
||||
|
||||
@ -309,7 +308,7 @@ class NestTest(test.TestCase):
|
||||
r"The two structures don't have the same keys. Input "
|
||||
r"structure has keys \['c'\], while shallow structure has "
|
||||
r"keys \['d'\].")
|
||||
with self.assertRaisesRegexp(ValueError, expected_message):
|
||||
with self.assertRaisesRegex(ValueError, expected_message):
|
||||
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
||||
|
||||
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
|
||||
@ -387,14 +386,14 @@ class NestTest(test.TestCase):
|
||||
shallow_tree = ("shallow_tree",)
|
||||
expected_message = ("If shallow structure is a sequence, input must also "
|
||||
"be a sequence. Input has type: <(type|class) 'str'>.")
|
||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||
with self.assertRaisesRegex(TypeError, expected_message):
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||
|
||||
input_tree = "input_tree"
|
||||
shallow_tree = ("shallow_tree_9", "shallow_tree_8")
|
||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||
with self.assertRaisesRegex(TypeError, expected_message):
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||
@ -404,14 +403,14 @@ class NestTest(test.TestCase):
|
||||
shallow_tree = (9,)
|
||||
expected_message = ("If shallow structure is a sequence, input must also "
|
||||
"be a sequence. Input has type: <(type|class) 'int'>.")
|
||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||
with self.assertRaisesRegex(TypeError, expected_message):
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||
|
||||
input_tree = 0
|
||||
shallow_tree = (9, 8)
|
||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||
with self.assertRaisesRegex(TypeError, expected_message):
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||
|
@ -417,46 +417,46 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
s_nest = structure.type_spec_from_value(value_nest)
|
||||
flat_nest = structure.to_tensor_list(s_nest, value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"SparseTensor.* is not convertible to a tensor with "
|
||||
r"dtype.*float32.* and shape \(\)"):
|
||||
structure.to_tensor_list(s_tensor, value_sparse_tensor)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_tensor, value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Neither a SparseTensor nor SparseTensorValue"):
|
||||
structure.to_tensor_list(s_sparse_tensor, value_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_sparse_tensor, value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_nest, value_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_nest, value_sparse_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
|
||||
with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
|
||||
structure.from_tensor_list(s_tensor, flat_sparse_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
|
||||
structure.from_tensor_list(s_tensor, flat_nest)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
with self.assertRaisesRegex(ValueError, "Incompatible input: "):
|
||||
structure.from_tensor_list(s_sparse_tensor, flat_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
|
||||
structure.from_tensor_list(s_sparse_tensor, flat_nest)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
|
||||
structure.from_tensor_list(s_nest, flat_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
|
||||
structure.from_tensor_list(s_nest, flat_sparse_tensor)
|
||||
|
||||
def testIncompatibleNestedStructure(self):
|
||||
@ -498,20 +498,20 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
s_2 = structure.type_spec_from_value(value_2)
|
||||
flat_s_2 = structure.to_tensor_list(s_2, value_2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"SparseTensor.* is not convertible to a tensor with "
|
||||
r"dtype.*int32.* and shape \(3,\)"):
|
||||
structure.to_tensor_list(s_0, value_1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_0, value_2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Neither a SparseTensor nor SparseTensorValue"):
|
||||
structure.to_tensor_list(s_1, value_0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_1, value_2)
|
||||
|
||||
@ -519,30 +519,30 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
# needs to account for "a" coming before or after "b". It might be worth
|
||||
# adding a deterministic repr for these error messages (among other
|
||||
# improvements).
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_2, value_0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The two structures don't have the same nested structure."):
|
||||
structure.to_tensor_list(s_2, value_1)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
|
||||
with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
|
||||
structure.from_tensor_list(s_0, flat_s_1)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."):
|
||||
structure.from_tensor_list(s_0, flat_s_2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
with self.assertRaisesRegex(ValueError, "Incompatible input: "):
|
||||
structure.from_tensor_list(s_1, flat_s_0)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."):
|
||||
structure.from_tensor_list(s_1, flat_s_2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."):
|
||||
structure.from_tensor_list(s_2, flat_s_0)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."):
|
||||
with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."):
|
||||
structure.from_tensor_list(s_2, flat_s_1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -1340,24 +1340,24 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Input argument filter_name cannot be empty."):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Input argument filter_name cannot be empty."):
|
||||
analyzer.add_tensor_filter("", lambda datum, tensor: True)
|
||||
|
||||
def testAddTensorFilterNonStrName(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"Input argument filter_name is expected to be str, ""but is not"):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Input argument filter_name is expected to be str, "
|
||||
"but is not"):
|
||||
analyzer.add_tensor_filter(1, lambda datum, tensor: True)
|
||||
|
||||
def testAddGetTensorFilterNonCallable(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Input argument filter_callable is expected to be callable, "
|
||||
"but is not."):
|
||||
analyzer.add_tensor_filter("foo_filter", "bar")
|
||||
@ -1367,8 +1367,8 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
analyzer.add_tensor_filter("foo_filter", lambda datum, tensor: True)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"There is no tensor filter named \"bar\""):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"There is no tensor filter named \"bar\""):
|
||||
analyzer.get_tensor_filter("bar")
|
||||
|
||||
def _findSourceLine(self, annotated_source, line_number):
|
||||
|
@ -101,7 +101,7 @@ class TimeToReadableStrTest(test_util.TensorFlowTestCase):
|
||||
cli_shared.time_to_readable_str(
|
||||
0, force_time_unit=cli_shared.TIME_UNIT_S))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
|
||||
with self.assertRaisesRegex(ValueError, r"Invalid time unit: ks"):
|
||||
cli_shared.time_to_readable_str(100, force_time_unit="ks")
|
||||
|
||||
|
||||
|
@ -121,7 +121,7 @@ class ExtractOutputFilePathTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(output_path, "/tmp/foo.txt")
|
||||
|
||||
def testHasGreaterThanSignButNoFileNameCausesSyntaxError(self):
|
||||
with self.assertRaisesRegexp(SyntaxError, "Redirect file path is empty"):
|
||||
with self.assertRaisesRegex(SyntaxError, "Redirect file path is empty"):
|
||||
command_parser.extract_output_file_path(
|
||||
["pt", "a:0", ">"])
|
||||
|
||||
@ -256,15 +256,15 @@ class ParseIndicesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual([3, 4, -5], command_parser.parse_indices("3,4,-5"))
|
||||
|
||||
def testParseInvalidIndicesStringsWithoutBrackets(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"invalid literal for int\(\) with base 10: 'a'"):
|
||||
self.assertEqual([0], command_parser.parse_indices("0,a"))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"invalid literal for int\(\) with base 10: '2\]'"):
|
||||
self.assertEqual([0], command_parser.parse_indices("1, 2]"))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"invalid literal for int\(\) with base 10: ''"):
|
||||
self.assertEqual([0], command_parser.parse_indices("3, 4,"))
|
||||
|
||||
@ -296,20 +296,20 @@ class ParseRangesTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(SyntaxError):
|
||||
command_parser.parse_ranges("[[1,2]")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Incorrect number of elements in range"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Incorrect number of elements in range"):
|
||||
command_parser.parse_ranges("[1,2,3]")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Incorrect number of elements in range"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Incorrect number of elements in range"):
|
||||
command_parser.parse_ranges("[inf]")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Incorrect type in the 1st element of range"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Incorrect type in the 1st element of range"):
|
||||
command_parser.parse_ranges("[1j, 1]")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Incorrect type in the 2nd element of range"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Incorrect type in the 2nd element of range"):
|
||||
command_parser.parse_ranges("[1, 1j]")
|
||||
|
||||
|
||||
@ -350,11 +350,11 @@ class ParseReadableSizeStrTest(test_util.TensorFlowTestCase):
|
||||
command_parser.parse_readable_size_str("0.25G"))
|
||||
|
||||
def testParseUnsupportedUnitRaisesException(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Failed to parsed human-readable byte size str: \"0foo\""):
|
||||
command_parser.parse_readable_size_str("0foo")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Failed to parsed human-readable byte size str: \"2E\""):
|
||||
command_parser.parse_readable_size_str("2EB")
|
||||
|
||||
@ -377,15 +377,13 @@ class ParseReadableTimeStrTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2e3, command_parser.parse_readable_time_str("2ms"))
|
||||
|
||||
def testParseUnsupportedUnitRaisesException(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".*float.*2us.*"):
|
||||
with self.assertRaisesRegex(ValueError, r".*float.*2us.*"):
|
||||
command_parser.parse_readable_time_str("2uss")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".*float.*2m.*"):
|
||||
with self.assertRaisesRegex(ValueError, r".*float.*2m.*"):
|
||||
command_parser.parse_readable_time_str("2m")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Invalid time -1. Time value must be positive."):
|
||||
command_parser.parse_readable_time_str("-1s")
|
||||
|
||||
@ -393,103 +391,104 @@ class ParseReadableTimeStrTest(test_util.TensorFlowTestCase):
|
||||
class ParseInterval(test_util.TensorFlowTestCase):
|
||||
|
||||
def testParseTimeInterval(self):
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(10, True, 1e3, True),
|
||||
command_parser.parse_time_interval("[10us, 1ms]"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(10, False, 1e3, False),
|
||||
command_parser.parse_time_interval("(10us, 1ms)"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(10, False, 1e3, True),
|
||||
command_parser.parse_time_interval("(10us, 1ms]"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(10, True, 1e3, False),
|
||||
command_parser.parse_time_interval("[10us, 1ms)"))
|
||||
self.assertEquals(command_parser.Interval(0, False, 1e3, True),
|
||||
command_parser.parse_time_interval("<=1ms"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(0, False, 1e3, True),
|
||||
command_parser.parse_time_interval("<=1ms"))
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1e3, True, float("inf"), False),
|
||||
command_parser.parse_time_interval(">=1ms"))
|
||||
self.assertEquals(command_parser.Interval(0, False, 1e3, False),
|
||||
command_parser.parse_time_interval("<1ms"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(0, False, 1e3, False),
|
||||
command_parser.parse_time_interval("<1ms"))
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1e3, False, float("inf"), False),
|
||||
command_parser.parse_time_interval(">1ms"))
|
||||
|
||||
def testParseTimeGreaterLessThanWithInvalidValueStrings(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after >= "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after >= "):
|
||||
command_parser.parse_time_interval(">=wms")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after > "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after > "):
|
||||
command_parser.parse_time_interval(">Yms")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after <= "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after <= "):
|
||||
command_parser.parse_time_interval("<= _ms")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after < "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after < "):
|
||||
command_parser.parse_time_interval("<-ms")
|
||||
|
||||
def testParseTimeIntervalsWithInvalidValueStrings(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid first item in interval:"):
|
||||
command_parser.parse_time_interval("[wms, 10ms]")
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Invalid second item in interval:"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid second item in interval:"):
|
||||
command_parser.parse_time_interval("[ 0ms, _ms]")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid first item in interval:"):
|
||||
command_parser.parse_time_interval("(xms, _ms]")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid first item in interval:"):
|
||||
command_parser.parse_time_interval("((3ms, _ms)")
|
||||
|
||||
def testInvalidTimeIntervalRaisesException(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"Invalid interval format: \[10us, 1ms. Valid formats are: "
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Invalid interval format: \[10us, 1ms. Valid formats are: "
|
||||
r"\[min, max\], \(min, max\), <max, >min"):
|
||||
command_parser.parse_time_interval("[10us, 1ms")
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Incorrect interval format: \[10us, 1ms, 2ms\]. Interval should "
|
||||
r"specify two values: \[min, max\] or \(min, max\)"):
|
||||
command_parser.parse_time_interval("[10us, 1ms, 2ms]")
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Invalid interval \[1s, 1ms\]. Start must be before end of interval."):
|
||||
command_parser.parse_time_interval("[1s, 1ms]")
|
||||
|
||||
def testParseMemoryInterval(self):
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1024, True, 2048, True),
|
||||
command_parser.parse_memory_interval("[1k, 2k]"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1024, False, 2048, False),
|
||||
command_parser.parse_memory_interval("(1kB, 2kB)"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1024, False, 2048, True),
|
||||
command_parser.parse_memory_interval("(1k, 2k]"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(1024, True, 2048, False),
|
||||
command_parser.parse_memory_interval("[1k, 2k)"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(0, False, 2048, True),
|
||||
command_parser.parse_memory_interval("<=2k"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(11, True, float("inf"), False),
|
||||
command_parser.parse_memory_interval(">=11"))
|
||||
self.assertEquals(command_parser.Interval(0, False, 2048, False),
|
||||
command_parser.parse_memory_interval("<2k"))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
command_parser.Interval(0, False, 2048, False),
|
||||
command_parser.parse_memory_interval("<2k"))
|
||||
self.assertEqual(
|
||||
command_parser.Interval(11, False, float("inf"), False),
|
||||
command_parser.parse_memory_interval(">11"))
|
||||
|
||||
def testParseMemoryIntervalsWithInvalidValueStrings(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after >= "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after >= "):
|
||||
command_parser.parse_time_interval(">=wM")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after > "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after > "):
|
||||
command_parser.parse_time_interval(">YM")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after <= "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after <= "):
|
||||
command_parser.parse_time_interval("<= _MB")
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid value string after < "):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid value string after < "):
|
||||
command_parser.parse_time_interval("<-MB")
|
||||
|
||||
def testInvalidMemoryIntervalRaisesException(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Invalid interval \[5k, 3k\]. Start of interval must be less than or "
|
||||
"equal to end of interval."):
|
||||
|
@ -1532,8 +1532,8 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
class ScrollBarTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConstructorRaisesExceptionForNotEnoughHeight(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Insufficient height for ScrollBar \(2\)"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Insufficient height for ScrollBar \(2\)"):
|
||||
curses_ui.ScrollBar(0, 0, 1, 1, 0, 0)
|
||||
|
||||
def testLayoutIsEmptyForZeroRow(self):
|
||||
|
@ -43,11 +43,11 @@ class CNHTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(nav_history.can_go_forward())
|
||||
self.assertFalse(nav_history.can_go_back())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Empty navigation history"):
|
||||
with self.assertRaisesRegex(ValueError, "Empty navigation history"):
|
||||
nav_history.go_back()
|
||||
with self.assertRaisesRegexp(ValueError, "Empty navigation history"):
|
||||
with self.assertRaisesRegex(ValueError, "Empty navigation history"):
|
||||
nav_history.go_forward()
|
||||
with self.assertRaisesRegexp(ValueError, "Empty navigation history"):
|
||||
with self.assertRaisesRegex(ValueError, "Empty navigation history"):
|
||||
nav_history.update_scroll_position(3)
|
||||
|
||||
def testAddOneItemWorks(self):
|
||||
|
@ -64,7 +64,7 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2, screen_output.num_lines())
|
||||
|
||||
def testRichTextLinesConstructorWithInvalidType(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Unexpected type in lines"):
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected type in lines"):
|
||||
debugger_cli_common.RichTextLines(123)
|
||||
|
||||
def testRichTextLinesConstructorWithString(self):
|
||||
@ -320,7 +320,7 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Attempt to register an empty-string as a command prefix should trigger
|
||||
# an exception.
|
||||
with self.assertRaisesRegexp(ValueError, "Empty command prefix"):
|
||||
with self.assertRaisesRegex(ValueError, "Empty command prefix"):
|
||||
registry.register_command_handler("", self._noop_handler, "")
|
||||
|
||||
def testRegisterAndInvokeHandler(self):
|
||||
@ -335,11 +335,11 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Attempt to invoke an unregistered command prefix should trigger an
|
||||
# exception.
|
||||
with self.assertRaisesRegexp(ValueError, "No handler is registered"):
|
||||
with self.assertRaisesRegex(ValueError, "No handler is registered"):
|
||||
registry.dispatch_command("beep", [])
|
||||
|
||||
# Empty command prefix should trigger an exception.
|
||||
with self.assertRaisesRegexp(ValueError, "Prefix is empty"):
|
||||
with self.assertRaisesRegex(ValueError, "Prefix is empty"):
|
||||
registry.dispatch_command("", [])
|
||||
|
||||
def testExitingHandler(self):
|
||||
@ -391,7 +391,7 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# If the command handler fails to return a RichTextLines instance, an error
|
||||
# should be triggered.
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Return value from command handler.*is not None or a RichTextLines "
|
||||
"instance"):
|
||||
@ -403,7 +403,7 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Registering the same command prefix more than once should trigger an
|
||||
# exception.
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "A handler is already registered for command prefix"):
|
||||
registry.register_command_handler("noop", self._noop_handler, "")
|
||||
|
||||
@ -416,8 +416,8 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
"noop", self._noop_handler, "", prefix_aliases=["n"])
|
||||
|
||||
# Clash with existing alias.
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"clashes with existing prefixes or aliases"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"clashes with existing prefixes or aliases"):
|
||||
registry.register_command_handler(
|
||||
"cols", self._echo_screen_cols, "", prefix_aliases=["n"])
|
||||
|
||||
@ -425,8 +425,8 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(registry.is_registered("cols"))
|
||||
|
||||
# Aliases can also clash with command prefixes.
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"clashes with existing prefixes or aliases"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"clashes with existing prefixes or aliases"):
|
||||
registry.register_command_handler(
|
||||
"cols", self._echo_screen_cols, "", prefix_aliases=["noop"])
|
||||
|
||||
@ -451,13 +451,13 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||
registry = debugger_cli_common.CommandHandlerRegistry()
|
||||
|
||||
# Attempt to register a non-callable handler should fail.
|
||||
with self.assertRaisesRegexp(ValueError, "handler is not callable"):
|
||||
with self.assertRaisesRegex(ValueError, "handler is not callable"):
|
||||
registry.register_command_handler("non_callable", 1, "")
|
||||
|
||||
def testRegisterHandlerWithInvalidHelpInfoType(self):
|
||||
registry = debugger_cli_common.CommandHandlerRegistry()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "help_info is not a str"):
|
||||
with self.assertRaisesRegex(ValueError, "help_info is not a str"):
|
||||
registry.register_command_handler("noop", self._noop_handler, ["foo"])
|
||||
|
||||
def testGetHelpFull(self):
|
||||
@ -629,7 +629,7 @@ class RegexFindTest(test_util.TensorFlowTestCase):
|
||||
debugger_cli_common.REGEX_MATCH_LINES_KEY])
|
||||
|
||||
def testInvalidRegex(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid regular expression"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid regular expression"):
|
||||
debugger_cli_common.regex_find(self._orig_screen_output, "[", "yellow")
|
||||
|
||||
def testRegexFindOnPrependedLinesWorks(self):
|
||||
@ -755,11 +755,11 @@ class WrapScreenOutputTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(new_line_indices, [0, 2, 5])
|
||||
|
||||
def testWrappingInvalidArguments(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Invalid type of input screen_output"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Invalid type of input screen_output"):
|
||||
debugger_cli_common.wrap_rich_text_lines("foo", 12)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid type of input cols"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid type of input cols"):
|
||||
debugger_cli_common.wrap_rich_text_lines(
|
||||
debugger_cli_common.RichTextLines(["foo", "bar"]), "12")
|
||||
|
||||
@ -813,7 +813,7 @@ class SliceRichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1, sliced.num_lines())
|
||||
|
||||
def testAttemptSliceWithNegativeIndex(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Encountered negative index"):
|
||||
with self.assertRaisesRegex(ValueError, "Encountered negative index"):
|
||||
self._original.slice(0, -1)
|
||||
|
||||
|
||||
@ -872,8 +872,8 @@ class TabCompletionRegistryTest(test_util.TensorFlowTestCase):
|
||||
self._tc_reg.get_completions("node_info", "node_"))
|
||||
|
||||
def testExtendCompletionItemsNonexistentContext(self):
|
||||
with self.assertRaisesRegexp(
|
||||
KeyError, "Context word \"foo\" has not been registered"):
|
||||
with self.assertRaisesRegex(KeyError,
|
||||
"Context word \"foo\" has not been registered"):
|
||||
self._tc_reg.extend_comp_items("foo", ["node_A:1", "node_A:2"])
|
||||
|
||||
def testRemoveCompletionItems(self):
|
||||
@ -891,8 +891,8 @@ class TabCompletionRegistryTest(test_util.TensorFlowTestCase):
|
||||
self._tc_reg.get_completions("node_info", "node_"))
|
||||
|
||||
def testRemoveCompletionItemsNonexistentContext(self):
|
||||
with self.assertRaisesRegexp(
|
||||
KeyError, "Context word \"foo\" has not been registered"):
|
||||
with self.assertRaisesRegex(KeyError,
|
||||
"Context word \"foo\" has not been registered"):
|
||||
self._tc_reg.remove_comp_items("foo", ["node_a:1", "node_a:2"])
|
||||
|
||||
def testDeregisterContext(self):
|
||||
@ -921,7 +921,7 @@ class TabCompletionRegistryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self._tc_reg.deregister_context(["print_tensor"])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
KeyError,
|
||||
"Cannot deregister unregistered context word \"print_tensor\""):
|
||||
self._tc_reg.deregister_context(["print_tensor"])
|
||||
@ -992,7 +992,7 @@ class CommandHistoryTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual([], self._cmd_hist.lookup_prefix("print_tensor", 10))
|
||||
|
||||
def testAddNonStrCommand(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Attempt to enter non-str entry to command history"):
|
||||
self._cmd_hist.add_command(["print_tensor node_a:0"])
|
||||
|
||||
|
@ -102,14 +102,14 @@ class ParseDebugTensorNameTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
def testParseMalformedDebugTensorName(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"The debug tensor name in the to-be-evaluated expression is "
|
||||
r"malformed:"):
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:ps/replica:0/task:2/cpu:0:foo:1:DebugNanCount:1337")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"The debug tensor name in the to-be-evaluated expression is "
|
||||
r"malformed:"):
|
||||
@ -184,7 +184,7 @@ class EvaluatorTest(test_util.TensorFlowTestCase):
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Eval failed due to the value of .* being unavailable"):
|
||||
ev.evaluate("np.matmul(`a:0`, `b:0`)")
|
||||
|
||||
@ -206,7 +206,7 @@ class EvaluatorTest(test_util.TensorFlowTestCase):
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(ValueError, r"multiple \(2\) devices"):
|
||||
with self.assertRaisesRegex(ValueError, r"multiple \(2\) devices"):
|
||||
ev.evaluate("`a:0` + `a:0`")
|
||||
|
||||
self.assertAllClose(
|
||||
@ -252,12 +252,12 @@ class EvaluatorTest(test_util.TensorFlowTestCase):
|
||||
def testEvaluateExpressionWithInvalidDebugTensorName(self):
|
||||
dump = test.mock.MagicMock()
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* tensor name .* expression .* malformed"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r".* tensor name .* expression .* malformed"):
|
||||
ev.evaluate("np.matmul(`a`, `b`)")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* tensor name .* expression .* malformed"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r".* tensor name .* expression .* malformed"):
|
||||
ev.evaluate("np.matmul(`a:0:DebugIdentity:0`, `b:1:DebugNanCount:2`)")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
|
@ -79,7 +79,7 @@ class ProfileAnalyzerListProfileTest(test_util.TensorFlowTestCase):
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
self.assertEquals([""], prof_output)
|
||||
self.assertEqual([""], prof_output)
|
||||
|
||||
def testSingleDevice(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
@ -211,22 +211,22 @@ class ProfileAnalyzerListProfileTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Default sort by start time (i.e. all_start_micros).
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
|
||||
# Default sort in reverse.
|
||||
prof_output = prof_analyzer.list_profile(["-r"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by name.
|
||||
prof_output = prof_analyzer.list_profile(["-s", "node"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros).
|
||||
prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
|
||||
# Sort by exec time (i.e. all_end_rel_micros).
|
||||
prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by line number.
|
||||
prof_output = prof_analyzer.list_profile(["-s", "line"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
|
||||
|
||||
def testFiltering(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
|
@ -88,13 +88,13 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsInstance(ui, readline_ui.ReadlineUI)
|
||||
|
||||
def testUIFactoryRaisesExceptionOnInvalidUIType(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'foobar'"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid ui_type: 'foobar'"):
|
||||
ui_factory.get_ui(
|
||||
"foobar",
|
||||
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
|
||||
|
||||
def testUIFactoryRaisesExceptionOnInvalidUITypeGivenAvailable(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'readline'"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid ui_type: 'readline'"):
|
||||
ui_factory.get_ui(
|
||||
"readline",
|
||||
available_ui_types=["curses"],
|
||||
|
@ -373,16 +373,13 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self._checkTensorElementLocations(out, a)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices exceed tensor dimensions"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices exceed tensor dimensions"):
|
||||
tensor_format.locate_tensor_element(out, [20])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices contain negative"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices contain negative"):
|
||||
tensor_format.locate_tensor_element(out, [-1])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Dimensions mismatch"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimensions mismatch"):
|
||||
tensor_format.locate_tensor_element(out, [0, 0])
|
||||
|
||||
def testLocateTensorElement1DNoEllipsisBatchMode(self):
|
||||
@ -407,18 +404,17 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
self, ["Tensor \"a\":", ""], out.lines[:2])
|
||||
self.assertEqual(repr(a).split("\n"), out.lines[2:])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions mismatch"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimensions mismatch"):
|
||||
tensor_format.locate_tensor_element(out, [[0, 0], [0]])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Indices exceed tensor dimensions"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices exceed tensor dimensions"):
|
||||
tensor_format.locate_tensor_element(out, [[0], [20]])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"Indices contain negative value\(s\)"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Indices contain negative value\(s\)"):
|
||||
tensor_format.locate_tensor_element(out, [[0], [-1]])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Input indices sets are not in ascending order"):
|
||||
tensor_format.locate_tensor_element(out, [[5], [0]])
|
||||
|
||||
@ -447,16 +443,13 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self._checkTensorElementLocations(out, a)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices exceed tensor dimensions"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices exceed tensor dimensions"):
|
||||
tensor_format.locate_tensor_element(out, [1, 4])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices contain negative"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices contain negative"):
|
||||
tensor_format.locate_tensor_element(out, [-1, 2])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Dimensions mismatch"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimensions mismatch"):
|
||||
tensor_format.locate_tensor_element(out, [0])
|
||||
|
||||
def testLocateTensorElement2DNoEllipsisWithNumericSummary(self):
|
||||
@ -479,16 +472,13 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self._checkTensorElementLocations(out, a)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices exceed tensor dimensions"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices exceed tensor dimensions"):
|
||||
tensor_format.locate_tensor_element(out, [1, 4])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices contain negative"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices contain negative"):
|
||||
tensor_format.locate_tensor_element(out, [-1, 2])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Dimensions mismatch"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimensions mismatch"):
|
||||
tensor_format.locate_tensor_element(out, [0])
|
||||
|
||||
def testLocateTensorElement3DWithEllipses(self):
|
||||
@ -564,16 +554,13 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsNone(start_col) # Past ellipsis.
|
||||
self.assertIsNone(end_col)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices exceed tensor dimensions"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices exceed tensor dimensions"):
|
||||
tensor_format.locate_tensor_element(out, [11, 5, 5])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Indices contain negative"):
|
||||
with self.assertRaisesRegex(ValueError, "Indices contain negative"):
|
||||
tensor_format.locate_tensor_element(out, [-1, 5, 5])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Dimensions mismatch"):
|
||||
with self.assertRaisesRegex(ValueError, "Dimensions mismatch"):
|
||||
tensor_format.locate_tensor_element(out, [5, 5])
|
||||
|
||||
def testLocateTensorElement3DWithEllipsesBatchMode(self):
|
||||
@ -633,7 +620,7 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(["Tensor \"a\":", "", "Uninitialized tensor:"],
|
||||
out.lines[:3])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError, "tensor_metadata is not available in annotations"):
|
||||
tensor_format.locate_tensor_element(out, [0])
|
||||
|
||||
|
@ -182,7 +182,7 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
|
||||
gpu_1_dir, "node_foo_1_2_DebugIdentity_1472563253536387"), "wb")
|
||||
|
||||
def testDebugDumpDir_nonexistentDumpRoot(self):
|
||||
with self.assertRaisesRegexp(IOError, "does not exist"):
|
||||
with self.assertRaisesRegex(IOError, "does not exist"):
|
||||
debug_data.DebugDumpDir(tempfile.mktemp() + "_foo")
|
||||
|
||||
def testDebugDumpDir_invalidFileNamingPattern(self):
|
||||
@ -194,8 +194,8 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
|
||||
os.makedirs(device_dir)
|
||||
open(os.path.join(device_dir, "node1_DebugIdentity_1234"), "wb")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"does not conform to the naming pattern"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"does not conform to the naming pattern"):
|
||||
debug_data.DebugDumpDir(self._dump_root)
|
||||
|
||||
def testDebugDumpDir_validDuplicateNodeNamesWithMultipleDevices(self):
|
||||
@ -228,8 +228,7 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1472563253536385, dump_dir.t0)
|
||||
self.assertEqual(3, dump_dir.size)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Invalid device name: "):
|
||||
with self.assertRaisesRegex(ValueError, r"Invalid device name: "):
|
||||
dump_dir.nodes("/job:localhost/replica:0/task:0/device:GPU:2")
|
||||
self.assertItemsEqual(["node_foo_1", "node_foo_1", "node_foo_1"],
|
||||
dump_dir.nodes())
|
||||
@ -259,8 +258,7 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
|
||||
node.op = "FooOp"
|
||||
node.device = "/job:localhost/replica:0/task:0/device:GPU:1"
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Duplicate node name on device "):
|
||||
with self.assertRaisesRegex(ValueError, r"Duplicate node name on device "):
|
||||
debug_data.DebugDumpDir(
|
||||
self._dump_root,
|
||||
partition_graphs=[graph_cpu_0, graph_gpu_0, graph_gpu_1])
|
||||
|
@ -674,8 +674,8 @@ class MultiSetReaderTest(dumping_callback_test_lib.DumpingCallbackTestBase):
|
||||
re.sub(r"(tfdbg_events\.\d+)", r"\g<1>1", os.path.basename(src_path)))
|
||||
os.rename(src_path, dst_path)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"Found multiple \(2\) tfdbg2 runs"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Found multiple \(2\) tfdbg2 runs"):
|
||||
debug_events_reader.DebugDataReader(dump_root_0)
|
||||
|
||||
|
||||
|
@ -119,8 +119,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
||||
def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self):
|
||||
grad_debugger = debug_gradients.GradientsDebugger()
|
||||
grad_debugger.identify_gradient(self.w)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"The graph already contains an op named .*"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The graph already contains an op named .*"):
|
||||
grad_debugger.identify_gradient(self.w)
|
||||
|
||||
def testIdentifyGradientWorksOnMultipleLosses(self):
|
||||
@ -162,18 +162,18 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
||||
# registered.
|
||||
gradients_impl.gradients(y, [self.u, self.v])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
LookupError,
|
||||
r"This GradientsDebugger has not received any gradient tensor for "):
|
||||
grad_debugger_1.gradient_tensor(self.w)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
LookupError,
|
||||
r"This GradientsDebugger has not received any gradient tensor for "):
|
||||
grad_debugger_2.gradient_tensor(self.w)
|
||||
|
||||
def testIdentifyGradientRaisesTypeErrorForNonTensorOrTensorNameInput(self):
|
||||
grad_debugger = debug_gradients.GradientsDebugger()
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"x_tensor must be a str or tf\.Tensor or tf\.Variable, but instead "
|
||||
r"has type .*Operation.*"):
|
||||
@ -370,7 +370,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1, len(u_grad_values))
|
||||
self.assertAllClose(30.0, u_grad_values[0])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
LookupError,
|
||||
r"This GradientsDebugger has not received any gradient tensor for "
|
||||
r"x-tensor v:0"):
|
||||
|
@ -91,20 +91,20 @@ class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
|
||||
def testParseDebugNodeName_invalidPrefix(self):
|
||||
invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid prefix"):
|
||||
debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
|
||||
|
||||
def testParseDebugNodeName_missingDebugOpIndex(self):
|
||||
invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid debug node name"):
|
||||
debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
|
||||
|
||||
def testParseDebugNodeName_invalidWatchedTensorName(self):
|
||||
invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Invalid tensor name in debug node name"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Invalid tensor name in debug node name"):
|
||||
debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
|
||||
|
||||
|
||||
|
@ -749,7 +749,7 @@ class DebugNumericSummaryV2Test(test_util.TensorFlowTestCase):
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([-1.0, 1.0])
|
||||
t2 = constant_op.constant([0.0, 0.0])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
r"pass through test.*had -Inf and \+Inf values"):
|
||||
self.evaluate(
|
||||
@ -760,7 +760,7 @@ class DebugNumericSummaryV2Test(test_util.TensorFlowTestCase):
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([-1.0, 1.0, 0.0])
|
||||
t2 = constant_op.constant([0.0, 0.0, 0.0])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
r"pass through test.*had -Inf, \+Inf, and NaN values"):
|
||||
self.evaluate(
|
||||
@ -771,7 +771,7 @@ class DebugNumericSummaryV2Test(test_util.TensorFlowTestCase):
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([0.0, 1.0])
|
||||
t2 = constant_op.constant([0.0, 0.0])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
r"pass through test.*had \+Inf and NaN values"):
|
||||
self.evaluate(
|
||||
|
@ -84,9 +84,8 @@ class DumpingCallbackTest(
|
||||
return "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
|
||||
def testInvalidTensorDebugModeCausesError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"Invalid value in tensor_debug_mode \(\'NONSENSICAL\'\).*"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Invalid value in tensor_debug_mode \(\'NONSENSICAL\'\).*"
|
||||
r"Valid options.*NO_TENSOR.*"):
|
||||
dumping_callback.enable_dump_debug_info(
|
||||
self.dump_root, tensor_debug_mode="NONSENSICAL")
|
||||
@ -947,19 +946,16 @@ class DumpingCallbackTest(
|
||||
tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
|
||||
|
||||
def testIncorrectTensorDTypeArgFormatLeadsToError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r".*expected.*list.*tuple.*callable.*but received.*\{\}"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r".*expected.*list.*tuple.*callable.*but received.*\{\}"):
|
||||
dumping_callback.enable_dump_debug_info(self.dump_root,
|
||||
tensor_dtypes=dict())
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r".*expected.*list.*tuple.*callable.*but received.*"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r".*expected.*list.*tuple.*callable.*but received.*"):
|
||||
dumping_callback.enable_dump_debug_info(self.dump_root,
|
||||
tensor_dtypes="float32")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r".*expected.*list.*tuple.*callable.*but received.*"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r".*expected.*list.*tuple.*callable.*but received.*"):
|
||||
dumping_callback.enable_dump_debug_info(
|
||||
self.dump_root, tensor_dtypes=dtypes.float32)
|
||||
with self.assertRaises(TypeError):
|
||||
@ -1220,7 +1216,7 @@ class DumpingCallbackTest(
|
||||
# array.
|
||||
self.assertAllEqual(tensor_value, [])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"already.*NO_TENSOR.*FULL_TENSOR.*not be honored"):
|
||||
dumping_callback.enable_dump_debug_info(
|
||||
self.dump_root, tensor_debug_mode="FULL_TENSOR")
|
||||
|
@ -54,8 +54,8 @@ class GrpcDebugServerTest(test_util.TensorFlowTestCase):
|
||||
# The server is started asynchronously. It needs to be polled till its state
|
||||
# has become started.
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Server has already started running"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Server has already started running"):
|
||||
server.run_server()
|
||||
|
||||
server.stop_server().wait()
|
||||
@ -68,7 +68,7 @@ class GrpcDebugServerTest(test_util.TensorFlowTestCase):
|
||||
server.stop_server().wait()
|
||||
server_thread.join()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
|
||||
with self.assertRaisesRegex(ValueError, "Server has already stopped"):
|
||||
server.stop_server().wait()
|
||||
|
||||
def testRunServerAfterStopRaisesException(self):
|
||||
@ -78,7 +78,7 @@ class GrpcDebugServerTest(test_util.TensorFlowTestCase):
|
||||
server.stop_server().wait()
|
||||
server_thread.join()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
|
||||
with self.assertRaisesRegex(ValueError, "Server has already stopped"):
|
||||
server.run_server()
|
||||
|
||||
def testStartServerWithoutBlocking(self):
|
||||
@ -131,14 +131,14 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
|
||||
sess = session.Session(
|
||||
config=session_debug_testlib.no_rewrite_session_config())
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Expected type str or list in grpc_debug_server_addresses"):
|
||||
grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
|
||||
|
||||
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
|
||||
sess = session.Session(
|
||||
config=session_debug_testlib.no_rewrite_session_config())
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Expected type str in list grpc_debug_server_addresses"):
|
||||
grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
|
||||
|
||||
@ -307,11 +307,10 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
|
||||
# Check that the server has _not_ received any tracebacks, as a result of
|
||||
# the disabling above.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Op .*u/read.* does not exist"):
|
||||
with self.assertRaisesRegex(ValueError, r"Op .*u/read.* does not exist"):
|
||||
self.assertTrue(self._server.query_op_traceback("u/read"))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* has not received any source file"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r".* has not received any source file"):
|
||||
self._server.query_source_file_line(__file__, 1)
|
||||
|
||||
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
|
||||
@ -693,11 +692,11 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# No op traceback or source code should have been received by the debug
|
||||
# server due to the disabling above.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Op .*delta_1.* does not exist"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Op .*delta_1.* does not exist"):
|
||||
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* has not received any source file"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r".* has not received any source file"):
|
||||
self._server_1.query_source_file_line(__file__, 1)
|
||||
|
||||
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
||||
|
@ -287,8 +287,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testLoadNonexistentNonParPathFailsWithIOError(self):
|
||||
bad_path = os.path.join(self.get_temp_dir(), "nonexistent.py")
|
||||
with self.assertRaisesRegexp(
|
||||
IOError, "neither exists nor can be loaded.*par.*"):
|
||||
with self.assertRaisesRegex(IOError,
|
||||
"neither exists nor can be loaded.*par.*"):
|
||||
source_utils.load_source(bad_path)
|
||||
|
||||
def testLoadingPythonSourceFileInParFileSucceeds(self):
|
||||
@ -315,8 +315,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
||||
zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))
|
||||
|
||||
source_path = os.path.join(par_path, "tensorflow_models", "nonexistent.py")
|
||||
with self.assertRaisesRegexp(
|
||||
IOError, "neither exists nor can be loaded.*par.*"):
|
||||
with self.assertRaisesRegex(IOError,
|
||||
"neither exists nor can be loaded.*par.*"):
|
||||
source_utils.load_source(source_path)
|
||||
|
||||
|
||||
|
@ -73,7 +73,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
os.mkdir(dir_path)
|
||||
self.assertTrue(os.path.isdir(dir_path))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "session_root path points to a non-empty directory"):
|
||||
dumping_wrapper.DumpingDebugWrapperSession(
|
||||
session.Session(), session_root=self.session_root, log_usage=False)
|
||||
@ -83,8 +83,8 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
open(file_path, "a").close() # Create the file
|
||||
self.assertTrue(gfile.Exists(file_path))
|
||||
self.assertFalse(gfile.IsDirectory(file_path))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"session_root path points to a file"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"session_root path points to a file"):
|
||||
dumping_wrapper.DumpingDebugWrapperSession(
|
||||
session.Session(), session_root=file_path, log_usage=False)
|
||||
|
||||
@ -161,7 +161,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testUsingNonCallableAsWatchFnRaisesTypeError(self):
|
||||
bad_watch_fn = "bad_watch_fn"
|
||||
with self.assertRaisesRegexp(TypeError, "watch_fn is not callable"):
|
||||
with self.assertRaisesRegex(TypeError, "watch_fn is not callable"):
|
||||
dumping_wrapper.DumpingDebugWrapperSession(
|
||||
self.sess,
|
||||
session_root=self.session_root,
|
||||
|
@ -273,11 +273,11 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
"""Attempt to wrap a non-Session-type object should cause an exception."""
|
||||
|
||||
wrapper = TestDebugWrapperSessionBadAction(self._sess)
|
||||
with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
|
||||
with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"):
|
||||
TestDebugWrapperSessionBadAction(wrapper)
|
||||
|
||||
def testSessionInitBadActionValue(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Invalid OnSessionInitAction value: nonsense_action"):
|
||||
TestDebugWrapperSessionBadAction(
|
||||
self._sess, bad_init_action="nonsense_action")
|
||||
@ -286,7 +286,7 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
wrapper = TestDebugWrapperSessionBadAction(
|
||||
self._sess, bad_run_start_action="nonsense_action")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Invalid OnRunStartAction value: nonsense_action"):
|
||||
wrapper.run(self._s)
|
||||
|
||||
@ -296,7 +296,7 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
wrapper = TestDebugWrapperSessionBadAction(
|
||||
self._sess, bad_debug_urls="file://foo")
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
|
||||
with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"):
|
||||
wrapper.run(self._s)
|
||||
|
||||
def testErrorDuringRun(self):
|
||||
|
@ -191,7 +191,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
os.mkdir(dir_path)
|
||||
self.assertTrue(os.path.isdir(dir_path))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "dump_root path points to a non-empty directory"):
|
||||
local_cli_wrapper.LocalCLIDebugWrapperSession(
|
||||
session.Session(), dump_root=self._tmp_dir, log_usage=False)
|
||||
@ -201,7 +201,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
file_path = os.path.join(self._tmp_dir, "foo")
|
||||
open(file_path, "a").close() # Create the file
|
||||
self.assertTrue(os.path.isfile(file_path))
|
||||
with self.assertRaisesRegexp(ValueError, "dump_root path points to a file"):
|
||||
with self.assertRaisesRegex(ValueError, "dump_root path points to a file"):
|
||||
local_cli_wrapper.LocalCLIDebugWrapperSession(
|
||||
session.Session(), dump_root=file_path, log_usage=False)
|
||||
|
||||
@ -540,7 +540,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run"]], self.sess, dump_root=self._tmp_dir)
|
||||
with self.assertRaisesRegexp(errors.OpError, r".*[Dd]evice.*1337.*"):
|
||||
with self.assertRaisesRegex(errors.OpError, r".*[Dd]evice.*1337.*"):
|
||||
wrapped_sess.run(w)
|
||||
|
||||
def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):
|
||||
@ -811,7 +811,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
def testCallingShouldStopMethodOnNonWrappedNonMonitoredSessionErrors(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run"], ["run"]], self.sess)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"The wrapped session .* does not have a method .*should_stop.*"):
|
||||
wrapped_sess.should_stop()
|
||||
|
@ -40,8 +40,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
||||
@test_util.run_deprecated_v1
|
||||
def testFlattenTensorsShapesDefined(self):
|
||||
x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"must have statically known shape"):
|
||||
with self.assertRaisesRegex(ValueError, "must have statically known shape"):
|
||||
ar._flatten_tensors([x, x])
|
||||
|
||||
def testRingPermutations(self):
|
||||
|
@ -335,7 +335,7 @@ class GCEClusterResolverTest(test.TestCase):
|
||||
credentials=None,
|
||||
service=self.gen_standard_mock_service_client(name_to_ip))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'You cannot reset the task_type '
|
||||
'of the GCEClusterResolver after it has '
|
||||
'been created.'):
|
||||
|
@ -134,7 +134,7 @@ class KubernetesClusterResolverTest(test.TestCase):
|
||||
{'job-name=tensorflow': ret}))
|
||||
|
||||
error_msg = 'Pod "tensorflow-abc123" is not running; phase: "Failed"'
|
||||
with self.assertRaisesRegexp(RuntimeError, error_msg):
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
cluster_resolver.cluster_spec()
|
||||
|
||||
def testMultiplePodSelectorsAndWorkers(self):
|
||||
|
@ -140,8 +140,8 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_running_in_gce)
|
||||
def testCheckRunningInGceWithNoTpuName(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Please provide a TPU Name to connect to.*'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Please provide a TPU Name to connect to.*'):
|
||||
resolver.TPUClusterResolver(tpu='')
|
||||
|
||||
@mock.patch.object(six.moves.urllib.request, 'urlopen',
|
||||
|
@ -191,8 +191,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with self.assertRaisesRegexp(NotImplementedError,
|
||||
"does not support pure eager execution"):
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"does not support pure eager execution"):
|
||||
distribution.run(train_step, args=(next(input_iterator),))
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -150,7 +150,7 @@ def _run_in_and_out_of_scope(unbound_test_method):
|
||||
# When run under a different strategy the test method should fail.
|
||||
another_strategy = _TestStrategy()
|
||||
msg = "Mixing different .*Strategy objects"
|
||||
with test_case.assertRaisesRegexp(RuntimeError, msg):
|
||||
with test_case.assertRaisesRegex(RuntimeError, msg):
|
||||
with another_strategy.scope():
|
||||
unbound_test_method(test_case, dist)
|
||||
return wrapper
|
||||
@ -206,7 +206,7 @@ class TestStrategyTest(test.TestCase):
|
||||
scope.__enter__()
|
||||
self.assertIs(dist, ds_context.get_strategy())
|
||||
with ops.device("/device:CPU:0"):
|
||||
with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Device scope nesting error"):
|
||||
scope.__exit__(None, None, None)
|
||||
scope.__exit__(None, None, None)
|
||||
_assert_in_default_state(self)
|
||||
@ -222,8 +222,8 @@ class TestStrategyTest(test.TestCase):
|
||||
scope.__enter__()
|
||||
self.assertIs(dist, ds_context.get_strategy())
|
||||
with variable_scope.variable_creator_scope(creator):
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
"Variable creator scope nesting error"):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Variable creator scope nesting error"):
|
||||
scope.__exit__(None, None, None)
|
||||
scope.__exit__(None, None, None)
|
||||
_assert_in_default_state(self)
|
||||
@ -239,8 +239,8 @@ class TestStrategyTest(test.TestCase):
|
||||
scope.__enter__()
|
||||
self.assertIs(dist, ds_context.get_strategy())
|
||||
with variable_scope.variable_scope("AA"):
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
"Variable scope nesting error"):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Variable scope nesting error"):
|
||||
scope.__exit__(None, None, None)
|
||||
_assert_in_default_state(self)
|
||||
|
||||
@ -284,15 +284,15 @@ class TestStrategyTest(test.TestCase):
|
||||
_assert_in_default_state(self)
|
||||
dist = _TestStrategy()
|
||||
with dist.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
||||
ds_context.experimental_set_strategy(_TestStrategy())
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
||||
ds_context.experimental_set_strategy(dist)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
||||
ds_context.experimental_set_strategy(None)
|
||||
@ -313,9 +313,8 @@ class TestStrategyTest(test.TestCase):
|
||||
self.assertIs(dist, ds_context.get_strategy())
|
||||
dist2 = _TestStrategy()
|
||||
scope2 = dist2.scope()
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError,
|
||||
"Mixing different tf.distribute.Strategy objects"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
||||
with scope2:
|
||||
pass
|
||||
_assert_in_default_state(self)
|
||||
@ -496,7 +495,7 @@ class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
_assert_in_default_state(self)
|
||||
|
||||
with test_strategy.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
||||
variable_scope.variable(1.0, name="error")
|
||||
|
||||
@ -504,7 +503,7 @@ class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
_assert_in_default_state(self)
|
||||
|
||||
with test_strategy.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
||||
variable_scope.variable(1.0, name="also_error")
|
||||
|
||||
|
@ -438,7 +438,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0])
|
||||
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "`merge_call` called while defining a new graph."):
|
||||
distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
@ -457,7 +457,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
return model_fn_nested()
|
||||
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "`merge_call` called while defining a new graph."):
|
||||
distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
@ -706,7 +706,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
def model_fn():
|
||||
return mirrored_var.assign(5.0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
|
||||
"with the given reduce op ReduceOp.SUM."):
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
|
@ -377,7 +377,7 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
|
||||
def testNoneSynchronizationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
@ -387,7 +387,7 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
|
||||
def testNoneSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
@ -398,14 +398,14 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
|
||||
def testInvalidSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Invalid variable synchronization mode: Invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(1.0, name="v", synchronization="Invalid")
|
||||
|
||||
def testInvalidAggregationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.get_variable(
|
||||
@ -415,7 +415,7 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
|
||||
def testInvalidAggregationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(
|
||||
|
@ -30,8 +30,8 @@ class MultiProcessRunnerNoInitTest(test.TestCase):
|
||||
def simple_func():
|
||||
return 'foobar'
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`multi_process_runner` is not initialized.'):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'`multi_process_runner` is not initialized.'):
|
||||
multi_process_runner.run(
|
||||
simple_func,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1))
|
||||
|
@ -97,7 +97,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
|
||||
max_run_time=20)
|
||||
runner.start()
|
||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||
with self.assertRaisesRegex(ValueError, 'This is an error.'):
|
||||
runner.join()
|
||||
|
||||
def test_multi_process_runner_queue_emptied_between_runs(self):
|
||||
@ -287,7 +287,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
mpr.start()
|
||||
time.sleep(60)
|
||||
mpr.terminate_all()
|
||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||
with self.assertRaisesRegex(ValueError, 'This is an error.'):
|
||||
mpr.join()
|
||||
|
||||
def test_barrier(self):
|
||||
@ -402,7 +402,7 @@ class MultiProcessPoolRunnerTest(test.TestCase):
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
|
||||
pid = runner.run(proc_func_that_returns_pid)
|
||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||
with self.assertRaisesRegex(ValueError, 'This is an error.'):
|
||||
runner.run(proc_func_that_errors)
|
||||
self.assertAllEqual(runner.run(proc_func_that_returns_pid), pid)
|
||||
|
||||
|
@ -71,7 +71,7 @@ class NormalizeClusterSpecTest(test.TestCase):
|
||||
def testUnexpectedInput(self):
|
||||
cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"]
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
|
||||
"`tf.train.ClusterDef` object"):
|
||||
@ -94,11 +94,11 @@ class IsChiefTest(test.TestCase):
|
||||
self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0))
|
||||
self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`task_type` 'chief' not found in cluster_spec."):
|
||||
multi_worker_util.is_chief(cluster_spec, "chief", 0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
|
||||
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
||||
|
||||
@ -135,7 +135,7 @@ class NumWorkersTest(test.TestCase):
|
||||
|
||||
def testTaskTypeNotFound(self):
|
||||
cluster_spec = {}
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`task_type` 'worker' not found in cluster_spec."):
|
||||
multi_worker_util.worker_count(cluster_spec, task_type="worker")
|
||||
|
||||
@ -145,7 +145,7 @@ class NumWorkersTest(test.TestCase):
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
# A "ps" job shouldn't call this method.
|
||||
with self.assertRaisesRegexp(ValueError, "Unexpected `task_type` 'ps'"):
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected `task_type` 'ps'"):
|
||||
multi_worker_util.worker_count(cluster_spec, task_type="ps")
|
||||
|
||||
|
||||
@ -187,16 +187,16 @@ class IdInClusterTest(test.TestCase):
|
||||
|
||||
def testPsId(self):
|
||||
cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"There is no id for task_type 'ps'"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"There is no id for task_type 'ps'"):
|
||||
multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
|
||||
|
||||
def testMultipleChiefs(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
|
||||
}
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"There must be at most one 'chief' job."):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"There must be at most one 'chief' job."):
|
||||
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
|
||||
|
||||
|
||||
@ -257,7 +257,7 @@ class ClusterSpecValidationTest(test.TestCase):
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`task_type` 'worker' not found in cluster_spec."):
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
|
||||
|
||||
|
@ -173,23 +173,23 @@ class ShardedVariableTest(test.TestCase):
|
||||
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
|
||||
|
||||
def test_validation_errors(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Expected a list of '):
|
||||
with self.assertRaisesRegex(ValueError, 'Expected a list of '):
|
||||
sharded_variable.ShardedVariable(
|
||||
[variables_lib.Variable([0]), 'not-a-variable'])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have the same dtype'):
|
||||
with self.assertRaisesRegex(ValueError, 'must have the same dtype'):
|
||||
sharded_variable.ShardedVariable([
|
||||
variables_lib.Variable([0], dtype='int64'),
|
||||
variables_lib.Variable([1], dtype='int32')
|
||||
])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'the same shapes except'):
|
||||
with self.assertRaisesRegex(ValueError, 'the same shapes except'):
|
||||
sharded_variable.ShardedVariable([
|
||||
variables_lib.Variable(array_ops.ones((5, 10))),
|
||||
variables_lib.Variable(array_ops.ones((5, 20)))
|
||||
])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, '`SaveSliceInfo` should not'):
|
||||
with self.assertRaisesRegex(ValueError, '`SaveSliceInfo` should not'):
|
||||
v = variables_lib.Variable([0])
|
||||
v._set_save_slice_info(
|
||||
variables_lib.Variable.SaveSliceInfo(
|
||||
|
@ -30,18 +30,18 @@ class CanonicalizeVariableNameTest(test.TestCase):
|
||||
return shared_variable_creator._canonicalize_variable_name(name)
|
||||
|
||||
def testNoName(self):
|
||||
self.assertEquals("Variable", self._canonicalize(None))
|
||||
self.assertEqual("Variable", self._canonicalize(None))
|
||||
|
||||
def testPatternInMiddle(self):
|
||||
self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz"))
|
||||
self.assertEqual("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz"))
|
||||
|
||||
def testPatternAtEnd(self):
|
||||
self.assertEquals("foo", self._canonicalize("foo_1"))
|
||||
self.assertEqual("foo", self._canonicalize("foo_1"))
|
||||
|
||||
def testWrongPatterns(self):
|
||||
self.assertEquals("foo_1:0", self._canonicalize("foo_1:0"))
|
||||
self.assertEquals("foo1", self._canonicalize("foo1"))
|
||||
self.assertEquals("foo_a", self._canonicalize("foo_a"))
|
||||
self.assertEqual("foo_1:0", self._canonicalize("foo_1:0"))
|
||||
self.assertEqual("foo1", self._canonicalize("foo1"))
|
||||
self.assertEqual("foo_a", self._canonicalize("foo_a"))
|
||||
|
||||
|
||||
class SharedVariableCreatorTest(test.TestCase):
|
||||
|
@ -52,7 +52,7 @@ class VirtualDevicesTest(test.TestCase, parameterized.TestCase):
|
||||
def testSetVirtualCPUsErrors(self):
|
||||
with self.assertRaises(ValueError):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(0)
|
||||
with self.assertRaisesRegexp(RuntimeError, "with 3 < 5 virtual CPUs"):
|
||||
with self.assertRaisesRegex(RuntimeError, "with 3 < 5 virtual CPUs"):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(5)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
|
@ -799,7 +799,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
y = control_flow_ops.cond(x < x, true_fn, false_fn)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'tf.gradients'):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'tf.gradients'):
|
||||
dy = g.gradient(y, [x])[0]
|
||||
else:
|
||||
dy = g.gradient(y, [x])[0]
|
||||
@ -822,7 +822,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
_, y = control_flow_ops.while_loop(cond, body, [i, x])
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'tf.gradients'):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'tf.gradients'):
|
||||
dy = g.gradient(y, [x])[0]
|
||||
else:
|
||||
dy = g.gradient(y, [x])[0]
|
||||
@ -836,7 +836,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
y = x * x
|
||||
z = y * y
|
||||
g.gradient(z, [x])
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'GradientTape.gradient can only be called once'):
|
||||
g.gradient(y, [x])
|
||||
|
||||
@ -958,7 +958,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch([x, y])
|
||||
z = y * 2
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
|
||||
g.gradient(z, x, unconnected_gradients='nonsense')
|
||||
|
||||
@ -989,8 +989,8 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(x)
|
||||
tape_lib.record_operation('InvalidBackprop', [y], [x], lambda dy: [])
|
||||
with self.assertRaisesRegexp(errors_impl.InternalError,
|
||||
'InvalidBackprop.*too few gradients'):
|
||||
with self.assertRaisesRegex(errors_impl.InternalError,
|
||||
'InvalidBackprop.*too few gradients'):
|
||||
g.gradient(y, x)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@ -1295,13 +1295,13 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
y = constant_op.constant(2)
|
||||
|
||||
loss_grads_fn = backprop.implicit_val_and_grad(fn)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Cannot differentiate a function that returns None; '
|
||||
'did you forget to return a value from fn?'):
|
||||
loss_grads_fn(x, y)
|
||||
|
||||
val_and_grads_fn = backprop.val_and_grad_function(fn)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Cannot differentiate a function that returns None; '
|
||||
'did you forget to return a value from fn?'):
|
||||
val_and_grads_fn(x, y)
|
||||
@ -1504,7 +1504,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testWatchBadThing(self):
|
||||
g = backprop.GradientTape()
|
||||
with self.assertRaisesRegexp(ValueError, 'ndarray'):
|
||||
with self.assertRaisesRegex(ValueError, 'ndarray'):
|
||||
g.watch(np.array(1.))
|
||||
|
||||
def testWatchComposite(self):
|
||||
@ -1659,7 +1659,7 @@ class JacobianTest(test.TestCase):
|
||||
x = constant_op.constant([1.0, 2.0])
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'persistent'):
|
||||
g.jacobian(y, x, experimental_use_pfor=False)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@ -1749,28 +1749,28 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
x = constant_op.constant([[1.0, 2.0]])
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'persistent'):
|
||||
g.batch_jacobian(y, x, experimental_use_pfor=False)
|
||||
|
||||
def testBadShape(self):
|
||||
x = random_ops.random_uniform([2, 3])
|
||||
with backprop.GradientTape() as g:
|
||||
y = array_ops.concat([x, x], axis=0)
|
||||
with self.assertRaisesRegexp(ValueError, 'Need first dimension'):
|
||||
with self.assertRaisesRegex(ValueError, 'Need first dimension'):
|
||||
g.batch_jacobian(y, x)
|
||||
|
||||
def testBadInputRank(self):
|
||||
x = random_ops.random_uniform([2])
|
||||
with backprop.GradientTape() as g:
|
||||
y = random_ops.random_uniform([2, 2])
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
|
||||
with self.assertRaisesRegex(ValueError, 'must have rank at least 2'):
|
||||
g.batch_jacobian(y, x)
|
||||
|
||||
def testBadOutputRank(self):
|
||||
x = random_ops.random_uniform([2, 2])
|
||||
with backprop.GradientTape() as g:
|
||||
y = random_ops.random_uniform([2])
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
|
||||
with self.assertRaisesRegex(ValueError, 'must have rank at least 2'):
|
||||
g.batch_jacobian(y, x)
|
||||
|
||||
def test_parallel_iterations(self):
|
||||
|
@ -89,7 +89,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
else:
|
||||
# TODO(gjn): Figure out how to make this work for tf.Tensor
|
||||
# self.assertNotIsInstance(b, collections.Hashable)
|
||||
with self.assertRaisesRegexp(TypeError, 'unhashable'):
|
||||
with self.assertRaisesRegex(TypeError, 'unhashable'):
|
||||
set([a, b])
|
||||
|
||||
def testEquality(self):
|
||||
@ -464,7 +464,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
def testContextConfig(self):
|
||||
ctx = context.Context(config=config_pb2.ConfigProto(
|
||||
device_count={'GPU': 0}))
|
||||
self.assertEquals(0, ctx.num_gpus())
|
||||
self.assertEqual(0, ctx.num_gpus())
|
||||
|
||||
def testPickle(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
@ -485,7 +485,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
self.assertEndsWith(current_device(), 'CPU:0')
|
||||
gpu.__enter__()
|
||||
self.assertEndsWith(current_device(), 'GPU:0')
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Exiting device scope without proper scope nesting'):
|
||||
cpu.__exit__()
|
||||
self.assertEndsWith(current_device(), 'GPU:0')
|
||||
@ -926,7 +926,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
x = constant_op.constant(1)
|
||||
three_x = add(add(x, x), x)
|
||||
self.assertEquals(dtypes.int32, three_x.dtype)
|
||||
self.assertEqual(dtypes.int32, three_x.dtype)
|
||||
self.assertAllEqual(3, three_x)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@ -953,7 +953,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
types, tensors = execute_lib.convert_to_mixed_eager_tensors(
|
||||
[array, tensor], context.context())
|
||||
for typ, t in zip(types, tensors):
|
||||
self.assertEquals(typ, dtypes.float32)
|
||||
self.assertEqual(typ, dtypes.float32)
|
||||
self.assertIsInstance(t, ops.EagerTensor)
|
||||
|
||||
def testConvertMixedEagerTensorsWithVariables(self):
|
||||
|
@ -41,7 +41,7 @@ class CustomDeviceTest(test.TestCase):
|
||||
# There was no copy onto the device. Actually I'm not sure how to trigger
|
||||
# that from Python.
|
||||
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
|
||||
with self.assertRaisesRegexp(errors.InternalError, 'Trying to copy'):
|
||||
with self.assertRaisesRegex(errors.InternalError, 'Trying to copy'):
|
||||
y.numpy()
|
||||
|
||||
|
||||
|
@ -218,8 +218,8 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
state.append(variables.Variable(2.0 * x))
|
||||
return state[0] * x
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
lift_to_graph.UnliftableError, r'transitively.* mul .* x'):
|
||||
with self.assertRaisesRegex(lift_to_graph.UnliftableError,
|
||||
r'transitively.* mul .* x'):
|
||||
fn(constant_op.constant(3.0))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
@ -393,8 +393,8 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
outputs.append(inputs[t])
|
||||
return outputs
|
||||
|
||||
with self.assertRaisesRegexp(errors.InaccessibleTensorError,
|
||||
'defined in another function or code block'):
|
||||
with self.assertRaisesRegex(errors.InaccessibleTensorError,
|
||||
'defined in another function or code block'):
|
||||
f(array_ops.zeros(shape=(8, 42, 3)))
|
||||
|
||||
@test_util.disable_tfrt('Control flow is not supported')
|
||||
@ -472,7 +472,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.init_scope():
|
||||
_ = a + a
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
|
||||
failing_function()
|
||||
@ -627,7 +627,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return a[0].read_value()
|
||||
|
||||
create_variable()
|
||||
self.assertRegexpMatches(a[0].device, 'CPU')
|
||||
self.assertRegex(a[0].device, 'CPU')
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
@test_util.run_gpu_only
|
||||
@ -647,8 +647,8 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
with ops.device('CPU:0'):
|
||||
create_variable()
|
||||
self.assertRegexpMatches(a[0].device, 'CPU')
|
||||
self.assertRegexpMatches(initial_value[0].device, 'CPU')
|
||||
self.assertRegex(a[0].device, 'CPU')
|
||||
self.assertRegex(initial_value[0].device, 'CPU')
|
||||
|
||||
def testDecorate(self):
|
||||
func = def_function.function(lambda: 1)
|
||||
@ -727,7 +727,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
func = def_function.function(lambda: 1)
|
||||
self.assertEqual(func().numpy(), 1)
|
||||
msg = 'Functions cannot be decorated after they have been traced.'
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
func._decorate(lambda f: f)
|
||||
|
||||
def testGetConcreteFunctionGraphLifetime(self):
|
||||
|
@ -142,8 +142,8 @@ class DefFunctionTest(test.TestCase):
|
||||
func = def_function.function(fn2, experimental_compile=False)
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
if not test.is_built_with_rocm():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'not compilable'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'not compilable'):
|
||||
func(inputs)
|
||||
|
||||
def testUnsupportedOps(self):
|
||||
@ -156,7 +156,7 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertAllClose([1, 2, 3], func(inputs))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, 'not compilable'):
|
||||
xla_func(inputs)
|
||||
|
||||
def testFunctionGradient(self):
|
||||
@ -236,7 +236,7 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
c = C()
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, 'not compilable'):
|
||||
c.f1(inputs)
|
||||
|
||||
def testMustBeConstantPropagation(self):
|
||||
@ -285,9 +285,8 @@ class DefFunctionTest(test.TestCase):
|
||||
x = constant_op.constant(3.14)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.UnimplementedError,
|
||||
'TensorList crossing the XLA/TF boundary'):
|
||||
with self.assertRaisesRegex(errors.UnimplementedError,
|
||||
'TensorList crossing the XLA/TF boundary'):
|
||||
y = f(x)
|
||||
tape.gradient(y, x)
|
||||
|
||||
|
@ -282,7 +282,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
def testJVPFunctionRaisesError(self):
|
||||
sum_outputs = (constant_op.constant(6.),)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r".*was expected to be of shape*"):
|
||||
with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"):
|
||||
forwardprop._jvp_dispatch(
|
||||
op_name="Add",
|
||||
attr_tuple=(),
|
||||
@ -343,7 +343,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testMultipleWatchesAdd(self):
|
||||
x = constant_op.constant(-2.)
|
||||
with self.assertRaisesRegexp(ValueError, "multiple times"):
|
||||
with self.assertRaisesRegex(ValueError, "multiple times"):
|
||||
with forwardprop.ForwardAccumulator(
|
||||
[x, x], [1., 2.]):
|
||||
pass
|
||||
@ -365,7 +365,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(1.5, acc.jvp(x))
|
||||
y = 4. * x
|
||||
self.assertAllClose(6., acc.jvp(y))
|
||||
with self.assertRaisesRegexp(ValueError, "already recording"):
|
||||
with self.assertRaisesRegex(ValueError, "already recording"):
|
||||
with acc:
|
||||
pass
|
||||
z = 4. * x
|
||||
@ -434,8 +434,8 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
def f(x):
|
||||
return math_ops.reduce_prod(math_ops.tanh(x)**2)
|
||||
|
||||
with self.assertRaisesRegexp(NotImplementedError,
|
||||
"recompute_grad tried to transpose"):
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"recompute_grad tried to transpose"):
|
||||
primals = [constant_op.constant([1.])]
|
||||
sym_jac_fwd = _jacfwd(f, primals)
|
||||
|
||||
@ -450,7 +450,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
c = constant_op.constant(1.)
|
||||
d = constant_op.constant(2.)
|
||||
with forwardprop.ForwardAccumulator(c, d):
|
||||
with self.assertRaisesRegexp(ValueError, "test_error_string"):
|
||||
with self.assertRaisesRegex(ValueError, "test_error_string"):
|
||||
f(c)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -55,9 +55,9 @@ class DefunCollectionTest(test.TestCase, parameterized.TestCase):
|
||||
return z
|
||||
|
||||
self.assertEqual(7, int(self.evaluate(fn())))
|
||||
self.assertEquals(ops.get_collection('x'), [2])
|
||||
self.assertEquals(ops.get_collection('y'), [5])
|
||||
self.assertEquals(ops.get_collection('z'), [])
|
||||
self.assertEqual(ops.get_collection('x'), [2])
|
||||
self.assertEqual(ops.get_collection('y'), [5])
|
||||
self.assertEqual(ops.get_collection('z'), [])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='Defun', function_decorator=function.defun),
|
||||
@ -76,8 +76,7 @@ class DefunCollectionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(1.0, float(self.evaluate(f())))
|
||||
self.assertEquals(
|
||||
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 1)
|
||||
self.assertLen(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), 1)
|
||||
|
||||
def testCollectionVariableValueWrite(self):
|
||||
"""Write variable value inside defun."""
|
||||
@ -92,8 +91,7 @@ class DefunCollectionTest(test.TestCase, parameterized.TestCase):
|
||||
_ = f.get_concrete_function()
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(2.0, float(self.evaluate(f())))
|
||||
self.assertEquals(
|
||||
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 1)
|
||||
self.assertLen(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user