Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
This commit is contained in:
parent
bbad7c0a07
commit
24f578cd66
@ -19,10 +19,12 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class FactTest(tf.test.TestCase):
|
class FactTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test(self):
|
def test(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
print(tf.user_ops.my_fact().eval())
|
print(tf.user_ops.my_fact().eval())
|
||||||
|
@ -23,10 +23,12 @@ import os.path
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.examples.adding_an_op import zero_out_op_1
|
from tensorflow.examples.adding_an_op import zero_out_op_1
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class ZeroOut1Test(tf.test.TestCase):
|
class ZeroOut1Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test(self):
|
def test(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
|
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
|
||||||
|
@ -24,20 +24,24 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import
|
from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import
|
||||||
from tensorflow.examples.adding_an_op import zero_out_op_2
|
from tensorflow.examples.adding_an_op import zero_out_op_2
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class ZeroOut2Test(tf.test.TestCase):
|
class ZeroOut2Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test(self):
|
def test(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
|
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
|
||||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_2d(self):
|
def test_2d(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
|
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
|
||||||
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
|
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_grad(self):
|
def test_grad(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
shape = (5,)
|
shape = (5,)
|
||||||
@ -46,6 +50,7 @@ class ZeroOut2Test(tf.test.TestCase):
|
|||||||
err = tf.test.compute_gradient_error(x, shape, y, shape)
|
err = tf.test.compute_gradient_error(x, shape, y, shape)
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_grad_2d(self):
|
def test_grad_2d(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
shape = (2, 3)
|
shape = (2, 3)
|
||||||
|
@ -21,26 +21,31 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.examples.adding_an_op import zero_out_op_3
|
from tensorflow.examples.adding_an_op import zero_out_op_3
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class ZeroOut3Test(tf.test.TestCase):
|
class ZeroOut3Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test(self):
|
def test(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
|
||||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAttr(self):
|
def testAttr(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
|
||||||
self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
|
self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNegative(self):
|
def testNegative(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
|
||||||
with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
|
with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
|
||||||
self.evaluate(result)
|
self.evaluate(result)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLarge(self):
|
def testLarge(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python import autograph
|
from tensorflow.python import autograph
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class MinimalKeras(tf.keras.Model):
|
class MinimalKeras(tf.keras.Model):
|
||||||
@ -84,6 +85,7 @@ class KerasTest(tf.test.TestCase):
|
|||||||
model = ModelWithStaticConditional(True)
|
model = ModelWithStaticConditional(True)
|
||||||
self.assertEqual(model.call(), 25)
|
self.assertEqual(model.call(), 25)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_recursive_true(self):
|
def test_recursive_true(self):
|
||||||
with self.assertRaisesRegexp(NotImplementedError,
|
with self.assertRaisesRegexp(NotImplementedError,
|
||||||
'Object conversion is not yet supported.'):
|
'Object conversion is not yet supported.'):
|
||||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.examples.speech_commands import freeze
|
from tensorflow.examples.speech_commands import freeze
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class FreezeTest(test.TestCase):
|
class FreezeTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateInferenceGraphWithMfcc(self):
|
def testCreateInferenceGraphWithMfcc(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
freeze.create_inference_graph(
|
freeze.create_inference_graph(
|
||||||
@ -43,6 +45,7 @@ class FreezeTest(test.TestCase):
|
|||||||
ops = [node.op for node in sess.graph_def.node]
|
ops = [node.op for node in sess.graph_def.node]
|
||||||
self.assertEqual(1, ops.count('Mfcc'))
|
self.assertEqual(1, ops.count('Mfcc'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateInferenceGraphWithoutMfcc(self):
|
def testCreateInferenceGraphWithoutMfcc(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
freeze.create_inference_graph(
|
freeze.create_inference_graph(
|
||||||
@ -62,6 +65,7 @@ class FreezeTest(test.TestCase):
|
|||||||
ops = [node.op for node in sess.graph_def.node]
|
ops = [node.op for node in sess.graph_def.node]
|
||||||
self.assertEqual(0, ops.count('Mfcc'))
|
self.assertEqual(0, ops.count('Mfcc'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFeatureBinCount(self):
|
def testFeatureBinCount(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
freeze.create_inference_graph(
|
freeze.create_inference_graph(
|
||||||
|
@ -26,6 +26,7 @@ import tensorflow as tf
|
|||||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||||
from tensorflow.examples.speech_commands import input_data
|
from tensorflow.examples.speech_commands import input_data
|
||||||
from tensorflow.examples.speech_commands import models
|
from tensorflow.examples.speech_commands import models
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -96,6 +97,7 @@ class InputDataTest(test.TestCase):
|
|||||||
input_data.which_set("foo_nohash_0.wav", 10, 10),
|
input_data.which_set("foo_nohash_0.wav", 10, 10),
|
||||||
input_data.which_set("foo_nohash_1.wav", 10, 10))
|
input_data.which_set("foo_nohash_1.wav", 10, 10))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrepareDataIndex(self):
|
def testPrepareDataIndex(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
|
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
|
||||||
@ -125,6 +127,7 @@ class InputDataTest(test.TestCase):
|
|||||||
10, self._model_settings(), tmp_dir)
|
10, self._model_settings(), tmp_dir)
|
||||||
self.assertTrue("Expected to find" in str(e.exception))
|
self.assertTrue("Expected to find" in str(e.exception))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrepareBackgroundData(self):
|
def testPrepareBackgroundData(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
background_dir = os.path.join(tmp_dir, "_background_noise_")
|
background_dir = os.path.join(tmp_dir, "_background_noise_")
|
||||||
@ -156,6 +159,7 @@ class InputDataTest(test.TestCase):
|
|||||||
self.assertIsNotNone(loaded_data)
|
self.assertIsNotNone(loaded_data)
|
||||||
self.assertEqual(16000, len(loaded_data))
|
self.assertEqual(16000, len(loaded_data))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrepareProcessingGraph(self):
|
def testPrepareProcessingGraph(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||||
@ -186,15 +190,19 @@ class InputDataTest(test.TestCase):
|
|||||||
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
|
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
|
||||||
self.assertIsNotNone(audio_processor.output_)
|
self.assertIsNotNone(audio_processor.output_)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetDataAverage(self):
|
def testGetDataAverage(self):
|
||||||
self._runGetDataTest("average", 10)
|
self._runGetDataTest("average", 10)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetDataAverageLongWindow(self):
|
def testGetDataAverageLongWindow(self):
|
||||||
self._runGetDataTest("average", 30)
|
self._runGetDataTest("average", 30)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetDataMfcc(self):
|
def testGetDataMfcc(self):
|
||||||
self._runGetDataTest("mfcc", 30)
|
self._runGetDataTest("mfcc", 30)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetUnprocessedData(self):
|
def testGetUnprocessedData(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||||
@ -216,6 +224,7 @@ class InputDataTest(test.TestCase):
|
|||||||
self.assertEqual(10, len(result_data))
|
self.assertEqual(10, len(result_data))
|
||||||
self.assertEqual(10, len(result_labels))
|
self.assertEqual(10, len(result_labels))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetFeaturesForWav(self):
|
def testGetFeaturesForWav(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.examples.speech_commands import models
|
from tensorflow.examples.speech_commands import models
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +48,7 @@ class ModelsTest(test.TestCase):
|
|||||||
feature_bin_count=40,
|
feature_bin_count=40,
|
||||||
preprocess="mfcc"))
|
preprocess="mfcc"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateModelConvTraining(self):
|
def testCreateModelConvTraining(self):
|
||||||
model_settings = self._modelSettings()
|
model_settings = self._modelSettings()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
@ -58,6 +60,7 @@ class ModelsTest(test.TestCase):
|
|||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateModelConvInference(self):
|
def testCreateModelConvInference(self):
|
||||||
model_settings = self._modelSettings()
|
model_settings = self._modelSettings()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
@ -67,6 +70,7 @@ class ModelsTest(test.TestCase):
|
|||||||
self.assertIsNotNone(logits)
|
self.assertIsNotNone(logits)
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateModelLowLatencyConvTraining(self):
|
def testCreateModelLowLatencyConvTraining(self):
|
||||||
model_settings = self._modelSettings()
|
model_settings = self._modelSettings()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
@ -78,6 +82,7 @@ class ModelsTest(test.TestCase):
|
|||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateModelFullyConnectedTraining(self):
|
def testCreateModelFullyConnectedTraining(self):
|
||||||
model_settings = self._modelSettings()
|
model_settings = self._modelSettings()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
@ -98,6 +103,7 @@ class ModelsTest(test.TestCase):
|
|||||||
"bad_architecture", True)
|
"bad_architecture", True)
|
||||||
self.assertTrue("not recognized" in str(e.exception))
|
self.assertTrue("not recognized" in str(e.exception))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateModelTinyConvTraining(self):
|
def testCreateModelTinyConvTraining(self):
|
||||||
model_settings = self._modelSettings()
|
model_settings = self._modelSettings()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
@ -24,6 +24,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||||
from tensorflow.examples.speech_commands import wav_to_features
|
from tensorflow.examples.speech_commands import wav_to_features
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ class WavToFeaturesTest(test.TestCase):
|
|||||||
file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
|
file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
|
||||||
self._saveTestWavFile(file_path, wav_data)
|
self._saveTestWavFile(file_path, wav_data)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWavToFeatures(self):
|
def testWavToFeatures(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||||
|
@ -23,12 +23,14 @@ from tensorflow.python.autograph.converters import side_effect_guards
|
|||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_control_flow_ops
|
from tensorflow.python.ops import gen_control_flow_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class AssertsTest(converter_testing.TestCase):
|
class AssertsTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
|
@ -24,12 +24,14 @@ from tensorflow.python.autograph.converters import builtin_functions
|
|||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class BuiltinFunctionsTest(converter_testing.TestCase):
|
class BuiltinFunctionsTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
@ -41,6 +43,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
|
|||||||
ops = result.test_fn(p)
|
ops = result.test_fn(p)
|
||||||
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
|
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_print(self):
|
def test_print(self):
|
||||||
|
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
@ -54,6 +57,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
|
|||||||
with self.assertPrints('a\n'):
|
with self.assertPrints('a\n'):
|
||||||
sess.run(result.test_fn('a'))
|
sess.run(result.test_fn('a'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_print_multiple_values(self):
|
def test_print_multiple_values(self):
|
||||||
|
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.autograph.core import converter_testing
|
|||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
|
self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_while_basic(self):
|
def test_while_basic(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -48,6 +50,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
|
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_while_nested(self):
|
def test_while_nested(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -66,6 +69,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
self.assertTransformedResult(test_fn, constant_op.constant(5),
|
self.assertTransformedResult(test_fn, constant_op.constant(5),
|
||||||
(25, 5, 0, 5))
|
(25, 5, 0, 5))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_while_single_output(self):
|
def test_while_single_output(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -86,6 +90,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
control_flow.transform(node, ctx)
|
control_flow.transform(node, ctx)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_basic(self):
|
def test_if_basic(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -100,6 +105,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
|
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
|
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_complex_outputs(self):
|
def test_if_complex_outputs(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -124,6 +130,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
|
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
|
||||||
self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2))
|
self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_single_output(self):
|
def test_if_single_output(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -133,6 +140,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
|
self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_semi(self):
|
def test_if_semi(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -143,6 +151,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
|
self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
|
self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_local_var(self):
|
def test_if_local_var(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -154,6 +163,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_if_no_outputs(self):
|
def test_if_no_outputs(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def test_fn(n):
|
||||||
@ -177,6 +187,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
with self.assertRaises(transformer.AutographParseError):
|
with self.assertRaises(transformer.AutographParseError):
|
||||||
control_flow.transform(node, ctx)
|
control_flow.transform(node, ctx)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_simple_for(self):
|
def test_simple_for(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def test_fn(l):
|
||||||
@ -191,6 +202,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
||||||
self.assertTransformedResult(test_fn, empty_vector, (0, 0))
|
self.assertTransformedResult(test_fn, empty_vector, (0, 0))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_for_single_output(self):
|
def test_for_single_output(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def test_fn(l):
|
||||||
@ -235,6 +247,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
|||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
control_flow.transform(node, ctx)
|
control_flow.transform(node, ctx)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_for_tuple_unpacking(self):
|
def test_for_tuple_unpacking(self):
|
||||||
def test_fn(x_list):
|
def test_fn(x_list):
|
||||||
z = tf.constant(0) # pylint:disable=undefined-variable
|
z = tf.constant(0) # pylint:disable=undefined-variable
|
||||||
|
@ -22,11 +22,13 @@ from tensorflow.python.autograph.converters import function_scopes
|
|||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class FunctionBodyTransformerTest(converter_testing.TestCase):
|
class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def test_fn(l):
|
||||||
@ -40,6 +42,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
|||||||
self.assertIn('test_fn/', result_op.op.name)
|
self.assertIn('test_fn/', result_op.op.name)
|
||||||
self.assertEqual('Docstring.', result.test_fn.__doc__)
|
self.assertEqual('Docstring.', result.test_fn.__doc__)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_multiline_docstring(self):
|
def test_multiline_docstring(self):
|
||||||
|
|
||||||
tf = None
|
tf = None
|
||||||
@ -58,6 +61,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
|||||||
self.assertIn('First sentence.', result.test_fn.__doc__)
|
self.assertIn('First sentence.', result.test_fn.__doc__)
|
||||||
self.assertIn('Second sentence.', result.test_fn.__doc__)
|
self.assertIn('Second sentence.', result.test_fn.__doc__)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_nested_functions(self):
|
def test_nested_functions(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def test_fn(l):
|
||||||
@ -74,6 +78,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
|||||||
self.assertNotIn('inner_fn', first.op.name)
|
self.assertNotIn('inner_fn', first.op.name)
|
||||||
self.assertIn('test_fn/inner_fn/', second.op.name)
|
self.assertIn('test_fn/inner_fn/', second.op.name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_method(self):
|
def test_method(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
@ -21,11 +21,13 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.autograph.converters import logical_expressions
|
from tensorflow.python.autograph.converters import logical_expressions
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class LogicalExpressionTest(converter_testing.TestCase):
|
class LogicalExpressionTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_equals(self):
|
def test_equals(self):
|
||||||
|
|
||||||
def test_fn(a, b):
|
def test_fn(a, b):
|
||||||
@ -36,6 +38,7 @@ class LogicalExpressionTest(converter_testing.TestCase):
|
|||||||
self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1)))
|
self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1)))
|
||||||
self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2)))
|
self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_bool_ops(self):
|
def test_bool_ops(self):
|
||||||
|
|
||||||
def test_fn(a, b, c):
|
def test_fn(a, b, c):
|
||||||
@ -48,6 +51,7 @@ class LogicalExpressionTest(converter_testing.TestCase):
|
|||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
sess.run(result.test_fn(constant_op.constant(True), False, True)))
|
sess.run(result.test_fn(constant_op.constant(True), False, True)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_comparison(self):
|
def test_comparison(self):
|
||||||
|
|
||||||
def test_fn(a, b, c, d):
|
def test_fn(a, b, c, d):
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.autograph.core import converter_testing
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -34,6 +35,7 @@ tf = None # Will be replaced by a mock.
|
|||||||
|
|
||||||
class SideEffectGuardsTest(converter_testing.TestCase):
|
class SideEffectGuardsTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_side_effect_on_return_only_variable(self):
|
def test_side_effect_on_return_only_variable(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
@ -75,6 +77,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
# Right now it's 3 or 4 based on whether the read is synchronized.
|
# Right now it's 3 or 4 based on whether the read is synchronized.
|
||||||
self.assertEqual(3, self.evaluate(v))
|
self.assertEqual(3, self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_side_effect_on_tensor(self):
|
def test_side_effect_on_tensor(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.autograph.core import errors
|
|||||||
from tensorflow.python.autograph.pyct import origin_info
|
from tensorflow.python.autograph.pyct import origin_info
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors as tf_errors
|
from tensorflow.python.framework import errors as tf_errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
@ -47,6 +48,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
'test_comment')
|
'test_comment')
|
||||||
return loc, origin
|
return loc, origin
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_improved_errors_basic(self):
|
def test_improved_errors_basic(self):
|
||||||
loc, origin = self.fake_origin(zero_div, 2)
|
loc, origin = self.fake_origin(zero_div, 2)
|
||||||
zero_div_caller.ag_source_map = {loc: origin}
|
zero_div_caller.ag_source_map = {loc: origin}
|
||||||
@ -62,6 +64,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
self.assertNotEqual('zero_div', function_name)
|
self.assertNotEqual('zero_div', function_name)
|
||||||
self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
|
self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_improved_errors_no_matching_lineno(self):
|
def test_improved_errors_no_matching_lineno(self):
|
||||||
loc, origin = self.fake_origin(zero_div, -1)
|
loc, origin = self.fake_origin(zero_div, -1)
|
||||||
zero_div_caller.ag_source_map = {loc: origin}
|
zero_div_caller.ag_source_map = {loc: origin}
|
||||||
@ -79,6 +82,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
self.assertNotEqual('test_function_name', function_name)
|
self.assertNotEqual('test_function_name', function_name)
|
||||||
self.assertIn('zero_div', all_function_names)
|
self.assertIn('zero_div', all_function_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_improved_errors_failures(self):
|
def test_improved_errors_failures(self):
|
||||||
loc, _ = self.fake_origin(zero_div, 2)
|
loc, _ = self.fake_origin(zero_div, 2)
|
||||||
zero_div_caller.ag_source_map = {loc: 'bogus object'}
|
zero_div_caller.ag_source_map = {loc: 'bogus object'}
|
||||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.core import function_wrapping
|
from tensorflow.python.autograph.core import function_wrapping
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class FunctionWrappingTest(test.TestCase):
|
class FunctionWrappingTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_function_scope_name(self):
|
def test_function_scope_name(self):
|
||||||
with function_wrapping.function_scope('test_name'):
|
with function_wrapping.function_scope('test_name'):
|
||||||
t = constant_op.constant(1)
|
t = constant_op.constant(1)
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.autograph.impl import api
|
|||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.utils import py_func
|
from tensorflow.python.autograph.utils import py_func
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -44,6 +45,7 @@ class TestResource(str):
|
|||||||
|
|
||||||
class ApiTest(test.TestCase):
|
class ApiTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_recurses(self):
|
def test_decorator_recurses(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -66,6 +68,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_does_not_recurse(self):
|
def test_decorator_does_not_recurse(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -86,6 +89,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_unconverted_graph(self):
|
def test_decorator_calls_unconverted_graph(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -107,6 +111,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_unconverted_py_func(self):
|
def test_decorator_calls_unconverted_py_func(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -133,6 +138,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_decorated(self):
|
def test_decorator_calls_decorated(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -172,6 +178,7 @@ class ApiTest(test.TestCase):
|
|||||||
list(tf_inspect.getfullargspec(tc.called_member)),
|
list(tf_inspect.getfullargspec(tc.called_member)),
|
||||||
list(tf_inspect.getfullargspec(tc.called_member_converted)))
|
list(tf_inspect.getfullargspec(tc.called_member_converted)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_convert_call_site_decorator(self):
|
def test_convert_call_site_decorator(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -326,6 +333,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(0))
|
constant_op.constant(0))
|
||||||
self.assertTrue(self.evaluate(x))
|
self.assertTrue(self.evaluate(x))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_converted_call_no_user_code(self):
|
def test_converted_call_no_user_code(self):
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -400,6 +408,7 @@ class ApiTest(test.TestCase):
|
|||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual(True, self.evaluate(x))
|
self.assertAllEqual(True, self.evaluate(x))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_to_graph_basic(self):
|
def test_to_graph_basic(self):
|
||||||
|
|
||||||
def test_fn(x, s):
|
def test_fn(x, s):
|
||||||
@ -413,6 +422,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_to_graph_with_defaults(self):
|
def test_to_graph_with_defaults(self):
|
||||||
|
|
||||||
foo = 4
|
foo = 4
|
||||||
|
@ -22,12 +22,14 @@ from tensorflow.python.autograph.operators import control_flow
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ForLoopTest(test.TestCase):
|
class ForLoopTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
s = control_flow.for_stmt(
|
s = control_flow.for_stmt(
|
||||||
constant_op.constant([1, 2, 3, 4]),
|
constant_op.constant([1, 2, 3, 4]),
|
||||||
@ -45,6 +47,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
init_state=(0,))
|
init_state=(0,))
|
||||||
self.assertEqual(10, s)
|
self.assertEqual(10, s)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dataset(self):
|
def test_dataset(self):
|
||||||
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
|
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
|
||||||
s = control_flow.for_stmt(
|
s = control_flow.for_stmt(
|
||||||
@ -58,6 +61,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
|
|
||||||
class WhileLoopTest(test.TestCase):
|
class WhileLoopTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
n = constant_op.constant(5)
|
n = constant_op.constant(5)
|
||||||
results = control_flow.while_stmt(
|
results = control_flow.while_stmt(
|
||||||
@ -87,6 +91,7 @@ class IfStmtTest(test.TestCase):
|
|||||||
return control_flow.if_stmt(
|
return control_flow.if_stmt(
|
||||||
cond=cond, body=lambda: (1, 2), orelse=lambda: (-1, -2))
|
cond=cond, body=lambda: (1, 2), orelse=lambda: (-1, -2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||||
@ -98,6 +103,7 @@ class IfStmtTest(test.TestCase):
|
|||||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||||
self.assertEqual(-1, self.single_return_if_stmt(False))
|
self.assertEqual(-1, self.single_return_if_stmt(False))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_tensor_multiple_returns(self):
|
def test_tensor_multiple_returns(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.autograph.operators import data_structures
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -59,6 +60,7 @@ class ListTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_tf_tensor_list_new_illegal_input(self):
|
def test_tf_tensor_list_new_illegal_input(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
data_structures.tf_tensor_list_new([3, 4.0])
|
data_structures.tf_tensor_list_new([3, 4.0])
|
||||||
@ -104,6 +106,7 @@ class ListTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_append_tensorarray(self):
|
def test_append_tensorarray(self):
|
||||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||||
l1 = data_structures.list_append(l, 1)
|
l1 = data_structures.list_append(l, 1)
|
||||||
@ -154,6 +157,7 @@ class ListTest(test.TestCase):
|
|||||||
t = data_structures.list_stack(l, opts)
|
t = data_structures.list_stack(l, opts)
|
||||||
self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
|
self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_stack_tensor_list_empty(self):
|
def test_stack_tensor_list_empty(self):
|
||||||
l = list_ops.empty_tensor_list(
|
l = list_ops.empty_tensor_list(
|
||||||
element_shape=None, element_dtype=dtypes.variant)
|
element_shape=None, element_dtype=dtypes.variant)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.autograph.operators import exceptions
|
from tensorflow.python.autograph.operators import exceptions
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ class ExceptionsTest(test.TestCase):
|
|||||||
constant_op.constant(True), lambda: constant_op.constant('ignored'))
|
constant_op.constant(True), lambda: constant_op.constant('ignored'))
|
||||||
self.evaluate(t)
|
self.evaluate(t)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_assert_tf_triggered(self):
|
def test_assert_tf_triggered(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = exceptions.assert_stmt(
|
t = exceptions.assert_stmt(
|
||||||
@ -42,6 +44,7 @@ class ExceptionsTest(test.TestCase):
|
|||||||
'test message'):
|
'test message'):
|
||||||
self.evaluate(t)
|
self.evaluate(t)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_assert_tf_multiple_printed_values(self):
|
def test_assert_tf_multiple_printed_values(self):
|
||||||
two_tensors = [
|
two_tensors = [
|
||||||
constant_op.constant('test message'),
|
constant_op.constant('test message'),
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.operators import logical
|
from tensorflow.python.autograph.operators import logical
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ class LogicalOperatorsTest(test.TestCase):
|
|||||||
self.assertFalse(logical.and_(lambda: False, lambda: True))
|
self.assertFalse(logical.and_(lambda: False, lambda: True))
|
||||||
self.assertFalse(logical.and_(lambda: False, self.assertNotCalled))
|
self.assertFalse(logical.and_(lambda: False, self.assertNotCalled))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_and_tf(self):
|
def test_and_tf(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = logical.and_(self._tf_true, self._tf_true)
|
t = logical.and_(self._tf_true, self._tf_true)
|
||||||
@ -60,6 +62,7 @@ class LogicalOperatorsTest(test.TestCase):
|
|||||||
self.assertTrue(logical.or_(lambda: False, lambda: True))
|
self.assertTrue(logical.or_(lambda: False, lambda: True))
|
||||||
self.assertTrue(logical.or_(lambda: True, self.assertNotCalled))
|
self.assertTrue(logical.or_(lambda: True, self.assertNotCalled))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_or_tf(self):
|
def test_or_tf(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = logical.or_(self._tf_false, self._tf_true)
|
t = logical.or_(self._tf_false, self._tf_true)
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.autograph.operators import py_builtins
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -81,6 +82,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
py_builtins.len_(constant_op.constant(1))
|
py_builtins.len_(constant_op.constant(1))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_len_dynamic_shape(self):
|
def test_len_dynamic_shape(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
|
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
|
||||||
@ -91,6 +93,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
t = py_builtins.len_(p)
|
t = py_builtins.len_(p)
|
||||||
sess.run(t, {p: 1})
|
sess.run(t, {p: 1})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_print_tensors(self):
|
def test_print_tensors(self):
|
||||||
try:
|
try:
|
||||||
out_capturer = six.StringIO()
|
out_capturer = six.StringIO()
|
||||||
@ -101,6 +104,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_print_complex(self):
|
def test_print_complex(self):
|
||||||
try:
|
try:
|
||||||
out_capturer = six.StringIO()
|
out_capturer = six.StringIO()
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.autograph.utils.misc import alias_tensors
|
from tensorflow.python.autograph.utils.misc import alias_tensors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework.constant_op import constant
|
from tensorflow.python.framework.constant_op import constant
|
||||||
from tensorflow.python.ops.variables import Variable
|
from tensorflow.python.ops.variables import Variable
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -26,6 +27,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class MiscTest(test.TestCase):
|
class MiscTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_alias_single_tensor(self):
|
def test_alias_single_tensor(self):
|
||||||
a = constant(1)
|
a = constant(1)
|
||||||
|
|
||||||
@ -34,6 +36,7 @@ class MiscTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(1, self.evaluate(new_a))
|
self.assertEqual(1, self.evaluate(new_a))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_alias_tensors(self):
|
def test_alias_tensors(self):
|
||||||
a = constant(1)
|
a = constant(1)
|
||||||
v = Variable(2)
|
v = Variable(2)
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.client.session import Session
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework.constant_op import constant
|
from tensorflow.python.framework.constant_op import constant
|
||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
@ -34,6 +35,7 @@ class TensorListTest(test.TestCase):
|
|||||||
def _shape(self, shape_tuple):
|
def _shape(self, shape_tuple):
|
||||||
return constant(shape_tuple, dtypes.int32)
|
return constant(shape_tuple, dtypes.int32)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dynamic_list_append(self):
|
def test_dynamic_list_append(self):
|
||||||
l = []
|
l = []
|
||||||
l = tl.dynamic_list_append(l, 1)
|
l = tl.dynamic_list_append(l, 1)
|
||||||
@ -80,6 +82,7 @@ class TensorListTest(test.TestCase):
|
|||||||
l[0] = ops.convert_to_tensor(b)
|
l[0] = ops.convert_to_tensor(b)
|
||||||
self.assertEqual(l[0].numpy(), b.numpy())
|
self.assertEqual(l[0].numpy(), b.numpy())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_list_append_tf(self):
|
def test_list_append_tf(self):
|
||||||
a = constant(3.0)
|
a = constant(3.0)
|
||||||
l = tl.TensorList(a.shape, a.dtype)
|
l = tl.TensorList(a.shape, a.dtype)
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class TypeCheckTest(test.TestCase):
|
class TypeCheckTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_checks(self):
|
def test_checks(self):
|
||||||
self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3])))
|
self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3])))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
@ -188,6 +188,7 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
|||||||
r = sess.partial_run(h, [b], {})
|
r = sess.partial_run(h, [b], {})
|
||||||
self.assertEqual([6.0], r)
|
self.assertEqual([6.0], r)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInvalidPartialRunSetup(self):
|
def testInvalidPartialRunSetup(self):
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
x = array_ops.placeholder(dtypes.float32, shape=[])
|
x = array_ops.placeholder(dtypes.float32, shape=[])
|
||||||
@ -196,6 +197,7 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
|||||||
'specify at least one target to fetch or execute.'):
|
'specify at least one target to fetch or execute.'):
|
||||||
sess.partial_run_setup(fetches=[], feeds=[x])
|
sess.partial_run_setup(fetches=[], feeds=[x])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunSetupNoFeedsPassed(self):
|
def testPartialRunSetupNoFeedsPassed(self):
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
r1 = constant_op.constant([6.0])
|
r1 = constant_op.constant([6.0])
|
||||||
@ -204,80 +206,102 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
|||||||
result1 = sess.partial_run(h, r1)
|
result1 = sess.partial_run(h, r1)
|
||||||
self.assertEqual([6.0], result1)
|
self.assertEqual([6.0], result1)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunDirect(self):
|
def testPartialRunDirect(self):
|
||||||
self.RunTestPartialRun(session.Session())
|
self.RunTestPartialRun(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunIncompleteDirect(self):
|
def testPartialRunIncompleteDirect(self):
|
||||||
self.RunTestPartialRunIncomplete(session.Session())
|
self.RunTestPartialRunIncomplete(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcurrentPartialRunDirect(self):
|
def testConcurrentPartialRunDirect(self):
|
||||||
self.RunTestConcurrentPartialRun(session.Session())
|
self.RunTestConcurrentPartialRun(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testManyPartialRunDirect(self):
|
def testManyPartialRunDirect(self):
|
||||||
self.RunTestManyPartialRun(session.Session())
|
self.RunTestManyPartialRun(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRunAndPartialRunDirect(self):
|
def testRunAndPartialRunDirect(self):
|
||||||
self.RunTestRunAndPartialRun(session.Session())
|
self.RunTestRunAndPartialRun(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
|
def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
|
||||||
self.RunTestPartialRunMissingPlaceholderFeedException(session.Session())
|
self.RunTestPartialRunMissingPlaceholderFeedException(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunUnspecifiedFeedDirect(self):
|
def testPartialRunUnspecifiedFeedDirect(self):
|
||||||
self.RunTestPartialRunUnspecifiedFeed(session.Session())
|
self.RunTestPartialRunUnspecifiedFeed(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunUnspecifiedFetchDirect(self):
|
def testPartialRunUnspecifiedFetchDirect(self):
|
||||||
self.RunTestPartialRunUnspecifiedFetch(session.Session())
|
self.RunTestPartialRunUnspecifiedFetch(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunAlreadyFedDirect(self):
|
def testPartialRunAlreadyFedDirect(self):
|
||||||
self.RunTestPartialRunAlreadyFed(session.Session())
|
self.RunTestPartialRunAlreadyFed(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunAlreadyFetchedDirect(self):
|
def testPartialRunAlreadyFetchedDirect(self):
|
||||||
self.RunTestPartialRunAlreadyFetched(session.Session())
|
self.RunTestPartialRunAlreadyFetched(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunEmptyFetchesDirect(self):
|
def testPartialRunEmptyFetchesDirect(self):
|
||||||
self.RunTestPartialRunEmptyFetches(session.Session())
|
self.RunTestPartialRunEmptyFetches(session.Session())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunDist(self):
|
def testPartialRunDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRun(session.Session(server.target))
|
self.RunTestPartialRun(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunIncompleteDist(self):
|
def testPartialRunIncompleteDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunIncomplete(session.Session(server.target))
|
self.RunTestPartialRunIncomplete(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcurrentPartialRunDist(self):
|
def testConcurrentPartialRunDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestConcurrentPartialRun(session.Session(server.target))
|
self.RunTestConcurrentPartialRun(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testManyPartialRunDist(self):
|
def testManyPartialRunDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestManyPartialRun(session.Session(server.target))
|
self.RunTestManyPartialRun(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRunAndPartialRunDist(self):
|
def testRunAndPartialRunDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestRunAndPartialRun(session.Session(server.target))
|
self.RunTestRunAndPartialRun(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunMissingPlaceholderFeedExceptionDist(self):
|
def testPartialRunMissingPlaceholderFeedExceptionDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunMissingPlaceholderFeedException(
|
self.RunTestPartialRunMissingPlaceholderFeedException(
|
||||||
session.Session(server.target))
|
session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunUnspecifiedFeedDist(self):
|
def testPartialRunUnspecifiedFeedDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunUnspecifiedFeed(session.Session(server.target))
|
self.RunTestPartialRunUnspecifiedFeed(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunUnspecifiedFetchDist(self):
|
def testPartialRunUnspecifiedFetchDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunUnspecifiedFetch(session.Session(server.target))
|
self.RunTestPartialRunUnspecifiedFetch(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunAlreadyFedDist(self):
|
def testPartialRunAlreadyFedDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunAlreadyFed(session.Session(server.target))
|
self.RunTestPartialRunAlreadyFed(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunAlreadyFetchedDist(self):
|
def testPartialRunAlreadyFetchedDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunAlreadyFetched(session.Session(server.target))
|
self.RunTestPartialRunAlreadyFetched(session.Session(server.target))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialRunEmptyFetchesDist(self):
|
def testPartialRunEmptyFetchesDist(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self.RunTestPartialRunEmptyFetches(session.Session(server.target))
|
self.RunTestPartialRunEmptyFetches(session.Session(server.target))
|
||||||
|
@ -57,6 +57,7 @@ class TimelineTest(test.TestCase):
|
|||||||
ctf = tl.generate_chrome_trace_format()
|
ctf = tl.generate_chrome_trace_format()
|
||||||
self._validateTrace(ctf)
|
self._validateTrace(ctf)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTimelineCpu(self):
|
def testTimelineCpu(self):
|
||||||
run_options = config_pb2.RunOptions(
|
run_options = config_pb2.RunOptions(
|
||||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.util import compat as util_compat
|
|||||||
|
|
||||||
class CopyToDeviceTest(test_base.DatasetTestBase):
|
class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToDevice(self):
|
def testCopyToDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -61,6 +62,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToDeviceInt32(self):
|
def testCopyToDeviceInt32(self):
|
||||||
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
|
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -86,6 +88,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToSameDevice(self):
|
def testCopyToSameDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -112,6 +115,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToDeviceWithPrefetch(self):
|
def testCopyToDeviceWithPrefetch(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -138,6 +142,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyDictToDevice(self):
|
def testCopyDictToDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -164,6 +169,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyDictToDeviceWithPrefetch(self):
|
def testCopyDictToDeviceWithPrefetch(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -190,6 +196,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopySparseTensorsToDevice(self):
|
def testCopySparseTensorsToDevice(self):
|
||||||
|
|
||||||
def make_tensor(i):
|
def make_tensor(i):
|
||||||
@ -224,6 +231,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopySparseTensorsToDeviceWithPrefetch(self):
|
def testCopySparseTensorsToDeviceWithPrefetch(self):
|
||||||
|
|
||||||
def make_tensor(i):
|
def make_tensor(i):
|
||||||
@ -426,6 +434,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToDeviceWithReInit(self):
|
def testCopyToDeviceWithReInit(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -456,6 +465,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCopyToDeviceWithReInitAndPrefetch(self):
|
def testCopyToDeviceWithReInitAndPrefetch(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.data.experimental.ops import counter
|
from tensorflow.python.data.experimental.ops import counter
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class CounterTest(test_base.DatasetTestBase):
|
class CounterTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCounter(self):
|
def testCounter(self):
|
||||||
"""Test dataset construction using `count`."""
|
"""Test dataset construction using `count`."""
|
||||||
iterator = (counter.Counter(start=3, step=4)
|
iterator = (counter.Counter(start=3, step=4)
|
||||||
|
@ -24,12 +24,14 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDenseToSparseBatchDataset(self):
|
def testDenseToSparseBatchDataset(self):
|
||||||
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -58,6 +60,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
||||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -91,12 +94,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||||
input_tensor = array_ops.constant([[1]])
|
input_tensor = array_ops.constant([[1]])
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
||||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||||
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
|
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDenseToSparseBatchDatasetShapeErrors(self):
|
def testDenseToSparseBatchDatasetShapeErrors(self):
|
||||||
input_tensor = array_ops.placeholder(dtypes.int32)
|
input_tensor = array_ops.placeholder(dtypes.int32)
|
||||||
iterator = (
|
iterator = (
|
||||||
|
@ -24,11 +24,13 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
|
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
|
||||||
input_datasets = [
|
input_datasets = [
|
||||||
@ -77,6 +79,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSampleFromDatasets(self):
|
def testSampleFromDatasets(self):
|
||||||
random_seed.set_random_seed(1619)
|
random_seed.set_random_seed(1619)
|
||||||
num_samples = 5000
|
num_samples = 5000
|
||||||
@ -96,6 +99,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
|||||||
freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
|
freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
|
||||||
self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
|
self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSelectFromDatasets(self):
|
def testSelectFromDatasets(self):
|
||||||
words = [b"foo", b"bar", b"baz"]
|
words = [b"foo", b"bar", b"baz"]
|
||||||
datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
|
datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
|
||||||
|
@ -24,11 +24,13 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class EnumerateDatasetTest(test_base.DatasetTestBase):
|
class EnumerateDatasetTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEnumerateDataset(self):
|
def testEnumerateDataset(self):
|
||||||
components = (["a", "b"], [1, 2], [37.0, 38])
|
components = (["a", "b"], [1, 2], [37.0, 38])
|
||||||
start = constant_op.constant(20, dtype=dtypes.int64)
|
start = constant_op.constant(20, dtype=dtypes.int64)
|
||||||
|
@ -107,11 +107,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual(elem, [5.0])
|
self.assertEqual(elem, [5.0])
|
||||||
self.evaluate(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSameDeviceCPU(self):
|
def testSameDeviceCPU(self):
|
||||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||||
"/job:localhost/replica:0/task:0/cpu:0",
|
"/job:localhost/replica:0/task:0/cpu:0",
|
||||||
"/job:localhost/replica:0/task:0/cpu:0")
|
"/job:localhost/replica:0/task:0/cpu:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDifferentDeviceCPU(self):
|
def testDifferentDeviceCPU(self):
|
||||||
self._prefetch_fn_helper_one_shot("diff_device_cpu",
|
self._prefetch_fn_helper_one_shot("diff_device_cpu",
|
||||||
"/job:localhost/replica:0/task:0/cpu:0",
|
"/job:localhost/replica:0/task:0/cpu:0",
|
||||||
@ -125,6 +127,7 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
"/job:localhost/replica:0/task:0/cpu:0",
|
"/job:localhost/replica:0/task:0/cpu:0",
|
||||||
"/job:localhost/replica:0/task:0/gpu:0")
|
"/job:localhost/replica:0/task:0/gpu:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReinitialization(self):
|
def testReinitialization(self):
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
|
|
||||||
@ -165,6 +168,7 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual(elem, [5.0])
|
self.assertEqual(elem, [5.0])
|
||||||
self.evaluate(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReinitializationOutOfRange(self):
|
def testReinitializationOutOfRange(self):
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
|
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
|
||||||
"Dataset had more than one element."),
|
"Dataset had more than one element."),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
|
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
|
||||||
skip_t = array_ops.placeholder(dtypes.int64, shape=[])
|
skip_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
take_t = array_ops.placeholder(dtypes.int64, shape=[])
|
take_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -44,6 +45,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSum(self):
|
def testSum(self):
|
||||||
reducer = grouping.Reducer(
|
reducer = grouping.Reducer(
|
||||||
init_func=lambda _: np.int64(0),
|
init_func=lambda _: np.int64(0),
|
||||||
@ -55,6 +57,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
self.checkResults(
|
self.checkResults(
|
||||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAverage(self):
|
def testAverage(self):
|
||||||
|
|
||||||
def reduce_fn(x, y):
|
def reduce_fn(x, y):
|
||||||
@ -72,6 +75,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
self.checkResults(
|
self.checkResults(
|
||||||
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
|
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
|
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
|
||||||
reducer = grouping.Reducer(
|
reducer = grouping.Reducer(
|
||||||
@ -88,6 +92,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
shapes=tensor_shape.scalar(),
|
shapes=tensor_shape.scalar(),
|
||||||
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
|
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSparseSum(self):
|
def testSparseSum(self):
|
||||||
def _sparse(i):
|
def _sparse(i):
|
||||||
return sparse_tensor.SparseTensorValue(
|
return sparse_tensor.SparseTensorValue(
|
||||||
@ -105,6 +110,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
self.checkResults(
|
self.checkResults(
|
||||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testChangingStateShape(self):
|
def testChangingStateShape(self):
|
||||||
|
|
||||||
def reduce_fn(x, _):
|
def reduce_fn(x, _):
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
@ -49,6 +50,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
|
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
|
||||||
[None]), tensor_shape.TensorShape([3])))))
|
[None]), tensor_shape.TensorShape([3])))))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSingleBucket(self):
|
def testSingleBucket(self):
|
||||||
|
|
||||||
def _map_fn(v):
|
def _map_fn(v):
|
||||||
@ -84,6 +86,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
||||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEvenOddBuckets(self):
|
def testEvenOddBuckets(self):
|
||||||
|
|
||||||
def _map_fn(v):
|
def _map_fn(v):
|
||||||
@ -141,6 +144,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
||||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||||
|
|
||||||
def _map_fn(v):
|
def _map_fn(v):
|
||||||
@ -188,6 +192,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDynamicWindowSize(self):
|
def testDynamicWindowSize(self):
|
||||||
components = np.arange(100).astype(np.int64)
|
components = np.arange(100).astype(np.int64)
|
||||||
|
|
||||||
@ -221,6 +226,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
self.assertEqual(batches, 15)
|
self.assertEqual(batches, 15)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSimple(self):
|
def testSimple(self):
|
||||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -248,6 +254,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertGreaterEqual(num_full_batches, 24)
|
self.assertGreaterEqual(num_full_batches, 24)
|
||||||
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testImmediateOutput(self):
|
def testImmediateOutput(self):
|
||||||
components = np.array(
|
components = np.array(
|
||||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||||
@ -270,6 +277,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
||||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSmallGroups(self):
|
def testSmallGroups(self):
|
||||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -288,6 +296,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
||||||
self.assertAllEqual([1], self.evaluate(get_next))
|
self.assertAllEqual([1], self.evaluate(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
iterator = (
|
iterator = (
|
||||||
dataset_ops.Dataset.range(4).apply(
|
dataset_ops.Dataset.range(4).apply(
|
||||||
@ -303,6 +312,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
"Window size must be greater than zero, but got 0."):
|
"Window size must be greater than zero, but got 0."):
|
||||||
print(self.evaluate(get_next))
|
print(self.evaluate(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReduceFuncError(self):
|
def testReduceFuncError(self):
|
||||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||||
|
|
||||||
@ -327,6 +337,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.data.experimental.ops import error_ops
|
|||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -35,6 +36,7 @@ _NUMPY_RANDOM_SEED = 42
|
|||||||
|
|
||||||
class IgnoreErrorsTest(test_base.DatasetTestBase):
|
class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapIgnoreError(self):
|
def testMapIgnoreError(self):
|
||||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||||
|
|
||||||
@ -53,6 +55,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testParallelMapIgnoreError(self):
|
def testParallelMapIgnoreError(self):
|
||||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||||
|
|
||||||
@ -71,6 +74,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReadFileIgnoreError(self):
|
def testReadFileIgnoreError(self):
|
||||||
|
|
||||||
def write_string_to_file(value, filename):
|
def write_string_to_file(value, filename):
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -31,6 +32,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLowLevelIndexedDatasetOps(self):
|
def testLowLevelIndexedDatasetOps(self):
|
||||||
identity = ged_ops.experimental_identity_indexed_dataset(
|
identity = ged_ops.experimental_identity_indexed_dataset(
|
||||||
ops.convert_to_tensor(16, dtype=dtypes.uint64))
|
ops.convert_to_tensor(16, dtype=dtypes.uint64))
|
||||||
@ -49,6 +51,7 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
|||||||
self.evaluate(materialize)
|
self.evaluate(materialize)
|
||||||
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
|
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentityIndexedDataset(self):
|
def testIdentityIndexedDataset(self):
|
||||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||||
materialized = ds.materialize()
|
materialized = ds.materialize()
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.data.util import nest
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -98,6 +99,7 @@ class MakeBatchedFeaturesDatasetTest(
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self._next_actual_batch(sess)
|
self._next_actual_batch(sess)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReadWithEquivalentDataset(self):
|
def testReadWithEquivalentDataset(self):
|
||||||
features = {
|
features = {
|
||||||
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -127,6 +128,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
|
self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
|
||||||
expected_output, expected_keys)
|
expected_output, expected_keys)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset(self):
|
def testMakeCSVDataset(self):
|
||||||
"""Tests making a CSV dataset with keys and defaults provided."""
|
"""Tests making a CSV dataset with keys and defaults provided."""
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
@ -158,6 +160,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
column_defaults=record_defaults,
|
column_defaults=record_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withBatchSizeAndEpochs(self):
|
def testMakeCSVDataset_withBatchSizeAndEpochs(self):
|
||||||
"""Tests making a CSV dataset with keys and defaults provided."""
|
"""Tests making a CSV dataset with keys and defaults provided."""
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
@ -189,6 +192,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
column_defaults=record_defaults,
|
column_defaults=record_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withCompressionType(self):
|
def testMakeCSVDataset_withCompressionType(self):
|
||||||
"""Tests `compression_type` argument."""
|
"""Tests `compression_type` argument."""
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
@ -257,6 +261,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
label_name="not_a_real_label",
|
label_name="not_a_real_label",
|
||||||
column_names=column_names)
|
column_names=column_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withNoLabel(self):
|
def testMakeCSVDataset_withNoLabel(self):
|
||||||
"""Tests making a CSV dataset with no label provided."""
|
"""Tests making a CSV dataset with no label provided."""
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
@ -286,6 +291,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
column_defaults=record_defaults,
|
column_defaults=record_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withNoHeader(self):
|
def testMakeCSVDataset_withNoHeader(self):
|
||||||
"""Tests that datasets can be created from CSV files with no header line.
|
"""Tests that datasets can be created from CSV files with no header line.
|
||||||
"""
|
"""
|
||||||
@ -347,6 +353,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
column_defaults=record_defaults,
|
column_defaults=record_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withNoColNames(self):
|
def testMakeCSVDataset_withNoColNames(self):
|
||||||
"""Tests that datasets can be created when column names are not specified.
|
"""Tests that datasets can be created when column names are not specified.
|
||||||
|
|
||||||
@ -451,6 +458,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
header=True,
|
header=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withSelectCols(self):
|
def testMakeCSVDataset_withSelectCols(self):
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
constant_op.constant([], dtypes.int32),
|
constant_op.constant([], dtypes.int32),
|
||||||
@ -557,6 +565,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
label_name=None,
|
label_name=None,
|
||||||
select_columns=["invalid_col_name"])
|
select_columns=["invalid_col_name"])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeCSVDataset_withShuffle(self):
|
def testMakeCSVDataset_withShuffle(self):
|
||||||
record_defaults = [
|
record_defaults = [
|
||||||
constant_op.constant([], dtypes.int32),
|
constant_op.constant([], dtypes.int32),
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -48,6 +49,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("ParallelCallsNUMA", 2, None, True),
|
("ParallelCallsNUMA", 2, None, True),
|
||||||
("ParallelBatchesNUMA", None, 10, True),
|
("ParallelBatchesNUMA", None, 10, True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches,
|
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches,
|
||||||
numa_aware):
|
numa_aware):
|
||||||
"""Test a dataset that maps a TF function across its input elements."""
|
"""Test a dataset that maps a TF function across its input elements."""
|
||||||
@ -132,6 +134,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("EvenNUMA", False, True),
|
("EvenNUMA", False, True),
|
||||||
("UnevenNUMA", True, True),
|
("UnevenNUMA", True, True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
|
def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
|
||||||
dataset = (
|
dataset = (
|
||||||
dataset_ops.Dataset.range(10).apply(
|
dataset_ops.Dataset.range(10).apply(
|
||||||
@ -163,6 +166,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchYieldsPartialBatch(self, numa_aware):
|
def testMapAndBatchYieldsPartialBatch(self, numa_aware):
|
||||||
dataset = (
|
dataset = (
|
||||||
dataset_ops.Dataset.range(10).apply(
|
dataset_ops.Dataset.range(10).apply(
|
||||||
@ -187,6 +191,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchParallelGetNext(self, numa_aware):
|
def testMapAndBatchParallelGetNext(self, numa_aware):
|
||||||
dataset = dataset_ops.Dataset.range(50000).apply(
|
dataset = dataset_ops.Dataset.range(50000).apply(
|
||||||
batching.map_and_batch(lambda x: x, batch_size=100))
|
batching.map_and_batch(lambda x: x, batch_size=100))
|
||||||
@ -214,6 +219,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
|
def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
|
||||||
dataset = dataset_ops.Dataset.range(49999).apply(
|
dataset = dataset_ops.Dataset.range(49999).apply(
|
||||||
batching.map_and_batch(
|
batching.map_and_batch(
|
||||||
@ -243,6 +249,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchSparse(self, numa_aware):
|
def testMapAndBatchSparse(self, numa_aware):
|
||||||
|
|
||||||
def _sparse(i):
|
def _sparse(i):
|
||||||
@ -277,6 +284,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchFails(self, numa_aware):
|
def testMapAndBatchFails(self, numa_aware):
|
||||||
"""Test a dataset that maps a TF function across its input elements."""
|
"""Test a dataset that maps a TF function across its input elements."""
|
||||||
dataset = dataset_ops.Dataset.from_tensors(
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
@ -299,6 +307,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchShapeMismatch(self, numa_aware):
|
def testMapAndBatchShapeMismatch(self, numa_aware):
|
||||||
"""Test a dataset that maps a TF function across its input elements."""
|
"""Test a dataset that maps a TF function across its input elements."""
|
||||||
|
|
||||||
@ -370,6 +379,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("5NUMA", 95, True),
|
("5NUMA", 95, True),
|
||||||
("6NUMA", 99, True),
|
("6NUMA", 99, True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):
|
def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):
|
||||||
|
|
||||||
def raising_py_fn(i):
|
def raising_py_fn(i):
|
||||||
@ -452,6 +462,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Swap", (None, None), lambda x, y: (y, x), None),
|
("Swap", (None, None), lambda x, y: (y, x), None),
|
||||||
("Project", (None, None), lambda x, y: x, None),
|
("Project", (None, None), lambda x, y: x, None),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShortCircuit(self, structure, map_fn, num_parallel_calls):
|
def testShortCircuit(self, structure, map_fn, num_parallel_calls):
|
||||||
dataset = self.structuredDataset(structure).repeat().apply(
|
dataset = self.structuredDataset(structure).repeat().apply(
|
||||||
batching.map_and_batch(map_fn, batch_size=10))
|
batching.map_and_batch(map_fn, batch_size=10))
|
||||||
@ -466,6 +477,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.run(self.structuredElement(structure, shape=[10])))
|
sess.run(self.structuredElement(structure, shape=[10])))
|
||||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShortCircuitCapturedInput(self):
|
def testShortCircuitCapturedInput(self):
|
||||||
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
dataset = self.structuredDataset(None).repeat().apply(
|
dataset = self.structuredDataset(None).repeat().apply(
|
||||||
@ -481,6 +493,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
("Normal", False),
|
("Normal", False),
|
||||||
("NUMA", True),
|
("NUMA", True),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchControlFlow(self, numa_aware):
|
def testMapAndBatchControlFlow(self, numa_aware):
|
||||||
|
|
||||||
def map_fn(x):
|
def map_fn(x):
|
||||||
|
@ -24,6 +24,7 @@ import tempfile
|
|||||||
from tensorflow.python.data.experimental.ops import matching_files
|
from tensorflow.python.data.experimental.ops import matching_files
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
@ -40,6 +41,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
open(os.path.join(self.tmp_dir, filename), 'a').close()
|
open(os.path.join(self.tmp_dir, filename), 'a').close()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNonExistingDirectory(self):
|
def testNonExistingDirectory(self):
|
||||||
"""Test the MatchingFiles dataset with a non-existing directory."""
|
"""Test the MatchingFiles dataset with a non-existing directory."""
|
||||||
|
|
||||||
@ -51,6 +53,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEmptyDirectory(self):
|
def testEmptyDirectory(self):
|
||||||
"""Test the MatchingFiles dataset with an empty directory."""
|
"""Test the MatchingFiles dataset with an empty directory."""
|
||||||
|
|
||||||
@ -61,6 +64,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSimpleDirectory(self):
|
def testSimpleDirectory(self):
|
||||||
"""Test the MatchingFiles dataset with a simple directory."""
|
"""Test the MatchingFiles dataset with a simple directory."""
|
||||||
|
|
||||||
@ -83,6 +87,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFileSuffixes(self):
|
def testFileSuffixes(self):
|
||||||
"""Test the MatchingFiles dataset using the suffixes of filename."""
|
"""Test the MatchingFiles dataset using the suffixes of filename."""
|
||||||
|
|
||||||
@ -104,6 +109,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFileMiddles(self):
|
def testFileMiddles(self):
|
||||||
"""Test the MatchingFiles dataset using the middles of filename."""
|
"""Test the MatchingFiles dataset using the middles of filename."""
|
||||||
|
|
||||||
@ -125,6 +131,7 @@ class MatchingFilesTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNestedDirectories(self):
|
def testNestedDirectories(self):
|
||||||
"""Test the MatchingFiles dataset with nested directories."""
|
"""Test the MatchingFiles dataset with nested directories."""
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -81,6 +82,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
|||||||
("8", 4, 1),
|
("8", 4, 1),
|
||||||
("9", 4, 4),
|
("9", 4, 4),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
|
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
|
||||||
|
|
||||||
def override_threadpool_fn(dataset):
|
def override_threadpool_fn(dataset):
|
||||||
@ -107,6 +109,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
|||||||
("11", 4, 4),
|
("11", 4, 4),
|
||||||
("12", None, None),
|
("12", None, None),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNumThreads(self, num_threads, max_intra_op_parallelism):
|
def testNumThreads(self, num_threads, max_intra_op_parallelism):
|
||||||
|
|
||||||
def override_threadpool_fn(dataset):
|
def override_threadpool_fn(dataset):
|
||||||
|
@ -144,6 +144,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
|||||||
expected_values=expected_output,
|
expected_values=expected_output,
|
||||||
create_iterator_twice=True)
|
create_iterator_twice=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEmptySerializedWithoutDefaultsShouldFail(self):
|
def testEmptySerializedWithoutDefaultsShouldFail(self):
|
||||||
input_features = {
|
input_features = {
|
||||||
"st_a":
|
"st_a":
|
||||||
@ -177,6 +178,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
|||||||
expected_err=(errors_impl.InvalidArgumentError,
|
expected_err=(errors_impl.InvalidArgumentError,
|
||||||
"Feature: c \\(data type: float\\) is required"))
|
"Feature: c \\(data type: float\\) is required"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDenseNotMatchingShapeShouldFail(self):
|
def testDenseNotMatchingShapeShouldFail(self):
|
||||||
original = [
|
original = [
|
||||||
example(features=features({
|
example(features=features({
|
||||||
@ -669,6 +671,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
|||||||
for batch_size in (1, 10, 20, 100, 256):
|
for batch_size in (1, 10, 20, 100, 256):
|
||||||
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
|
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerSerializedShapeMismatch(self):
|
def testSkipEagerSerializedShapeMismatch(self):
|
||||||
aname = "a"
|
aname = "a"
|
||||||
bname = "b"
|
bname = "b"
|
||||||
@ -706,6 +709,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
|||||||
expected_err=(ValueError,
|
expected_err=(ValueError,
|
||||||
"Cannot reshape a tensor with 0 elements to shape"))
|
"Cannot reshape a tensor with 0 elements to shape"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSerializedContainingVarLenDense(self):
|
def testSerializedContainingVarLenDense(self):
|
||||||
aname = "a"
|
aname = "a"
|
||||||
bname = "b"
|
bname = "b"
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchToDevice(self):
|
def testPrefetchToDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -61,6 +62,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchToSameDevice(self):
|
def testPrefetchToSameDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -91,6 +93,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchDictToDevice(self):
|
def testPrefetchDictToDevice(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
@ -121,6 +124,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchSparseTensorsToDevice(self):
|
def testPrefetchSparseTensorsToDevice(self):
|
||||||
def make_tensor(i):
|
def make_tensor(i):
|
||||||
return sparse_tensor.SparseTensorValue(
|
return sparse_tensor.SparseTensorValue(
|
||||||
@ -174,6 +178,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchToDeviceWithReInit(self):
|
def testPrefetchToDeviceWithReInit(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
device_dataset = host_dataset.apply(
|
device_dataset = host_dataset.apply(
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
@ -63,6 +64,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("InitialDistributionKnown", True),
|
("InitialDistributionKnown", True),
|
||||||
("InitialDistributionUnknown", False))
|
("InitialDistributionUnknown", False))
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDistribution(self, initial_known):
|
def testDistribution(self, initial_known):
|
||||||
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
|
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
|
||||||
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
|
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
|
||||||
@ -97,6 +99,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("OnlyInitial", True),
|
("OnlyInitial", True),
|
||||||
("NotInitial", False))
|
("NotInitial", False))
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
|
def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
|
||||||
init_dist = [0.5, 0.5]
|
init_dist = [0.5, 0.5]
|
||||||
target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
|
target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
|
||||||
@ -122,6 +125,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
while True:
|
while True:
|
||||||
returned.append(sess.run(get_next))
|
returned.append(sess.run(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRandomClasses(self):
|
def testRandomClasses(self):
|
||||||
init_dist = [0.25, 0.25, 0.25, 0.25]
|
init_dist = [0.25, 0.25, 0.25, 0.25]
|
||||||
target_dist = [0.0, 0.0, 0.0, 1.0]
|
target_dist = [0.0, 0.0, 0.0, 1.0]
|
||||||
|
@ -22,12 +22,14 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class RestructuredDatasetTest(test_base.DatasetTestBase):
|
class RestructuredDatasetTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRestructureDataset(self):
|
def testRestructureDataset(self):
|
||||||
components = (array_ops.placeholder(dtypes.int32),
|
components = (array_ops.placeholder(dtypes.int32),
|
||||||
(array_ops.placeholder(dtypes.int32, shape=[None]),
|
(array_ops.placeholder(dtypes.int32, shape=[None]),
|
||||||
|
@ -40,6 +40,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
|
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
|
||||||
scan_ops.scan(start, scan_fn))
|
scan_ops.scan(start, scan_fn))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCount(self):
|
def testCount(self):
|
||||||
def make_scan_fn(step):
|
def make_scan_fn(step):
|
||||||
return lambda state, _: (state + step, state)
|
return lambda state, _: (state + step, state)
|
||||||
@ -83,6 +84,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual(5, self.evaluate(next_element()))
|
self.assertEqual(5, self.evaluate(next_element()))
|
||||||
self.assertEqual(8, self.evaluate(next_element()))
|
self.assertEqual(8, self.evaluate(next_element()))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSparseCount(self):
|
def testSparseCount(self):
|
||||||
def _sparse(i):
|
def _sparse(i):
|
||||||
return sparse_tensor.SparseTensorValue(
|
return sparse_tensor.SparseTensorValue(
|
||||||
@ -114,6 +116,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testChangingStateShape(self):
|
def testChangingStateShape(self):
|
||||||
# Test the fixed-point shape invariant calculations: start with
|
# Test the fixed-point shape invariant calculations: start with
|
||||||
# initial values with known shapes, and use a scan function that
|
# initial values with known shapes, and use a scan function that
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.estimator import model_fn
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -68,6 +69,7 @@ class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
|
|||||||
def _build_iterator_saver_hook(self, est):
|
def _build_iterator_saver_hook(self, est):
|
||||||
return iterator_ops.CheckpointInputPipelineHook(est)
|
return iterator_ops.CheckpointInputPipelineHook(est)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReturnDatasetFromInputFn(self):
|
def testReturnDatasetFromInputFn(self):
|
||||||
|
|
||||||
def _input_fn():
|
def _input_fn():
|
||||||
@ -80,6 +82,7 @@ class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
|
|||||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBuildIteratorInInputFn(self):
|
def testBuildIteratorInInputFn(self):
|
||||||
|
|
||||||
def _input_fn():
|
def _input_fn():
|
||||||
@ -94,6 +97,7 @@ class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
|
|||||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDoNotRestore(self):
|
def testDoNotRestore(self):
|
||||||
|
|
||||||
def _input_fn():
|
def _input_fn():
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
|||||||
self.evaluate(get_next)
|
self.evaluate(get_next)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCorrectOutput(self):
|
def testCorrectOutput(self):
|
||||||
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||||
self.assertSequenceEqual(
|
self.assertSequenceEqual(
|
||||||
@ -52,6 +54,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
|||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
|
self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReshuffling(self):
|
def testReshuffling(self):
|
||||||
# Check that the output orders of different epochs are indeed different.
|
# Check that the output orders of different epochs are indeed different.
|
||||||
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||||
@ -60,17 +63,20 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
|||||||
epoch2 = output[(i + 1) * 20:(i + 2) * 20]
|
epoch2 = output[(i + 1) * 20:(i + 2) * 20]
|
||||||
self.assertNotEqual(epoch1, epoch2)
|
self.assertNotEqual(epoch1, epoch2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSameOrderForSameSeeds(self):
|
def testSameOrderForSameSeeds(self):
|
||||||
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||||
output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||||
self.assertEqual(output1, output2)
|
self.assertEqual(output1, output2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDifferentOrderForDifferentSeeds(self):
|
def testDifferentOrderForDifferentSeeds(self):
|
||||||
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||||
output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
|
output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
|
||||||
self.assertNotEqual(output1, output2)
|
self.assertNotEqual(output1, output2)
|
||||||
self.assertEqual(sorted(output1), sorted(output2))
|
self.assertEqual(sorted(output1), sorted(output2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCountNone(self):
|
def testCountNone(self):
|
||||||
output1 = self._gen_outputs(
|
output1 = self._gen_outputs(
|
||||||
lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
|
lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
|
||||||
@ -79,6 +85,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
|||||||
self.assertNotEqual(output1, output2)
|
self.assertNotEqual(output1, output2)
|
||||||
self.assertEqual(sorted(output1), sorted(output2))
|
self.assertEqual(sorted(output1), sorted(output2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCountMinusOne(self):
|
def testCountMinusOne(self):
|
||||||
output1 = self._gen_outputs(
|
output1 = self._gen_outputs(
|
||||||
lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
|
lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.ops import sleep
|
|||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
_NUMPY_RANDOM_SEED = 42
|
_NUMPY_RANDOM_SEED = 42
|
||||||
@ -30,6 +31,7 @@ _NUMPY_RANDOM_SEED = 42
|
|||||||
|
|
||||||
class SleepTest(test_base.DatasetTestBase):
|
class SleepTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSleep(self):
|
def testSleep(self):
|
||||||
sleep_microseconds = 100
|
sleep_microseconds = 100
|
||||||
dataset = dataset_ops.Dataset.range(10).apply(
|
dataset = dataset_ops.Dataset.range(10).apply(
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.data.experimental.ops import stats_options
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -59,6 +60,7 @@ def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
|
|||||||
)
|
)
|
||||||
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBytesProduced(self, dataset_transformation):
|
def testBytesProduced(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).map(
|
dataset = dataset_ops.Dataset.range(100).map(
|
||||||
@ -85,6 +87,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLatencyStats(self, dataset_transformation):
|
def testLatencyStats(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -105,6 +108,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
self.evaluate(summary_t), "record_latency", 100.0)
|
self.evaluate(summary_t), "record_latency", 100.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).map(
|
dataset = dataset_ops.Dataset.range(100).map(
|
||||||
@ -132,6 +136,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||||
100)
|
100)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrefetchBufferScalars(self, dataset_transformation):
|
def testPrefetchBufferScalars(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(10).map(
|
dataset = dataset_ops.Dataset.range(10).map(
|
||||||
@ -154,6 +159,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFilteredElementsStats(self, dataset_transformation):
|
def testFilteredElementsStats(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(101).filter(
|
dataset = dataset_ops.Dataset.range(101).filter(
|
||||||
@ -180,6 +186,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasScalarValue(
|
self._assertSummaryHasScalarValue(
|
||||||
self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
|
self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapBufferUtilization(self, dataset_transformation):
|
def testMapBufferUtilization(self, dataset_transformation):
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
@ -194,6 +201,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
dataset_transformation,
|
dataset_transformation,
|
||||||
function_processing_time=True)
|
function_processing_time=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAutoTuneBufferUtilization(self, dataset_transformation):
|
def testMapAutoTuneBufferUtilization(self, dataset_transformation):
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
@ -211,6 +219,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
dataset_transformation,
|
dataset_transformation,
|
||||||
function_processing_time=True)
|
function_processing_time=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
|
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
@ -227,6 +236,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._testParallelCallsStats(dataset_fn, "ParallelInterleaveV2", 10,
|
self._testParallelCallsStats(dataset_fn, "ParallelInterleaveV2", 10,
|
||||||
dataset_transformation)
|
dataset_transformation)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
|
def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
@ -248,6 +258,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
check_elements=False,
|
check_elements=False,
|
||||||
function_processing_time=True)
|
function_processing_time=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReinitialize(self, dataset_transformation):
|
def testReinitialize(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -270,6 +281,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
self.evaluate(summary_t), "record_latency", (j + 1) * 100.0)
|
self.evaluate(summary_t), "record_latency", (j + 1) * 100.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoAggregatorRegistered(self, dataset_transformation):
|
def testNoAggregatorRegistered(self, dataset_transformation):
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
stats_ops.latency_stats("record_latency"))
|
stats_ops.latency_stats("record_latency"))
|
||||||
@ -283,6 +295,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleTags(self, dataset_transformation):
|
def testMultipleTags(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -308,6 +321,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
self.evaluate(summary_t), "record_latency_2", 100.0)
|
self.evaluate(summary_t), "record_latency_2", 100.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRepeatedTags(self, dataset_transformation):
|
def testRepeatedTags(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -329,6 +343,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
self.evaluate(summary_t), "record_latency", 200.0)
|
self.evaluate(summary_t), "record_latency", 200.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -350,6 +365,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
self.evaluate(summary_t), "record_latency", 200.0)
|
self.evaluate(summary_t), "record_latency", 200.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
@ -390,6 +406,7 @@ class FeatureStatsDatasetTest(
|
|||||||
stats_dataset_test_base.StatsDatasetTestBase,
|
stats_dataset_test_base.StatsDatasetTestBase,
|
||||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFeaturesStats(self, dataset_transformation):
|
def testFeaturesStats(self, dataset_transformation):
|
||||||
num_epochs = 5
|
num_epochs = 5
|
||||||
total_records = num_epochs * self._num_records
|
total_records = num_epochs * self._num_records
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
@ -37,6 +38,7 @@ from tensorflow.python.util import compat
|
|||||||
|
|
||||||
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchWithUnknownRankInput(self):
|
def testUnbatchWithUnknownRankInput(self):
|
||||||
placeholder = array_ops.placeholder(dtypes.int32)
|
placeholder = array_ops.placeholder(dtypes.int32)
|
||||||
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
|
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
|
||||||
@ -51,6 +53,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_elem)
|
self.evaluate(next_elem)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchScalarDataset(self):
|
def testUnbatchScalarDataset(self):
|
||||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||||
@ -70,6 +73,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchDatasetWithStrings(self):
|
def testUnbatchDatasetWithStrings(self):
|
||||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||||
@ -90,6 +94,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchDatasetWithSparseTensor(self):
|
def testUnbatchDatasetWithSparseTensor(self):
|
||||||
st = sparse_tensor.SparseTensorValue(
|
st = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[i, i] for i in range(10)],
|
indices=[[i, i] for i in range(10)],
|
||||||
@ -111,6 +116,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchDatasetWithDenseAndSparseTensor(self):
|
def testUnbatchDatasetWithDenseAndSparseTensor(self):
|
||||||
st = sparse_tensor.SparseTensorValue(
|
st = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[i, i] for i in range(10)],
|
indices=[[i, i] for i in range(10)],
|
||||||
@ -133,6 +139,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchSingleElementTupleDataset(self):
|
def testUnbatchSingleElementTupleDataset(self):
|
||||||
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
||||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||||
@ -152,6 +159,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchMultiElementTupleDataset(self):
|
def testUnbatchMultiElementTupleDataset(self):
|
||||||
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
||||||
array_ops.fill([10], "hi")) for i in range(3)])
|
array_ops.fill([10], "hi")) for i in range(3)])
|
||||||
@ -173,6 +181,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchEmpty(self):
|
def testUnbatchEmpty(self):
|
||||||
data = dataset_ops.Dataset.from_tensors(
|
data = dataset_ops.Dataset.from_tensors(
|
||||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||||
@ -191,6 +200,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
data.apply(batching.unbatch())
|
data.apply(batching.unbatch())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnbatchDynamicShapeMismatch(self):
|
def testUnbatchDynamicShapeMismatch(self):
|
||||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ class UniqueTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSimpleInt(self):
|
def testSimpleInt(self):
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
self._testSimpleHelper(dtype, [
|
self._testSimpleHelper(dtype, [
|
||||||
@ -69,6 +71,7 @@ class UniqueTest(test_base.DatasetTestBase):
|
|||||||
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
|
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSimpleString(self):
|
def testSimpleString(self):
|
||||||
self._testSimpleHelper(dtypes.string, [
|
self._testSimpleHelper(dtypes.string, [
|
||||||
([], []),
|
([], []),
|
||||||
|
@ -60,6 +60,7 @@ class FlatMapTest(test_base.DatasetTestBase):
|
|||||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||||
|
|
||||||
# Note: no eager mode coverage, session specific test.
|
# Note: no eager mode coverage, session specific test.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerSharedResourceNestedFlatMapDataset(self):
|
def testSkipEagerSharedResourceNestedFlatMapDataset(self):
|
||||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||||
components = np.array(repeats, dtype=np.int64)
|
components = np.array(repeats, dtype=np.int64)
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -69,6 +70,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorUsingFunction(self):
|
def testFromGeneratorUsingFunction(self):
|
||||||
def generator():
|
def generator():
|
||||||
for i in range(1, 100):
|
for i in range(1, 100):
|
||||||
@ -79,18 +81,21 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
|
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
|
||||||
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
|
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorUsingList(self):
|
def testFromGeneratorUsingList(self):
|
||||||
generator = lambda: [[i] * i for i in range(1, 100)]
|
generator = lambda: [[i] * i for i in range(1, 100)]
|
||||||
elem_sequence = list(generator())
|
elem_sequence = list(generator())
|
||||||
self._testFromGenerator(generator, elem_sequence, 1)
|
self._testFromGenerator(generator, elem_sequence, 1)
|
||||||
self._testFromGenerator(generator, elem_sequence, 5)
|
self._testFromGenerator(generator, elem_sequence, 5)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorUsingNdarray(self):
|
def testFromGeneratorUsingNdarray(self):
|
||||||
generator = lambda: np.arange(100, dtype=np.int64)
|
generator = lambda: np.arange(100, dtype=np.int64)
|
||||||
elem_sequence = list(generator())
|
elem_sequence = list(generator())
|
||||||
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
|
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
|
||||||
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
|
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorUsingGeneratorExpression(self):
|
def testFromGeneratorUsingGeneratorExpression(self):
|
||||||
# NOTE(mrry): Generator *expressions* are not repeatable (or in
|
# NOTE(mrry): Generator *expressions* are not repeatable (or in
|
||||||
# general reusable), because they eagerly evaluate the `for`
|
# general reusable), because they eagerly evaluate the `for`
|
||||||
@ -102,6 +107,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
self._testFromGenerator(generator, elem_sequence, 1)
|
self._testFromGenerator(generator, elem_sequence, 1)
|
||||||
self._testFromGenerator(generator, elem_sequence, 5)
|
self._testFromGenerator(generator, elem_sequence, 5)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromMultipleConcurrentGenerators(self):
|
def testFromMultipleConcurrentGenerators(self):
|
||||||
num_inner_repeats = 5
|
num_inner_repeats = 5
|
||||||
num_outer_repeats = 100
|
num_outer_repeats = 100
|
||||||
@ -199,6 +205,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorImplicitConversion(self):
|
def testFromGeneratorImplicitConversion(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield [1]
|
yield [1]
|
||||||
@ -223,6 +230,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorString(self):
|
def testFromGeneratorString(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield "foo"
|
yield "foo"
|
||||||
@ -243,6 +251,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorTypeError(self):
|
def testFromGeneratorTypeError(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield np.array([1, 2, 3], dtype=np.int64)
|
yield np.array([1, 2, 3], dtype=np.int64)
|
||||||
@ -266,6 +275,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorShapeError(self):
|
def testFromGeneratorShapeError(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield np.array([1, 2, 3], dtype=np.int64)
|
yield np.array([1, 2, 3], dtype=np.int64)
|
||||||
@ -289,6 +299,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorStructureError(self):
|
def testFromGeneratorStructureError(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield 1, 2
|
yield 1, 2
|
||||||
@ -317,6 +328,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorHeterogeneous(self):
|
def testFromGeneratorHeterogeneous(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield 1
|
yield 1
|
||||||
@ -335,6 +347,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorStopShort(self):
|
def testFromGeneratorStopShort(self):
|
||||||
|
|
||||||
def generator():
|
def generator():
|
||||||
@ -353,6 +366,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
self.assertAllEqual(0, sess.run(get_next))
|
self.assertAllEqual(0, sess.run(get_next))
|
||||||
self.assertAllEqual(1, sess.run(get_next))
|
self.assertAllEqual(1, sess.run(get_next))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorDestructorCalled(self):
|
def testFromGeneratorDestructorCalled(self):
|
||||||
# Use an `Event` to signal that the generator has been deleted.
|
# Use an `Event` to signal that the generator has been deleted.
|
||||||
event = threading.Event()
|
event = threading.Event()
|
||||||
@ -387,6 +401,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
# iterator terminates (and the generator iterator is deleted).
|
# iterator terminates (and the generator iterator is deleted).
|
||||||
self.assertTrue(event.is_set())
|
self.assertTrue(event.is_set())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorWithArgs(self):
|
def testFromGeneratorWithArgs(self):
|
||||||
|
|
||||||
def flat_map_fn(elem):
|
def flat_map_fn(elem):
|
||||||
@ -414,6 +429,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromGeneratorWithTwoArgs(self):
|
def testFromGeneratorWithTwoArgs(self):
|
||||||
|
|
||||||
def flat_map_fn(elem, message):
|
def flat_map_fn(elem, message):
|
||||||
@ -446,6 +462,7 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGeneratorDatasetFinalizeFunctionCalled(self):
|
def testGeneratorDatasetFinalizeFunctionCalled(self):
|
||||||
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
|
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
|
||||||
# which affords more control over what the finalize function can do than
|
# which affords more control over what the finalize function can do than
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.platform import test
|
|||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class FromSparseTensorSlicesTest(test_base.DatasetTestBase):
|
class FromSparseTensorSlicesTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerFromSparseTensorSlices(self):
|
def testSkipEagerFromSparseTensorSlices(self):
|
||||||
"""Test a dataset based on slices of a `tf.SparseTensor`."""
|
"""Test a dataset based on slices of a `tf.SparseTensor`."""
|
||||||
st = array_ops.sparse_placeholder(dtypes.float64)
|
st = array_ops.sparse_placeholder(dtypes.float64)
|
||||||
|
@ -154,6 +154,7 @@ class FromTensorsTest(test_base.DatasetTestBase):
|
|||||||
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
||||||
|
|
||||||
# TODO(b/117581999): more specific shapes in eager mode.
|
# TODO(b/117581999): more specific shapes in eager mode.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerNestedStructure(self):
|
def testSkipEagerNestedStructure(self):
|
||||||
components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
|
components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
|
||||||
np.array([6., 7.])),
|
np.array([6., 7.])),
|
||||||
|
@ -55,6 +55,7 @@ from tensorflow.python.util import compat
|
|||||||
|
|
||||||
class IteratorTest(test.TestCase, parameterized.TestCase):
|
class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoGradients(self):
|
def testNoGradients(self):
|
||||||
component = constant_op.constant([1.])
|
component = constant_op.constant([1.])
|
||||||
side = constant_op.constant(0.)
|
side = constant_op.constant(0.)
|
||||||
@ -65,6 +66,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIsNone(gradients_impl.gradients(value, side)[0])
|
self.assertIsNone(gradients_impl.gradients(value, side)[0])
|
||||||
self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
|
self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCapturingStateInOneShotRaisesException(self):
|
def testCapturingStateInOneShotRaisesException(self):
|
||||||
var = variables.Variable(37.0, name="myvar")
|
var = variables.Variable(37.0, name="myvar")
|
||||||
dataset = (
|
dataset = (
|
||||||
@ -75,6 +77,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
"datasets that capture stateful objects.+myvar"):
|
"datasets that capture stateful objects.+myvar"):
|
||||||
dataset.make_one_shot_iterator()
|
dataset.make_one_shot_iterator()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOneShotIterator(self):
|
def testOneShotIterator(self):
|
||||||
components = (np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
@ -100,6 +103,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOneShotIteratorCaptureByValue(self):
|
def testOneShotIteratorCaptureByValue(self):
|
||||||
components = (np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
@ -162,6 +166,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOneShotIteratorNonBlocking(self):
|
def testOneShotIteratorNonBlocking(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
|
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
|
||||||
iterator = dataset.make_one_shot_iterator()
|
iterator = dataset.make_one_shot_iterator()
|
||||||
@ -200,6 +205,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
len([None for r in results if r is None]))
|
len([None for r in results if r is None]))
|
||||||
self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
|
self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOneShotIteratorInitializerFails(self):
|
def testOneShotIteratorInitializerFails(self):
|
||||||
# Define a dataset whose initialization will always fail.
|
# Define a dataset whose initialization will always fail.
|
||||||
dataset = dataset_ops.Dataset.from_tensors(
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
@ -280,6 +286,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNotInitializedError(self):
|
def testNotInitializedError(self):
|
||||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -292,6 +299,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
"iterator has not been initialized"):
|
"iterator has not been initialized"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReinitializableIterator(self):
|
def testReinitializableIterator(self):
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensors(
|
dataset_3 = dataset_ops.Dataset.from_tensors(
|
||||||
constant_op.constant([1, 2, 3]))
|
constant_op.constant([1, 2, 3]))
|
||||||
@ -331,6 +339,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReinitializableIteratorWithFunctions(self):
|
def testReinitializableIteratorWithFunctions(self):
|
||||||
|
|
||||||
def g():
|
def g():
|
||||||
@ -390,6 +399,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
(constant_op.constant([1, 2, 3], dtype=dtypes.int64),
|
(constant_op.constant([1, 2, 3], dtype=dtypes.int64),
|
||||||
constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
|
constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIteratorStringHandle(self):
|
def testIteratorStringHandle(self):
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||||
@ -445,6 +455,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIteratorStringHandleFuture(self):
|
def testIteratorStringHandleFuture(self):
|
||||||
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
@ -508,6 +519,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIteratorStringHandleReuseTensorObject(self):
|
def testIteratorStringHandleReuseTensorObject(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
one_shot_iterator = dataset.make_one_shot_iterator()
|
one_shot_iterator = dataset.make_one_shot_iterator()
|
||||||
@ -536,6 +548,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual("foo_1", handle_with_same_name.op.name)
|
self.assertEqual("foo_1", handle_with_same_name.op.name)
|
||||||
self.assertIsNot(handle_with_name, handle_with_same_name)
|
self.assertIsNot(handle_with_name, handle_with_same_name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIteratorStringHandleError(self):
|
def testIteratorStringHandleError(self):
|
||||||
dataset_int_scalar = (
|
dataset_int_scalar = (
|
||||||
dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
|
dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
|
||||||
@ -576,6 +589,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
feedable_int_vector.get_next(),
|
feedable_int_vector.get_next(),
|
||||||
feed_dict={handle_placeholder: handle_float_vector}))
|
feed_dict={handle_placeholder: handle_float_vector}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||||
worker_config = config_pb2.ConfigProto()
|
worker_config = config_pb2.ConfigProto()
|
||||||
worker_config.device_count["CPU"] = 3
|
worker_config.device_count["CPU"] = 3
|
||||||
@ -632,6 +646,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
|
def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
|
||||||
s1 = server_lib.Server.create_local_server()
|
s1 = server_lib.Server.create_local_server()
|
||||||
s2 = server_lib.Server.create_local_server()
|
s2 = server_lib.Server.create_local_server()
|
||||||
@ -739,6 +754,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIncorrectIteratorRestore(self):
|
def testIncorrectIteratorRestore(self):
|
||||||
|
|
||||||
def _path():
|
def _path():
|
||||||
@ -797,6 +813,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(restore_op)
|
sess.run(restore_op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRepeatedGetNextWarning(self):
|
def testRepeatedGetNextWarning(self):
|
||||||
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
|
@ -44,6 +44,7 @@ class ListFilesTest(test_base.DatasetTestBase):
|
|||||||
open(path.join(self.tmp_dir, filename), 'a').close()
|
open(path.join(self.tmp_dir, filename), 'a').close()
|
||||||
|
|
||||||
# Note: eager mode fails in assertion error same as initializer in graph mode.
|
# Note: eager mode fails in assertion error same as initializer in graph mode.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerEmptyDirectory(self):
|
def testSkipEagerEmptyDirectory(self):
|
||||||
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
|
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
|
||||||
self.assertDatasetProduces(dataset, expected_output=[])
|
self.assertDatasetProduces(dataset, expected_output=[])
|
||||||
|
@ -74,6 +74,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(expected.dense_shape,
|
self.assertAllEqual(expected.dense_shape,
|
||||||
self.evaluate(actual.dense_shape))
|
self.evaluate(actual.dense_shape))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromNone(self):
|
def testFromNone(self):
|
||||||
value_structure = structure.TensorStructure(dtypes.float32, [])
|
value_structure = structure.TensorStructure(dtypes.float32, [])
|
||||||
opt = optional_ops.Optional.none_from_structure(value_structure)
|
opt = optional_ops.Optional.none_from_structure(value_structure)
|
||||||
@ -267,6 +268,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
optional_ops.OptionalStructure(
|
optional_ops.OptionalStructure(
|
||||||
structure.TensorStructure(dtypes.float32, []))),
|
structure.TensorStructure(dtypes.float32, []))),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerOptionalStructure(self, tf_value_fn,
|
def testSkipEagerOptionalStructure(self, tf_value_fn,
|
||||||
expected_value_structure):
|
expected_value_structure):
|
||||||
tf_value = tf_value_fn()
|
tf_value = tf_value_fn()
|
||||||
@ -322,6 +324,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
|
indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
|
||||||
dense_shape=[2, 2])}, False),
|
dense_shape=[2, 2])}, False),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
|
def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
|
||||||
works_on_gpu):
|
works_on_gpu):
|
||||||
if not works_on_gpu and test.is_gpu_available():
|
if not works_on_gpu and test.is_gpu_available():
|
||||||
|
@ -93,6 +93,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next())
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPaddedBatchShortPadding(self):
|
def testPaddedBatchShortPadding(self):
|
||||||
dataset = (
|
dataset = (
|
||||||
dataset_ops.Dataset.from_tensor_slices(
|
dataset_ops.Dataset.from_tensor_slices(
|
||||||
@ -155,6 +156,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
next_element = self.getNext(padded_dataset)
|
next_element = self.getNext(padded_dataset)
|
||||||
self.evaluate(next_element())
|
self.evaluate(next_element())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
|
def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
|
||||||
int_placeholder = array_ops.placeholder(dtypes.int32)
|
int_placeholder = array_ops.placeholder(dtypes.int32)
|
||||||
float_placeholder = array_ops.placeholder(dtypes.float32)
|
float_placeholder = array_ops.placeholder(dtypes.float32)
|
||||||
@ -226,6 +228,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
_ = dataset_ops.Dataset.range(10).padded_batch(
|
_ = dataset_ops.Dataset.range(10).padded_batch(
|
||||||
5, padded_shapes=shape_as_tensor)
|
5, padded_shapes=shape_as_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerPaddedBatchShapeError(self):
|
def testSkipEagerPaddedBatchShapeError(self):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
|
@ -68,6 +68,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(((i + 1) * i) // 2, s)
|
self.assertEqual(((i + 1) * i) // 2, s)
|
||||||
self.assertEqual(i, c)
|
self.assertEqual(i, c)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerSquareUsingPlaceholder(self):
|
def testSkipEagerSquareUsingPlaceholder(self):
|
||||||
delta = array_ops.placeholder(dtype=dtypes.int64)
|
delta = array_ops.placeholder(dtype=dtypes.int64)
|
||||||
|
|
||||||
|
@ -115,6 +115,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(get_next())
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerSeedZero(self):
|
def testSkipEagerSeedZero(self):
|
||||||
"""Test for same behavior when the seed is a Python or Tensor zero."""
|
"""Test for same behavior when the seed is a Python or Tensor zero."""
|
||||||
iterator = (
|
iterator = (
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.data.util import convert
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
@ -60,6 +61,7 @@ class ConvertTest(test.TestCase):
|
|||||||
convert.partial_shape_to_tensor(
|
convert.partial_shape_to_tensor(
|
||||||
constant_op.constant([1], dtype=dtypes.int64))))
|
constant_op.constant([1], dtype=dtypes.int64))))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartialShapeToTensorUnknownDimension(self):
|
def testPartialShapeToTensorUnknownDimension(self):
|
||||||
self.assertAllEqual([-1],
|
self.assertAllEqual([-1],
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -296,6 +297,7 @@ class SparseTest(test.TestCase):
|
|||||||
self.assertAllEqual(a.eval().values, self.evaluate(b).values)
|
self.assertAllEqual(a.eval().values, self.evaluate(b).values)
|
||||||
self.assertAllEqual(a.eval().dense_shape, self.evaluate(b).dense_shape)
|
self.assertAllEqual(a.eval().dense_shape, self.evaluate(b).dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSerializeDeserialize(self):
|
def testSerializeDeserialize(self):
|
||||||
test_cases = (
|
test_cases = (
|
||||||
(),
|
(),
|
||||||
@ -325,6 +327,7 @@ class SparseTest(test.TestCase):
|
|||||||
for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
|
for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
|
||||||
self.assertSparseValuesEqual(a, e)
|
self.assertSparseValuesEqual(a, e)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSerializeManyDeserialize(self):
|
def testSerializeManyDeserialize(self):
|
||||||
test_cases = (
|
test_cases = (
|
||||||
(),
|
(),
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -115,6 +116,7 @@ class StructureTest(test.TestCase, parameterized.TestCase):
|
|||||||
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
|
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
|
||||||
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
|
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
|
||||||
)
|
)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIsCompatibleWithStructure(
|
def testIsCompatibleWithStructure(
|
||||||
self, original_value_fn, compatible_values_fn, incompatible_values_fn):
|
self, original_value_fn, compatible_values_fn, incompatible_values_fn):
|
||||||
original_value = original_value_fn()
|
original_value = original_value_fn()
|
||||||
|
@ -645,6 +645,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(len("Size (B)") + 1, dump_size_col_width)
|
self.assertEqual(len("Size (B)") + 1, dump_size_col_width)
|
||||||
self.assertEqual(len("Op type") + 1, op_type_col_width)
|
self.assertEqual(len("Op type") + 1, op_type_col_width)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMeasureTensorListColumnWidthsGivesRightAnswerForData(self):
|
def testMeasureTensorListColumnWidthsGivesRightAnswerForData(self):
|
||||||
dump = self._debug_dump.dumped_tensor_data[0]
|
dump = self._debug_dump.dumped_tensor_data[0]
|
||||||
self.assertLess(dump.dump_size_bytes, 1000)
|
self.assertLess(dump.dump_size_bytes, 1000)
|
||||||
@ -660,6 +661,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
# column should be determined by the length of "VariableV2".
|
# column should be determined by the length of "VariableV2".
|
||||||
self.assertEqual(len("VariableV2") + 1, op_type_col_width)
|
self.assertEqual(len("VariableV2") + 1, op_type_col_width)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensors(self):
|
def testListTensors(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", [])
|
out = self._registry.dispatch_command("lt", [])
|
||||||
@ -673,6 +675,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
# Check the main menu.
|
# Check the main menu.
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInReverseTimeOrderWorks(self):
|
def testListTensorsInReverseTimeOrderWorks(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "timestamp", "-r"])
|
out = self._registry.dispatch_command("lt", ["-s", "timestamp", "-r"])
|
||||||
@ -688,6 +691,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
reverse=True)
|
reverse=True)
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInDumpSizeOrderWorks(self):
|
def testListTensorsInDumpSizeOrderWorks(self):
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "dump_size"])
|
out = self._registry.dispatch_command("lt", ["-s", "dump_size"])
|
||||||
assert_listed_tensors(
|
assert_listed_tensors(
|
||||||
@ -701,6 +705,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
sort_by="dump_size")
|
sort_by="dump_size")
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInReverseDumpSizeOrderWorks(self):
|
def testListTensorsInReverseDumpSizeOrderWorks(self):
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "dump_size", "-r"])
|
out = self._registry.dispatch_command("lt", ["-s", "dump_size", "-r"])
|
||||||
assert_listed_tensors(
|
assert_listed_tensors(
|
||||||
@ -720,6 +725,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIn("ValueError: Unsupported key to sort tensors by: foobar",
|
self.assertIn("ValueError: Unsupported key to sort tensors by: foobar",
|
||||||
out.lines)
|
out.lines)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInOpTypeOrderWorks(self):
|
def testListTensorsInOpTypeOrderWorks(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "op_type"])
|
out = self._registry.dispatch_command("lt", ["-s", "op_type"])
|
||||||
@ -735,6 +741,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
reverse=False)
|
reverse=False)
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInReverseOpTypeOrderWorks(self):
|
def testListTensorsInReverseOpTypeOrderWorks(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "op_type", "-r"])
|
out = self._registry.dispatch_command("lt", ["-s", "op_type", "-r"])
|
||||||
@ -750,6 +757,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
reverse=True)
|
reverse=True)
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInTensorNameOrderWorks(self):
|
def testListTensorsInTensorNameOrderWorks(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "tensor_name"])
|
out = self._registry.dispatch_command("lt", ["-s", "tensor_name"])
|
||||||
@ -765,6 +773,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
reverse=False)
|
reverse=False)
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsInReverseTensorNameOrderWorks(self):
|
def testListTensorsInReverseTensorNameOrderWorks(self):
|
||||||
# Use shorthand alias for the command prefix.
|
# Use shorthand alias for the command prefix.
|
||||||
out = self._registry.dispatch_command("lt", ["-s", "tensor_name", "-r"])
|
out = self._registry.dispatch_command("lt", ["-s", "tensor_name", "-r"])
|
||||||
@ -780,6 +789,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
reverse=True)
|
reverse=True)
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorsFilterByNodeNameRegex(self):
|
def testListTensorsFilterByNodeNameRegex(self):
|
||||||
out = self._registry.dispatch_command("list_tensors",
|
out = self._registry.dispatch_command("list_tensors",
|
||||||
["--node_name_filter", ".*read.*"])
|
["--node_name_filter", ".*read.*"])
|
||||||
@ -793,6 +803,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
assert_listed_tensors(self, out, [], [], node_name_regex="^read")
|
assert_listed_tensors(self, out, [], [], node_name_regex="^read")
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorFilterByOpTypeRegex(self):
|
def testListTensorFilterByOpTypeRegex(self):
|
||||||
out = self._registry.dispatch_command("list_tensors",
|
out = self._registry.dispatch_command("list_tensors",
|
||||||
["--op_type_filter", "Identity"])
|
["--op_type_filter", "Identity"])
|
||||||
@ -821,6 +832,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
op_type_regex="(Add|MatMul)")
|
op_type_regex="(Add|MatMul)")
|
||||||
check_main_menu(self, out, list_tensors_enabled=False)
|
check_main_menu(self, out, list_tensors_enabled=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListTensorWithFilterAndNodeNameExclusionWorks(self):
|
def testListTensorWithFilterAndNodeNameExclusionWorks(self):
|
||||||
# First, create and register the filter.
|
# First, create and register the filter.
|
||||||
def is_2x1_vector(datum, tensor):
|
def is_2x1_vector(datum, tensor):
|
||||||
@ -877,6 +889,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
out = self._registry.dispatch_command("list_tensors", ["--bar"])
|
out = self._registry.dispatch_command("list_tensors", ["--bar"])
|
||||||
check_syntax_error_output(self, out, "list_tensors")
|
check_syntax_error_output(self, out, "list_tensors")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoByNodeName(self):
|
def testNodeInfoByNodeName(self):
|
||||||
node_name = "simple_mul_add/matmul"
|
node_name = "simple_mul_add/matmul"
|
||||||
out = self._registry.dispatch_command("node_info", [node_name])
|
out = self._registry.dispatch_command("node_info", [node_name])
|
||||||
@ -901,6 +914,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
[(len(out.lines[0]) - len(node_name), len(out.lines[0]), "bold")],
|
[(len(out.lines[0]) - len(node_name), len(out.lines[0]), "bold")],
|
||||||
out.font_attr_segs[0])
|
out.font_attr_segs[0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoShowAttributes(self):
|
def testNodeInfoShowAttributes(self):
|
||||||
node_name = "simple_mul_add/matmul"
|
node_name = "simple_mul_add/matmul"
|
||||||
out = self._registry.dispatch_command("node_info", ["-a", node_name])
|
out = self._registry.dispatch_command("node_info", ["-a", node_name])
|
||||||
@ -924,6 +938,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
print_tensor_node_name=node_name,
|
print_tensor_node_name=node_name,
|
||||||
list_outputs_node_name=node_name)
|
list_outputs_node_name=node_name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoShowDumps(self):
|
def testNodeInfoShowDumps(self):
|
||||||
node_name = "simple_mul_add/matmul"
|
node_name = "simple_mul_add/matmul"
|
||||||
out = self._registry.dispatch_command("node_info", ["-d", node_name])
|
out = self._registry.dispatch_command("node_info", ["-d", node_name])
|
||||||
@ -948,6 +963,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[16]) - len(out.lines[16].strip()),
|
len(out.lines[16]) - len(out.lines[16].strip()),
|
||||||
len(out.lines[16]), "pt %s:0 -n 0" % node_name)
|
len(out.lines[16]), "pt %s:0 -n 0" % node_name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoShowStackTraceUnavailableIsIndicated(self):
|
def testNodeInfoShowStackTraceUnavailableIsIndicated(self):
|
||||||
self._debug_dump.set_python_graph(None)
|
self._debug_dump.set_python_graph(None)
|
||||||
|
|
||||||
@ -971,6 +987,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
print_tensor_node_name=node_name,
|
print_tensor_node_name=node_name,
|
||||||
list_outputs_node_name=node_name)
|
list_outputs_node_name=node_name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoShowStackTraceAvailableWorks(self):
|
def testNodeInfoShowStackTraceAvailableWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
|
|
||||||
@ -994,6 +1011,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
print_tensor_node_name=node_name,
|
print_tensor_node_name=node_name,
|
||||||
list_outputs_node_name=node_name)
|
list_outputs_node_name=node_name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoByTensorName(self):
|
def testNodeInfoByTensorName(self):
|
||||||
node_name = "simple_mul_add/u/read"
|
node_name = "simple_mul_add/u/read"
|
||||||
tensor_name = node_name + ":0"
|
tensor_name = node_name + ":0"
|
||||||
@ -1363,6 +1381,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
break
|
break
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceForOpNamesWholeFileWorks(self):
|
def testPrintSourceForOpNamesWholeFileWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
out = self._registry.dispatch_command(
|
out = self._registry.dispatch_command(
|
||||||
@ -1415,6 +1434,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("pt simple_mul_add/add",
|
self.assertEqual("pt simple_mul_add/add",
|
||||||
out.font_attr_segs[index + 1][0][2].content)
|
out.font_attr_segs[index + 1][0][2].content)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceForTensorNamesWholeFileWorks(self):
|
def testPrintSourceForTensorNamesWholeFileWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
out = self._registry.dispatch_command(
|
out = self._registry.dispatch_command(
|
||||||
@ -1435,6 +1455,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("pt simple_mul_add/u:0",
|
self.assertEqual("pt simple_mul_add/u:0",
|
||||||
out.font_attr_segs[index + 2][0][2].content)
|
out.font_attr_segs[index + 2][0][2].content)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceForOpNamesStartingAtSpecifiedLineWorks(self):
|
def testPrintSourceForOpNamesStartingAtSpecifiedLineWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
out = self._registry.dispatch_command(
|
out = self._registry.dispatch_command(
|
||||||
@ -1461,6 +1482,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("pt simple_mul_add/u/read",
|
self.assertEqual("pt simple_mul_add/u/read",
|
||||||
out.font_attr_segs[index + 3][0][2].content)
|
out.font_attr_segs[index + 3][0][2].content)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceForOpNameSettingMaximumElementCountWorks(self):
|
def testPrintSourceForOpNameSettingMaximumElementCountWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
out = self._registry.dispatch_command(
|
out = self._registry.dispatch_command(
|
||||||
@ -1505,6 +1527,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
|
self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
|
||||||
attr_seg[2] == cli_shared.COLOR_GRAY)
|
attr_seg[2] == cli_shared.COLOR_GRAY)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListSourceWithNodeNameFilterWithMatchesWorks(self):
|
def testListSourceWithNodeNameFilterWithMatchesWorks(self):
|
||||||
self._debug_dump.set_python_graph(self._sess.graph)
|
self._debug_dump.set_python_graph(self._sess.graph)
|
||||||
out = self._registry.dispatch_command("list_source", ["-n", ".*/read"])
|
out = self._registry.dispatch_command("list_source", ["-n", ".*/read"])
|
||||||
@ -1719,6 +1742,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
# Tear down temporary dump directory.
|
# Tear down temporary dump directory.
|
||||||
shutil.rmtree(cls._dump_root)
|
shutil.rmtree(cls._dump_root)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNodeInfoWithControlDependencies(self):
|
def testNodeInfoWithControlDependencies(self):
|
||||||
# Call node_info on a node with control inputs.
|
# Call node_info on a node with control inputs.
|
||||||
out = self._registry.dispatch_command("node_info",
|
out = self._registry.dispatch_command("node_info",
|
||||||
@ -1759,6 +1783,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[z_line]),
|
len(out.lines[z_line]),
|
||||||
"ni -a -d -t control_deps/ctrl_dep_z")
|
"ni -a -d -t control_deps/ctrl_dep_z")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListInputsNonRecursiveNoControl(self):
|
def testListInputsNonRecursiveNoControl(self):
|
||||||
"""List inputs non-recursively, without any control inputs."""
|
"""List inputs non-recursively, without any control inputs."""
|
||||||
|
|
||||||
@ -1801,6 +1826,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
|
len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
|
||||||
len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
|
len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListInputsNonRecursiveNoControlUsingTensorName(self):
|
def testListInputsNonRecursiveNoControlUsingTensorName(self):
|
||||||
"""List inputs using the name of an output tensor of the node."""
|
"""List inputs using the name of an output tensor of the node."""
|
||||||
|
|
||||||
@ -1829,6 +1855,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
|
len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
|
||||||
len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
|
len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListInputsNonRecursiveWithControls(self):
|
def testListInputsNonRecursiveWithControls(self):
|
||||||
"""List inputs non-recursively, with control inputs."""
|
"""List inputs non-recursively, with control inputs."""
|
||||||
node_name = "control_deps/ctrl_dep_z"
|
node_name = "control_deps/ctrl_dep_z"
|
||||||
@ -1859,6 +1886,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[5]) - len("control_deps/x"),
|
len(out.lines[5]) - len("control_deps/x"),
|
||||||
len(out.lines[5]), "li -c -r control_deps/x")
|
len(out.lines[5]), "li -c -r control_deps/x")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListInputsRecursiveWithControls(self):
|
def testListInputsRecursiveWithControls(self):
|
||||||
"""List inputs recursively, with control inputs."""
|
"""List inputs recursively, with control inputs."""
|
||||||
node_name = "control_deps/ctrl_dep_z"
|
node_name = "control_deps/ctrl_dep_z"
|
||||||
@ -1904,6 +1932,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
len(out.lines[18]) - len("control_deps/x"),
|
len(out.lines[18]) - len("control_deps/x"),
|
||||||
len(out.lines[18]), "li -c -r control_deps/x")
|
len(out.lines[18]), "li -c -r control_deps/x")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListInputsRecursiveWithControlsWithDepthLimit(self):
|
def testListInputsRecursiveWithControlsWithDepthLimit(self):
|
||||||
"""List inputs recursively, with control inputs and a depth limit."""
|
"""List inputs recursively, with control inputs and a depth limit."""
|
||||||
node_name = "control_deps/ctrl_dep_z"
|
node_name = "control_deps/ctrl_dep_z"
|
||||||
@ -1963,6 +1992,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
|
|||||||
"ERROR: There is no node named \"control_deps/z/foo\" in the "
|
"ERROR: There is no node named \"control_deps/z/foo\" in the "
|
||||||
"partition graphs"], out.lines)
|
"partition graphs"], out.lines)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListRecipientsRecursiveWithControlsWithDepthLimit(self):
|
def testListRecipientsRecursiveWithControlsWithDepthLimit(self):
|
||||||
"""List recipients recursively, with control inputs and a depth limit."""
|
"""List recipients recursively, with control inputs and a depth limit."""
|
||||||
|
|
||||||
@ -2034,6 +2064,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
|
|||||||
# Tear down temporary dump directory.
|
# Tear down temporary dump directory.
|
||||||
shutil.rmtree(cls._dump_root)
|
shutil.rmtree(cls._dump_root)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleDumpsPrintTensorNoNumber(self):
|
def testMultipleDumpsPrintTensorNoNumber(self):
|
||||||
output = self._registry.dispatch_command("pt", ["while/Identity:0"])
|
output = self._registry.dispatch_command("pt", ["while/Identity:0"])
|
||||||
|
|
||||||
@ -2051,6 +2082,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("For example:", output.lines[-2])
|
self.assertEqual("For example:", output.lines[-2])
|
||||||
self.assertEqual(" print_tensor while/Identity:0 -n 0", output.lines[-1])
|
self.assertEqual(" print_tensor while/Identity:0 -n 0", output.lines[-1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleDumpsPrintTensorWithNumber(self):
|
def testMultipleDumpsPrintTensorWithNumber(self):
|
||||||
for i in xrange(5):
|
for i in xrange(5):
|
||||||
output = self._registry.dispatch_command(
|
output = self._registry.dispatch_command(
|
||||||
@ -2064,6 +2096,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(output.lines[4].startswith("array(%d" % i))
|
self.assertTrue(output.lines[4].startswith("array(%d" % i))
|
||||||
self.assertTrue(output.lines[4].endswith(")"))
|
self.assertTrue(output.lines[4].endswith(")"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleDumpsPrintTensorInvalidNumber(self):
|
def testMultipleDumpsPrintTensorInvalidNumber(self):
|
||||||
output = self._registry.dispatch_command("pt",
|
output = self._registry.dispatch_command("pt",
|
||||||
["while/Identity:0", "-n", "10"])
|
["while/Identity:0", "-n", "10"])
|
||||||
|
@ -118,6 +118,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSingleFetchNoFeeds(self):
|
def testSingleFetchNoFeeds(self):
|
||||||
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
||||||
|
|
||||||
@ -181,6 +182,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
||||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTwoFetchesListNoFeeds(self):
|
def testTwoFetchesListNoFeeds(self):
|
||||||
fetches = [self.const_a, self.const_b]
|
fetches = [self.const_a, self.const_b]
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -197,6 +199,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNestedListAsFetches(self):
|
def testNestedListAsFetches(self):
|
||||||
fetches = [self.const_c, [self.const_a, self.const_b]]
|
fetches = [self.const_c, [self.const_a, self.const_b]]
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -210,6 +213,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNestedDictAsFetches(self):
|
def testNestedDictAsFetches(self):
|
||||||
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
|
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -227,6 +231,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTwoFetchesAsTupleNoFeeds(self):
|
def testTwoFetchesAsTupleNoFeeds(self):
|
||||||
fetches = (self.const_a, self.const_b)
|
fetches = (self.const_a, self.const_b)
|
||||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||||
@ -243,6 +248,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTwoFetchesAsNamedTupleNoFeeds(self):
|
def testTwoFetchesAsNamedTupleNoFeeds(self):
|
||||||
fetches_namedtuple = namedtuple("fetches", "x y")
|
fetches_namedtuple = namedtuple("fetches", "x y")
|
||||||
fetches = fetches_namedtuple(self.const_b, self.const_c)
|
fetches = fetches_namedtuple(self.const_b, self.const_c)
|
||||||
@ -260,6 +266,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
description = cli_shared.get_run_short_description(1, fetches, None)
|
description = cli_shared.get_run_short_description(1, fetches, None)
|
||||||
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWithFeedDict(self):
|
def testWithFeedDict(self):
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
self.const_a: 10.0,
|
self.const_a: 10.0,
|
||||||
@ -283,6 +290,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
feed_dict)
|
feed_dict)
|
||||||
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
|
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTensorFilters(self):
|
def testTensorFilters(self):
|
||||||
feed_dict = {self.const_a: 10.0}
|
feed_dict = {self.const_a: 10.0}
|
||||||
tensor_filters = {
|
tensor_filters = {
|
||||||
@ -313,11 +321,13 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
command_set.add(annot[2].content)
|
command_set.add(annot[2].content)
|
||||||
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
|
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
|
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
|
||||||
short_description = cli_shared.get_run_short_description(
|
short_description = cli_shared.get_run_short_description(
|
||||||
1, self.const_a, {self.const_a: 42.0})
|
1, self.const_a, {self.const_a: 42.0})
|
||||||
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
|
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
|
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
|
||||||
short_description = cli_shared.get_run_short_description(
|
short_description = cli_shared.get_run_short_description(
|
||||||
1, self.const_a, {u"foo": 42.0})
|
1, self.const_a, {u"foo": 42.0})
|
||||||
@ -332,6 +342,7 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShapeError(self):
|
def testShapeError(self):
|
||||||
tf_error = errors.OpError(None, self.var_a.initializer, "foo description",
|
tf_error = errors.OpError(None, self.var_a.initializer, "foo description",
|
||||||
None)
|
None)
|
||||||
|
@ -348,6 +348,7 @@ class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
|
|||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
super(ProfileAnalyzerPrintSourceTest, self).tearDown()
|
super(ProfileAnalyzerPrintSourceTest, self).tearDown()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceForWhileLoop(self):
|
def testPrintSourceForWhileLoop(self):
|
||||||
prof_output = self.prof_analyzer.print_source([__file__])
|
prof_output = self.prof_analyzer.print_source([__file__])
|
||||||
|
|
||||||
@ -361,6 +362,7 @@ class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
|
|||||||
r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
|
r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
|
||||||
prof_output.lines)
|
prof_output.lines)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceOutputContainsClickableLinks(self):
|
def testPrintSourceOutputContainsClickableLinks(self):
|
||||||
prof_output = self.prof_analyzer.print_source([__file__])
|
prof_output = self.prof_analyzer.print_source([__file__])
|
||||||
any_match, line_index = _at_least_one_line_matches(
|
any_match, line_index = _at_least_one_line_matches(
|
||||||
@ -377,6 +379,7 @@ class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
|
|||||||
break
|
break
|
||||||
self.assertTrue(any_menu_item_match)
|
self.assertTrue(any_menu_item_match)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceWithNonDefaultTimeUnit(self):
|
def testPrintSourceWithNonDefaultTimeUnit(self):
|
||||||
prof_output = self.prof_analyzer.print_source([
|
prof_output = self.prof_analyzer.print_source([
|
||||||
__file__, "--time_unit", "ms"])
|
__file__, "--time_unit", "ms"])
|
||||||
@ -391,6 +394,7 @@ class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
|
|||||||
r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
|
r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
|
||||||
prof_output.lines)
|
prof_output.lines)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceWithNodeNameFilter(self):
|
def testPrintSourceWithNodeNameFilter(self):
|
||||||
prof_output = self.prof_analyzer.print_source([
|
prof_output = self.prof_analyzer.print_source([
|
||||||
__file__, "--node_name_filter", "x$"])
|
__file__, "--node_name_filter", "x$"])
|
||||||
@ -423,6 +427,7 @@ class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
|
|||||||
break
|
break
|
||||||
self.assertTrue(any_menu_item_match)
|
self.assertTrue(any_menu_item_match)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPrintSourceWithOpTypeFilter(self):
|
def testPrintSourceWithOpTypeFilter(self):
|
||||||
prof_output = self.prof_analyzer.print_source([
|
prof_output = self.prof_analyzer.print_source([
|
||||||
__file__, "--op_type_filter", "Less"])
|
__file__, "--op_type_filter", "Less"])
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
class CommonTest(test_util.TensorFlowTestCase):
|
class CommonTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOnFeedOneFetch(self):
|
def testOnFeedOneFetch(self):
|
||||||
a = constant_op.constant(10.0, name="a")
|
a = constant_op.constant(10.0, name="a")
|
||||||
b = constant_op.constant(20.0, name="b")
|
b = constant_op.constant(20.0, name="b")
|
||||||
@ -35,6 +36,7 @@ class CommonTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertItemsEqual(["a:0"], loaded[0])
|
self.assertItemsEqual(["a:0"], loaded[0])
|
||||||
self.assertItemsEqual(["b:0"], loaded[1])
|
self.assertItemsEqual(["b:0"], loaded[1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetRunKeyFlat(self):
|
def testGetRunKeyFlat(self):
|
||||||
a = constant_op.constant(10.0, name="a")
|
a = constant_op.constant(10.0, name="a")
|
||||||
b = constant_op.constant(20.0, name="b")
|
b = constant_op.constant(20.0, name="b")
|
||||||
@ -43,6 +45,7 @@ class CommonTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertItemsEqual(["a:0"], loaded[0])
|
self.assertItemsEqual(["a:0"], loaded[0])
|
||||||
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
|
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetRunKeyNestedFetches(self):
|
def testGetRunKeyNestedFetches(self):
|
||||||
a = constant_op.constant(10.0, name="a")
|
a = constant_op.constant(10.0, name="a")
|
||||||
b = constant_op.constant(20.0, name="b")
|
b = constant_op.constant(20.0, name="b")
|
||||||
|
@ -54,6 +54,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
debug_gradients.clear_gradient_debuggers()
|
debug_gradients.clear_gradient_debuggers()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self):
|
def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self):
|
||||||
grad_debugger = debug_gradients.GradientsDebugger()
|
grad_debugger = debug_gradients.GradientsDebugger()
|
||||||
id_grad_w = grad_debugger.identify_gradient(self.w)
|
id_grad_w = grad_debugger.identify_gradient(self.w)
|
||||||
@ -84,6 +85,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w_grad, ops.Tensor)
|
self.assertIsInstance(w_grad, ops.Tensor)
|
||||||
self.assertAllClose(1.0, self.sess.run(w_grad))
|
self.assertAllClose(1.0, self.sess.run(w_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self):
|
def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self):
|
||||||
grad_debugger = debug_gradients.GradientsDebugger()
|
grad_debugger = debug_gradients.GradientsDebugger()
|
||||||
id_grad_w = grad_debugger.identify_gradient(self.w)
|
id_grad_w = grad_debugger.identify_gradient(self.w)
|
||||||
@ -115,6 +117,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w_grad, ops.Tensor)
|
self.assertIsInstance(w_grad, ops.Tensor)
|
||||||
self.assertAllClose(1.0, self.sess.run(w_grad))
|
self.assertAllClose(1.0, self.sess.run(w_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self):
|
def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self):
|
||||||
grad_debugger = debug_gradients.GradientsDebugger()
|
grad_debugger = debug_gradients.GradientsDebugger()
|
||||||
grad_debugger.identify_gradient(self.w)
|
grad_debugger.identify_gradient(self.w)
|
||||||
@ -122,6 +125,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
"The graph already contains an op named .*"):
|
"The graph already contains an op named .*"):
|
||||||
grad_debugger.identify_gradient(self.w)
|
grad_debugger.identify_gradient(self.w)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientWorksOnMultipleLosses(self):
|
def testIdentifyGradientWorksOnMultipleLosses(self):
|
||||||
grad_debugger_1 = debug_gradients.GradientsDebugger()
|
grad_debugger_1 = debug_gradients.GradientsDebugger()
|
||||||
grad_debugger_2 = debug_gradients.GradientsDebugger()
|
grad_debugger_2 = debug_gradients.GradientsDebugger()
|
||||||
@ -150,6 +154,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
|
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
|
||||||
self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
|
self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self):
|
def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self):
|
||||||
grad_debugger_1 = debug_gradients.GradientsDebugger()
|
grad_debugger_1 = debug_gradients.GradientsDebugger()
|
||||||
grad_debugger_2 = debug_gradients.GradientsDebugger()
|
grad_debugger_2 = debug_gradients.GradientsDebugger()
|
||||||
@ -170,6 +175,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
r"This GradientsDebugger has not received any gradient tensor for "):
|
r"This GradientsDebugger has not received any gradient tensor for "):
|
||||||
grad_debugger_2.gradient_tensor(self.w)
|
grad_debugger_2.gradient_tensor(self.w)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientRaisesTypeErrorForNonTensorOrTensorNameInput(self):
|
def testIdentifyGradientRaisesTypeErrorForNonTensorOrTensorNameInput(self):
|
||||||
grad_debugger = debug_gradients.GradientsDebugger()
|
grad_debugger = debug_gradients.GradientsDebugger()
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -178,6 +184,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
r"has type .*Operation.*"):
|
r"has type .*Operation.*"):
|
||||||
grad_debugger.gradient_tensor(variables.global_variables_initializer())
|
grad_debugger.gradient_tensor(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self):
|
def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self):
|
||||||
grad_debugger = debug_gradients.GradientsDebugger()
|
grad_debugger = debug_gradients.GradientsDebugger()
|
||||||
id_grad_w = grad_debugger.identify_gradient(self.w)
|
id_grad_w = grad_debugger.identify_gradient(self.w)
|
||||||
@ -193,6 +200,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w_grad, ops.Tensor)
|
self.assertIsInstance(w_grad, ops.Tensor)
|
||||||
self.assertAllClose(1.0, self.sess.run(w_grad))
|
self.assertAllClose(1.0, self.sess.run(w_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsByXTensorNamesWorks(self):
|
def testWatchGradientsByXTensorNamesWorks(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
|
|
||||||
@ -219,6 +227,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w_grad, ops.Tensor)
|
self.assertIsInstance(w_grad, ops.Tensor)
|
||||||
self.assertAllClose(1.0, self.sess.run(w_grad))
|
self.assertAllClose(1.0, self.sess.run(w_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self):
|
def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
|
|
||||||
@ -245,6 +254,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w_grad, ops.Tensor)
|
self.assertIsInstance(w_grad, ops.Tensor)
|
||||||
self.assertAllClose(1.0, self.sess.run(w_grad))
|
self.assertAllClose(1.0, self.sess.run(w_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsWorksOnRefTensor(self):
|
def testWatchGradientsWorksOnRefTensor(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
|
|
||||||
@ -263,6 +273,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(3.0, self.sess.run(
|
self.assertAllClose(3.0, self.sess.run(
|
||||||
grad_debugger.gradient_tensor("u:0")))
|
grad_debugger.gradient_tensor("u:0")))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsWorksOnMultipleTensors(self):
|
def testWatchGradientsWorksOnMultipleTensors(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
|
|
||||||
@ -283,6 +294,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(3.0, self.sess.run(
|
self.assertAllClose(3.0, self.sess.run(
|
||||||
grad_debugger.gradient_tensor("u:0")))
|
grad_debugger.gradient_tensor("u:0")))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsByXTensorsWorks(self):
|
def testWatchGradientsByXTensorsWorks(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="foo/y")
|
y = math_ops.add(self.w, -1.0, name="foo/y")
|
||||||
z = math_ops.square(y, name="foo/z")
|
z = math_ops.square(y, name="foo/z")
|
||||||
@ -305,6 +317,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(10.0, self.sess.run(w_grad))
|
self.assertAllClose(10.0, self.sess.run(w_grad))
|
||||||
self.assertAllClose(30.0, self.sess.run(u_grad))
|
self.assertAllClose(30.0, self.sess.run(u_grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGradientsByTensorCanWorkOnMultipleLosses(self):
|
def testWatchGradientsByTensorCanWorkOnMultipleLosses(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
z1 = math_ops.square(y, name="z1")
|
z1 = math_ops.square(y, name="z1")
|
||||||
@ -330,6 +343,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
|
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
|
||||||
self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
|
self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGradientsValuesFromDumpWorks(self):
|
def testGradientsValuesFromDumpWorks(self):
|
||||||
y = math_ops.add(self.w, -1.0, name="y")
|
y = math_ops.add(self.w, -1.0, name="y")
|
||||||
z = math_ops.square(y, name="z")
|
z = math_ops.square(y, name="z")
|
||||||
|
@ -185,6 +185,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"],
|
self.assertEqual(["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"],
|
||||||
watch_0.debug_urls)
|
watch_0.debug_urls)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_allNodes(self):
|
def testWatchGraph_allNodes(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -216,6 +217,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue("p1" in node_names)
|
self.assertTrue("p1" in node_names)
|
||||||
self.assertTrue("s" in node_names)
|
self.assertTrue("s" in node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_nodeNameWhitelist(self):
|
def testWatchGraph_nodeNameWhitelist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -230,6 +232,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
||||||
sorted(node_names))
|
sorted(node_names))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_opTypeWhitelist(self):
|
def testWatchGraph_opTypeWhitelist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -255,6 +258,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertEqual(["p1"], node_names)
|
self.assertEqual(["p1"], node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_tensorDTypeWhitelist(self):
|
def testWatchGraph_tensorDTypeWhitelist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -267,6 +271,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
|
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
|
def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -280,6 +285,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertItemsEqual(["a1", "a1/Assign"], node_names)
|
self.assertItemsEqual(["a1", "a1/Assign"], node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_nodeNameBlacklist(self):
|
def testWatchGraph_nodeNameBlacklist(self):
|
||||||
debug_utils.watch_graph_with_blacklists(
|
debug_utils.watch_graph_with_blacklists(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -294,6 +300,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
|
sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
|
||||||
sorted(node_names))
|
sorted(node_names))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_opTypeBlacklist(self):
|
def testWatchGraph_opTypeBlacklist(self):
|
||||||
debug_utils.watch_graph_with_blacklists(
|
debug_utils.watch_graph_with_blacklists(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -306,6 +313,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
|
self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_nodeNameAndOpTypeBlacklists(self):
|
def testWatchGraph_nodeNameAndOpTypeBlacklists(self):
|
||||||
debug_utils.watch_graph_with_blacklists(
|
debug_utils.watch_graph_with_blacklists(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -319,6 +327,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertEqual(["s"], node_names)
|
self.assertEqual(["s"], node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_tensorDTypeBlacklists(self):
|
def testWatchGraph_tensorDTypeBlacklists(self):
|
||||||
debug_utils.watch_graph_with_blacklists(
|
debug_utils.watch_graph_with_blacklists(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
@ -335,6 +344,7 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertNotIn("b/Assign", node_names)
|
self.assertNotIn("b/Assign", node_names)
|
||||||
self.assertIn("s", node_names)
|
self.assertIn("s", node_names)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWatchGraph_nodeNameAndTensorDTypeBlacklists(self):
|
def testWatchGraph_nodeNameAndTensorDTypeBlacklists(self):
|
||||||
debug_utils.watch_graph_with_blacklists(
|
debug_utils.watch_graph_with_blacklists(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import debug_utils
|
|||||||
from tensorflow.python.debug.lib import session_debug_testlib
|
from tensorflow.python.debug.lib import session_debug_testlib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
@ -44,6 +45,7 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
|
|||||||
else:
|
else:
|
||||||
return os.path.join(self._dump_root, "run_%d" % run_number)
|
return os.path.join(self._dump_root, "run_%d" % run_number)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAllowsDifferentWatchesOnDifferentRuns(self):
|
def testAllowsDifferentWatchesOnDifferentRuns(self):
|
||||||
"""Test watching different tensors on different runs of the same graph."""
|
"""Test watching different tensors on different runs of the same graph."""
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
|
source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFileInPythonKernelsPathReturnsTrue(self):
|
def testFileInPythonKernelsPathReturnsTrue(self):
|
||||||
x = constant_op.constant(42.0, name="x")
|
x = constant_op.constant(42.0, name="x")
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.platform import tf_logging
|
|||||||
|
|
||||||
class AllReduceTest(test_util.TensorFlowTestCase):
|
class AllReduceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFlattenTensorsShapesDefined(self):
|
def testFlattenTensorsShapesDefined(self):
|
||||||
x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
|
x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
@ -100,6 +101,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
|||||||
input_tensors.append(array_ops.identity(t8))
|
input_tensors.append(array_ops.identity(t8))
|
||||||
return input_tensors, device_names
|
return input_tensors, device_names
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBuildRingGatherPassStructure(self):
|
def testBuildRingGatherPassStructure(self):
|
||||||
# 1 worker, 1 device
|
# 1 worker, 1 device
|
||||||
input_tensors, device_names = self._buildInput(1, 1)
|
input_tensors, device_names = self._buildInput(1, 1)
|
||||||
@ -170,6 +172,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
|||||||
"subdiv=%d elapsed=%f" %
|
"subdiv=%d elapsed=%f" %
|
||||||
(num_workers, num_gpus, shape, subdiv, elapsed))
|
(num_workers, num_gpus, shape, subdiv, elapsed))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRingAllReduce(self):
|
def testRingAllReduce(self):
|
||||||
self._testRingAllReduce(1, 2, [], 1)
|
self._testRingAllReduce(1, 2, [], 1)
|
||||||
self._testRingAllReduce(1, 2, [8], 1)
|
self._testRingAllReduce(1, 2, [8], 1)
|
||||||
@ -199,6 +202,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
|||||||
tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s "
|
tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s "
|
||||||
"elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
|
"elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShuffleAllReduce(self):
|
def testShuffleAllReduce(self):
|
||||||
self._testShuffleAllReduce(1, 2, [], 1)
|
self._testShuffleAllReduce(1, 2, [], 1)
|
||||||
self._testShuffleAllReduce(1, 2, [8], 1)
|
self._testShuffleAllReduce(1, 2, [8], 1)
|
||||||
@ -225,6 +229,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
|||||||
"shape=%s elapsed=%f" %
|
"shape=%s elapsed=%f" %
|
||||||
(num_workers, num_gpus, shape, elapsed))
|
(num_workers, num_gpus, shape, elapsed))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRecursiveHDAllReduce(self):
|
def testRecursiveHDAllReduce(self):
|
||||||
self._testRecursiveHDAllReduce(1, 2, [8])
|
self._testRecursiveHDAllReduce(1, 2, [8])
|
||||||
self._testRecursiveHDAllReduce(1, 2, [4, 4])
|
self._testRecursiveHDAllReduce(1, 2, [4, 4])
|
||||||
|
@ -21,11 +21,13 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class DeviceUtilTest(test.TestCase):
|
class DeviceUtilTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCurrentDeviceWithGlobalGraph(self):
|
def testCurrentDeviceWithGlobalGraph(self):
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
self.assertEqual(device_util.current(), "/device:CPU:0")
|
self.assertEqual(device_util.current(), "/device:CPU:0")
|
||||||
@ -49,6 +51,7 @@ class DeviceUtilTest(test.TestCase):
|
|||||||
self.assertEqual(device_util.current(),
|
self.assertEqual(device_util.current(),
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0")
|
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCanonicalizeWithoutDefaultDevice(self):
|
def testCanonicalizeWithoutDefaultDevice(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
device_util.canonicalize("/cpu:0"),
|
device_util.canonicalize("/cpu:0"),
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.data.ops import readers
|
from tensorflow.python.data.ops import readers
|
||||||
from tensorflow.python.distribute import input_ops
|
from tensorflow.python.distribute import input_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.lib.io import python_io
|
from tensorflow.python.lib.io import python_io
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -96,6 +97,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTFRecordDataset(self):
|
def testTFRecordDataset(self):
|
||||||
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
|
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||||
dataset = input_ops.auto_shard_dataset(
|
dataset = input_ops.auto_shard_dataset(
|
||||||
@ -103,6 +105,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self._verifySimpleShardingOutput(dataset, self._record)
|
self._verifySimpleShardingOutput(dataset, self._record)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFlatMap(self):
|
def testFlatMap(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
self._createTFRecordFiles())
|
self._createTFRecordFiles())
|
||||||
@ -112,6 +115,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self._verifySimpleShardingOutput(dataset, self._record)
|
self._verifySimpleShardingOutput(dataset, self._record)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInterleave(self):
|
def testInterleave(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
self._createTFRecordFiles())
|
self._createTFRecordFiles())
|
||||||
@ -124,6 +128,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
# contain records in order of files.
|
# contain records in order of files.
|
||||||
self._verifySimpleShardingOutput(dataset, self._record)
|
self._verifySimpleShardingOutput(dataset, self._record)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testListfiles(self):
|
def testListfiles(self):
|
||||||
filenames = self._createTFRecordFiles()
|
filenames = self._createTFRecordFiles()
|
||||||
file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
|
file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
|
||||||
@ -144,6 +149,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
self.assertAllEqual(expected, actual)
|
self.assertAllEqual(expected, actual)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testComplexPipeline(self):
|
def testComplexPipeline(self):
|
||||||
# Setup a complex input pipeline.
|
# Setup a complex input pipeline.
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@ -183,6 +189,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(sorted(expected), sorted(actual))
|
self.assertAllEqual(sorted(expected), sorted(actual))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testZip(self):
|
def testZip(self):
|
||||||
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||||
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
||||||
@ -193,6 +200,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
|
record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
|
||||||
self._verifySimpleShardingOutput(dataset, record_fn)
|
self._verifySimpleShardingOutput(dataset, record_fn)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||||
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
||||||
@ -213,6 +221,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(next_element)
|
self.evaluate(next_element)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTextLineReader(self):
|
def testTextLineReader(self):
|
||||||
dataset = readers.TextLineDataset(self._createTextFiles())
|
dataset = readers.TextLineDataset(self._createTextFiles())
|
||||||
dataset = input_ops.auto_shard_dataset(
|
dataset = input_ops.auto_shard_dataset(
|
||||||
@ -220,6 +229,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self._verifySimpleShardingOutput(dataset, self._text_line)
|
self._verifySimpleShardingOutput(dataset, self._text_line)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTextLineReaderWithFlatMap(self):
|
def testTextLineReaderWithFlatMap(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
|
dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
|
||||||
dataset = dataset.flat_map(readers.TextLineDataset)
|
dataset = dataset.flat_map(readers.TextLineDataset)
|
||||||
@ -228,6 +238,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self._verifySimpleShardingOutput(dataset, self._text_line)
|
self._verifySimpleShardingOutput(dataset, self._text_line)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFixedLengthReader(self):
|
def testFixedLengthReader(self):
|
||||||
dataset = readers.FixedLengthRecordDataset(
|
dataset = readers.FixedLengthRecordDataset(
|
||||||
self._createFixedLengthRecordFiles(), self._record_bytes)
|
self._createFixedLengthRecordFiles(), self._record_bytes)
|
||||||
@ -236,6 +247,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
|
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFixedLengthReaderWithFlatMap(self):
|
def testFixedLengthReaderWithFlatMap(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
self._createFixedLengthRecordFiles())
|
self._createFixedLengthRecordFiles())
|
||||||
|
@ -187,6 +187,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(2, g(constant_op.constant(2.)))
|
self.assertAllEqual(2, g(constant_op.constant(2.)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGraphModeEagerGradError(self):
|
def testGraphModeEagerGradError(self):
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
def f():
|
def f():
|
||||||
|
@ -1283,6 +1283,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
defined.get_concrete_function(
|
defined.get_concrete_function(
|
||||||
tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
|
tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInputSignatureForFunctionWithNonTensorInputsNotAllowed(self):
|
def testInputSignatureForFunctionWithNonTensorInputsNotAllowed(self):
|
||||||
|
|
||||||
def foo(a, training=True):
|
def foo(a, training=True):
|
||||||
|
@ -29,12 +29,14 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
|
class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGraphZerosLike(self):
|
def testGraphZerosLike(self):
|
||||||
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
|
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
|
||||||
z_tf = graph_only_ops.graph_zeros_like(x)
|
z_tf = graph_only_ops.graph_zeros_like(x)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(np.zeros((2, 3)), self.evaluate(z_tf))
|
self.assertAllClose(np.zeros((2, 3)), self.evaluate(z_tf))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGraphPlaceholder(self):
|
def testGraphPlaceholder(self):
|
||||||
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
|
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
|
||||||
y_tf = math_ops.square(x_tf)
|
y_tf = math_ops.square(x_tf)
|
||||||
|
@ -170,6 +170,7 @@ class LazyColumnTest(test.TestCase):
|
|||||||
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
|
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
|
||||||
builder.get(NotAFeatureColumn())
|
builder.get(NotAFeatureColumn())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
|
def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
|
||||||
# empty 1-D sparse tensor:
|
# empty 1-D sparse tensor:
|
||||||
builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
|
builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
|
||||||
@ -185,6 +186,7 @@ class LazyColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class NumericColumnTest(test.TestCase):
|
class NumericColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
a = fc._numeric_column('aaa')
|
a = fc._numeric_column('aaa')
|
||||||
self.assertEqual('aaa', a.key)
|
self.assertEqual('aaa', a.key)
|
||||||
@ -263,6 +265,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
|
'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
|
||||||
}, a._parse_example_spec)
|
}, a._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_no_default_value(self):
|
def test_parse_example_no_default_value(self):
|
||||||
price = fc._numeric_column('price', shape=[2])
|
price = fc._numeric_column('price', shape=[2])
|
||||||
data = example_pb2.Example(features=feature_pb2.Features(
|
data = example_pb2.Example(features=feature_pb2.Features(
|
||||||
@ -278,6 +281,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_with_default_value(self):
|
def test_parse_example_with_default_value(self):
|
||||||
price = fc._numeric_column('price', shape=[2], default_value=11.)
|
price = fc._numeric_column('price', shape=[2], default_value=11.)
|
||||||
data = example_pb2.Example(features=feature_pb2.Features(
|
data = example_pb2.Example(features=feature_pb2.Features(
|
||||||
@ -304,6 +308,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
|
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
|
||||||
fc._numeric_column('price', normalizer_fn='NotACallable')
|
fc._numeric_column('price', normalizer_fn='NotACallable')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_normalizer_fn_transform_feature(self):
|
def test_normalizer_fn_transform_feature(self):
|
||||||
|
|
||||||
def _increment_two(input_tensor):
|
def _increment_two(input_tensor):
|
||||||
@ -314,6 +319,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
|
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
|
|
||||||
def _increment_two(input_tensor):
|
def _increment_two(input_tensor):
|
||||||
@ -333,6 +339,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
||||||
price._transform_feature(builder)
|
price._transform_feature(builder)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc._numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
|
a = fc._numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
|
||||||
a_copy = copy.deepcopy(a)
|
a_copy = copy.deepcopy(a)
|
||||||
@ -345,6 +352,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
|
'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
|
||||||
self.assertEqual(a.default_value, ((3., 2.),))
|
self.assertEqual(a.default_value, ((3., 2.),))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -359,6 +367,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
sess.run(price_var.assign([[10.]]))
|
sess.run(price_var.assign([[10.]]))
|
||||||
self.assertAllClose([[10.], [50.]], self.evaluate(predictions))
|
self.assertAllClose([[10.], [50.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -433,6 +442,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
|
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
|
||||||
self.assertEqual(6, b._num_buckets)
|
self.assertEqual(6, b._num_buckets)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
price = fc._numeric_column('price', shape=[2])
|
price = fc._numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -449,6 +459,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
price = fc._numeric_column('price', shape=[2])
|
price = fc._numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc._bucketized_column(price, boundaries=[0, 2, 4, 6])
|
bucketized_price = fc._bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||||
@ -531,6 +542,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
||||||
bucketized_price._transform_feature(builder)
|
bucketized_price._transform_feature(builder)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc._numeric_column('aaa', shape=[2])
|
a = fc._numeric_column('aaa', shape=[2])
|
||||||
a_bucketized = fc._bucketized_column(a, boundaries=[0, 1])
|
a_bucketized = fc._bucketized_column(a, boundaries=[0, 1])
|
||||||
@ -658,6 +670,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class HashedCategoricalColumnTest(test.TestCase):
|
class HashedCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
a = fc._categorical_column_with_hash_bucket('aaa', 10)
|
a = fc._categorical_column_with_hash_bucket('aaa', 10)
|
||||||
self.assertEqual('aaa', a.name)
|
self.assertEqual('aaa', a.name)
|
||||||
@ -685,6 +698,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
||||||
fc._categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
|
fc._categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc._categorical_column_with_hash_bucket('aaa', 10)
|
original = fc._categorical_column_with_hash_bucket('aaa', 10)
|
||||||
for column in (original, copy.deepcopy(original)):
|
for column in (original, copy.deepcopy(original)):
|
||||||
@ -705,6 +719,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, a._parse_example_spec)
|
}, a._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_hash_bucket('aaa', 10)
|
a = fc._categorical_column_with_hash_bucket('aaa', 10)
|
||||||
data = example_pb2.Example(features=feature_pb2.Features(
|
data = example_pb2.Example(features=feature_pb2.Features(
|
||||||
@ -726,6 +741,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_strings_should_be_hashed(self):
|
def test_strings_should_be_hashed(self):
|
||||||
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
||||||
wire_tensor = sparse_tensor.SparseTensor(
|
wire_tensor = sparse_tensor.SparseTensor(
|
||||||
@ -781,6 +797,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
builder.get(hashed_sparse)
|
builder.get(hashed_sparse)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_ints_should_be_hashed(self):
|
def test_ints_should_be_hashed(self):
|
||||||
hashed_sparse = fc._categorical_column_with_hash_bucket(
|
hashed_sparse = fc._categorical_column_with_hash_bucket(
|
||||||
'wire', 10, dtype=dtypes.int64)
|
'wire', 10, dtype=dtypes.int64)
|
||||||
@ -795,6 +812,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual(expected_values, output.values.eval())
|
self.assertAllEqual(expected_values, output.values.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_int32_64_is_compatible(self):
|
def test_int32_64_is_compatible(self):
|
||||||
hashed_sparse = fc._categorical_column_with_hash_bucket(
|
hashed_sparse = fc._categorical_column_with_hash_bucket(
|
||||||
'wire', 10, dtype=dtypes.int64)
|
'wire', 10, dtype=dtypes.int64)
|
||||||
@ -809,6 +827,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual(expected_values, output.values.eval())
|
self.assertAllEqual(expected_values, output.values.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
||||||
builder = _LazyBuilder({
|
builder = _LazyBuilder({
|
||||||
@ -837,6 +856,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc._categorical_column_with_hash_bucket('wire', 10)
|
||||||
builder = _LazyBuilder({'wire': (('omar', ''), ('stringer', 'marlo'))})
|
builder = _LazyBuilder({'wire': (('omar', ''), ('stringer', 'marlo'))})
|
||||||
@ -844,6 +864,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
|
self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_hash_bucket('wire', 4)
|
wire_column = fc._categorical_column_with_hash_bucket('wire', 4)
|
||||||
self.assertEqual(4, wire_column._num_buckets)
|
self.assertEqual(4, wire_column._num_buckets)
|
||||||
@ -866,6 +887,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
|
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
|
||||||
self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions))
|
self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_hash_bucket('wire', 4)
|
wire_column = fc._categorical_column_with_hash_bucket('wire', 4)
|
||||||
self.assertEqual(4, wire_column._num_buckets)
|
self.assertEqual(4, wire_column._num_buckets)
|
||||||
@ -975,6 +997,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
crossed = fc._crossed_column([b, 'c'], 15)
|
crossed = fc._crossed_column([b, 'c'], 15)
|
||||||
self.assertEqual(15, crossed._num_buckets)
|
self.assertEqual(15, crossed._num_buckets)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc._numeric_column('a', dtype=dtypes.int32)
|
a = fc._numeric_column('a', dtype=dtypes.int32)
|
||||||
b = fc._bucketized_column(a, boundaries=[0, 1])
|
b = fc._bucketized_column(a, boundaries=[0, 1])
|
||||||
@ -985,6 +1008,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
||||||
self.assertEqual(5, crossed2_copy.hash_key)
|
self.assertEqual(5, crossed2_copy.hash_key)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
price = fc._numeric_column('price', shape=[2])
|
price = fc._numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -1011,6 +1035,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
|
self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
|
||||||
self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
|
self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
price = fc._numeric_column('price', shape=[2])
|
price = fc._numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc._bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -1034,6 +1059,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertIn(val, list(range(hash_bucket_size)))
|
self.assertIn(val, list(range(hash_bucket_size)))
|
||||||
self.assertAllEqual([2, 4], output_val.dense_shape)
|
self.assertAllEqual([2, 4], output_val.dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc._bucketized_column(a, boundaries=(0, 1))
|
b = fc._bucketized_column(a, boundaries=(0, 1))
|
||||||
@ -1101,6 +1127,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||||
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
@ -1182,6 +1209,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
dense_shape=(2, 2)),
|
dense_shape=(2, 2)),
|
||||||
}, (crossed,))
|
}, (crossed,))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
"""Tests _LinearModel.
|
"""Tests _LinearModel.
|
||||||
|
|
||||||
@ -1854,6 +1882,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
price_buckets = fc._bucketized_column(
|
price_buckets = fc._bucketized_column(
|
||||||
@ -1889,6 +1918,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
self.evaluate(net))
|
self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
price_buckets = fc._bucketized_column(
|
price_buckets = fc._bucketized_column(
|
||||||
@ -1936,6 +1966,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
features = {
|
features = {
|
||||||
@ -2488,6 +2519,7 @@ class _LinearModelTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
price_buckets = fc._bucketized_column(
|
price_buckets = fc._bucketized_column(
|
||||||
@ -2529,6 +2561,7 @@ class _LinearModelTest(test.TestCase):
|
|||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
self.evaluate(net))
|
self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
price_buckets = fc._bucketized_column(
|
price_buckets = fc._bucketized_column(
|
||||||
@ -2575,6 +2608,7 @@ class _LinearModelTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
features = {
|
features = {
|
||||||
@ -2815,6 +2849,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
variables_lib.Variable)
|
variables_lib.Variable)
|
||||||
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
|
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_fills_cols_to_vars_shared_embedding(self):
|
def test_fills_cols_to_vars_shared_embedding(self):
|
||||||
# Provide 5 DenseColumn's to input_layer: a NumericColumn, a
|
# Provide 5 DenseColumn's to input_layer: a NumericColumn, a
|
||||||
# BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
|
# BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
|
||||||
@ -3012,6 +3047,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_multiple_layers_with_same_shared_embedding_column(self):
|
def test_multiple_layers_with_same_shared_embedding_column(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -3045,6 +3081,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
|
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -3096,6 +3133,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2., 3., 4., 5.), # id 0
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
@ -3146,6 +3184,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
||||||
sess.run(net))
|
sess.run(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
@ -3205,6 +3244,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
# price has 1 dimension in input_layer
|
# price has 1 dimension in input_layer
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
@ -3335,6 +3375,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
'python/feature_column/testdata/wire_vocabulary.txt')
|
'python/feature_column/testdata/wire_vocabulary.txt')
|
||||||
self._wire_vocabulary_size = 3
|
self._wire_vocabulary_size = 3
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
@ -3351,6 +3392,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
fc._categorical_column_with_vocabulary_file(
|
fc._categorical_column_with_vocabulary_file(
|
||||||
key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
|
key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3363,6 +3405,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, column._parse_example_spec)
|
}, column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc._categorical_column_with_vocabulary_file(
|
original = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3387,6 +3430,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
fc._categorical_column_with_vocabulary_file(
|
fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='', vocabulary_size=3)
|
key='aaa', vocabulary_file='', vocabulary_size=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_vocabulary_file(self):
|
def test_invalid_vocabulary_file(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
|
key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
|
||||||
@ -3411,6 +3455,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
vocabulary_file=self._wire_vocabulary_file_name,
|
vocabulary_file=self._wire_vocabulary_file_name,
|
||||||
vocabulary_size=0)
|
vocabulary_size=0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_too_large_vocabulary_size(self):
|
def test_too_large_vocabulary_size(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3477,6 +3522,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_vocabulary_file(
|
a = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
@ -3499,6 +3545,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3519,6 +3566,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_none_vocabulary_size(self):
|
def test_get_sparse_tensors_none_vocabulary_size(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
|
key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
|
||||||
@ -3537,6 +3585,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3573,6 +3622,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3592,6 +3642,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=(2, 2)),
|
dense_shape=(2, 2)),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3613,6 +3664,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_oov_buckets(self):
|
def test_get_sparse_tensors_with_oov_buckets(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3634,6 +3686,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_small_vocabulary_size(self):
|
def test_get_sparse_tensors_small_vocabulary_size(self):
|
||||||
# 'marlo' is the last entry in our vocabulary file, so be setting
|
# 'marlo' is the last entry in our vocabulary file, so be setting
|
||||||
# `vocabulary_size` to 1 less than number of entries in file, we take
|
# `vocabulary_size` to 1 less than number of entries in file, we take
|
||||||
@ -3657,6 +3710,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32(self):
|
def test_get_sparse_tensors_int32(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3678,6 +3732,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_dense_input(self):
|
def test_get_sparse_tensors_int32_dense_input(self):
|
||||||
default_value = -100
|
default_value = -100
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
@ -3700,6 +3755,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=(3, 3)),
|
dense_shape=(3, 3)),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
||||||
column = fc._categorical_column_with_vocabulary_file(
|
column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3722,6 +3778,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_vocabulary_file(
|
wire_column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='wire',
|
key='wire',
|
||||||
@ -3748,6 +3805,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
||||||
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_vocabulary_file(
|
wire_column = fc._categorical_column_with_vocabulary_file(
|
||||||
key='wire',
|
key='wire',
|
||||||
@ -3805,6 +3863,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, column._parse_example_spec)
|
}, column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -3816,6 +3875,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, column._parse_example_spec)
|
}, column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc._categorical_column_with_vocabulary_list(
|
original = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
||||||
@ -3904,6 +3964,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_string(self):
|
def test_parse_example_string(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -3926,6 +3987,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_int(self):
|
def test_parse_example_int(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=(11, 21, 31))
|
key='aaa', vocabulary_list=(11, 21, 31))
|
||||||
@ -3948,6 +4010,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -3966,6 +4029,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -3998,6 +4062,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -4015,6 +4080,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=(2, 2)),
|
dense_shape=(2, 2)),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4035,6 +4101,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_oov_buckets(self):
|
def test_get_sparse_tensors_with_oov_buckets(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4055,6 +4122,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32(self):
|
def test_get_sparse_tensors_int32(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4075,6 +4143,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_dense_input(self):
|
def test_get_sparse_tensors_int32_dense_input(self):
|
||||||
default_value = -100
|
default_value = -100
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
@ -4098,6 +4167,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=(3, 3)),
|
dense_shape=(3, 3)),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
||||||
column = fc._categorical_column_with_vocabulary_list(
|
column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4119,6 +4189,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_vocabulary_list(
|
wire_column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4144,6 +4215,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
||||||
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
wire_column = fc._categorical_column_with_vocabulary_list(
|
wire_column = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4187,6 +4259,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
|
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
|
||||||
fc._categorical_column_with_identity(key=('aaa',), num_buckets=3)
|
fc._categorical_column_with_identity(key=('aaa',), num_buckets=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
original = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
for column in (original, copy.deepcopy(original)):
|
for column in (original, copy.deepcopy(original)):
|
||||||
@ -4223,6 +4296,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
|
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
|
||||||
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_identity(key='aaa', num_buckets=30)
|
a = fc._categorical_column_with_identity(key='aaa', num_buckets=30)
|
||||||
data = example_pb2.Example(features=feature_pb2.Features(
|
data = example_pb2.Example(features=feature_pb2.Features(
|
||||||
@ -4244,6 +4318,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -4261,6 +4336,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -4291,6 +4367,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
id_weight_pair = column._get_sparse_tensors(
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
@ -4307,6 +4384,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=(2, 2)),
|
dense_shape=(2, 2)),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_inputs_too_small(self):
|
def test_get_sparse_tensors_with_inputs_too_small(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -4320,6 +4398,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
errors.OpError, 'assert_greater_or_equal_0'):
|
errors.OpError, 'assert_greater_or_equal_0'):
|
||||||
id_weight_pair.id_tensor.eval()
|
id_weight_pair.id_tensor.eval()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_inputs_too_big(self):
|
def test_get_sparse_tensors_with_inputs_too_big(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -4333,6 +4412,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
errors.OpError, 'assert_less_than_num_buckets'):
|
errors.OpError, 'assert_less_than_num_buckets'):
|
||||||
id_weight_pair.id_tensor.eval()
|
id_weight_pair.id_tensor.eval()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_default_value(self):
|
def test_get_sparse_tensors_with_default_value(self):
|
||||||
column = fc._categorical_column_with_identity(
|
column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=4, default_value=3)
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
@ -4351,6 +4431,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
id_weight_pair.id_tensor.eval())
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
||||||
column = fc._categorical_column_with_identity(
|
column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=4, default_value=3)
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
@ -4376,6 +4457,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
input_shape: (2, 2),
|
input_shape: (2, 2),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
self.assertEqual(3, column._num_buckets)
|
self.assertEqual(3, column._num_buckets)
|
||||||
@ -4397,6 +4479,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
# weight_var[2] + weight_var[1] = 3+2 = 5
|
# weight_var[2] + weight_var[1] = 3+2 = 5
|
||||||
self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
self.assertEqual(3, column._num_buckets)
|
self.assertEqual(3, column._num_buckets)
|
||||||
@ -4548,6 +4631,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual([[0., 1., 1., 0.]], self.evaluate(output))
|
self.assertAllEqual([[0., 1., 1., 0.]], self.evaluate(output))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc._categorical_column_with_hash_bucket('a', 4)
|
a = fc._categorical_column_with_hash_bucket('a', 4)
|
||||||
column = fc._indicator_column(a)
|
column = fc._indicator_column(a)
|
||||||
@ -4556,6 +4640,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
self.assertEqual(column.name, 'a_indicator')
|
self.assertEqual(column.name, 'a_indicator')
|
||||||
self.assertEqual(column._variable_shape, [1, 4])
|
self.assertEqual(column._variable_shape, [1, 4])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -4579,6 +4664,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform(self):
|
def test_transform(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -4594,6 +4680,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual([[0, 0, 1], [1, 0, 0]],
|
self.assertAllEqual([[0, 0, 1], [1, 0, 0]],
|
||||||
self.evaluate(indicator_tensor))
|
self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_weighted_column(self):
|
def test_transform_with_weighted_column(self):
|
||||||
# Github issue 12557
|
# Github issue 12557
|
||||||
ids = fc._categorical_column_with_vocabulary_list(
|
ids = fc._categorical_column_with_vocabulary_list(
|
||||||
@ -4608,6 +4695,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual([[6., 4., 3.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[6., 4., 3.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_missing_value_in_weighted_column(self):
|
def test_transform_with_missing_value_in_weighted_column(self):
|
||||||
# Github issue 12583
|
# Github issue 12583
|
||||||
ids = fc._categorical_column_with_vocabulary_list(
|
ids = fc._categorical_column_with_vocabulary_list(
|
||||||
@ -4622,6 +4710,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual([[0., 4., 2.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[0., 4., 2.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_missing_value_in_categorical_column(self):
|
def test_transform_with_missing_value_in_categorical_column(self):
|
||||||
# Github issue 12583
|
# Github issue 12583
|
||||||
ids = fc._categorical_column_with_vocabulary_list(
|
ids = fc._categorical_column_with_vocabulary_list(
|
||||||
@ -4634,6 +4723,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual([[0., 1., 1.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[0., 1., 1.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
animal = fc._indicator_column(
|
animal = fc._indicator_column(
|
||||||
fc._categorical_column_with_identity('animal', num_buckets=4))
|
fc._categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -4653,6 +4743,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
|
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
|
||||||
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
animal = fc._indicator_column(
|
animal = fc._indicator_column(
|
||||||
fc._categorical_column_with_identity('animal', num_buckets=4))
|
fc._categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -4672,6 +4763,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
|
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
|
||||||
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer(self):
|
def test_input_layer(self):
|
||||||
animal = fc._indicator_column(
|
animal = fc._indicator_column(
|
||||||
fc._categorical_column_with_identity('animal', num_buckets=4))
|
fc._categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -4688,6 +4780,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class EmbeddingColumnTest(test.TestCase):
|
class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
categorical_column = fc._categorical_column_with_identity(
|
categorical_column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -4709,6 +4802,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column._parse_example_spec)
|
}, embedding_column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
categorical_column = fc._categorical_column_with_identity(
|
categorical_column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -4737,6 +4831,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column._parse_example_spec)
|
}, embedding_column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
categorical_column = fc._categorical_column_with_identity(
|
categorical_column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -4770,6 +4865,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column._parse_example_spec)
|
}, embedding_column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_initializer(self):
|
def test_invalid_initializer(self):
|
||||||
categorical_column = fc._categorical_column_with_identity(
|
categorical_column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -4777,6 +4873,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
fc._embedding_column(
|
fc._embedding_column(
|
||||||
categorical_column, dimension=2, initializer='not_fn')
|
categorical_column, dimension=2, initializer='not_fn')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -4800,6 +4897,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['aaa'].eval())
|
features['aaa'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
a = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
a = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
a_embedded = fc._embedding_column(a, dimension=2)
|
a_embedded = fc._embedding_column(a, dimension=2)
|
||||||
@ -4816,6 +4914,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
||||||
self.evaluate(output_embedded))
|
self.evaluate(output_embedded))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -4875,6 +4974,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_3d(self):
|
def test_get_dense_tensor_3d(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 4
|
vocabulary_size = 4
|
||||||
@ -4936,6 +5036,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_weight_collections(self):
|
def test_get_dense_tensor_weight_collections(self):
|
||||||
sparse_input = sparse_tensor.SparseTensorValue(
|
sparse_input = sparse_tensor.SparseTensorValue(
|
||||||
# example 0, ids [2]
|
# example 0, ids [2]
|
||||||
@ -4965,6 +5066,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertItemsEqual(
|
self.assertItemsEqual(
|
||||||
('embedding_weights:0',), tuple([v.name for v in my_vars]))
|
('embedding_weights:0',), tuple([v.name for v in my_vars]))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_placeholder_inputs(self):
|
def test_get_dense_tensor_placeholder_inputs(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5036,6 +5138,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
input_shape: sparse_input.dense_shape,
|
input_shape: sparse_input.dense_shape,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_restore_from_ckpt(self):
|
def test_get_dense_tensor_restore_from_ckpt(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5094,6 +5197,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@ -5173,6 +5277,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
||||||
self.evaluate(predictions))
|
self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@ -5252,6 +5357,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
||||||
self.evaluate(predictions))
|
self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer(self):
|
def test_input_layer(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5313,6 +5419,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
|
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer_not_trainable(self):
|
def test_input_layer_not_trainable(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5376,6 +5483,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class SharedEmbeddingColumnTest(test.TestCase):
|
class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5420,6 +5528,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_b._parse_example_spec)
|
}, embedding_column_b._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5471,6 +5580,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_b._parse_example_spec)
|
}, embedding_column_b._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5509,6 +5619,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_a._parse_example_spec)
|
}, embedding_column_a._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_initializer(self):
|
def test_invalid_initializer(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5520,6 +5631,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
dimension=2,
|
dimension=2,
|
||||||
initializer='not_fn')
|
initializer='not_fn')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_incompatible_column_type(self):
|
def test_incompatible_column_type(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5535,6 +5647,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
[categorical_column_a, categorical_column_b, categorical_column_c],
|
[categorical_column_a, categorical_column_b, categorical_column_c],
|
||||||
dimension=2)
|
dimension=2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_weighted_categorical_column_ok(self):
|
def test_weighted_categorical_column_ok(self):
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -5552,6 +5665,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
[weighted_categorical_column_a, weighted_categorical_column_b],
|
[weighted_categorical_column_a, weighted_categorical_column_b],
|
||||||
dimension=2)
|
dimension=2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5589,6 +5703,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['bbb'].eval())
|
features['bbb'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
a = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
a = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
b = fc._categorical_column_with_identity(key='bbb', num_buckets=3)
|
b = fc._categorical_column_with_identity(key='bbb', num_buckets=3)
|
||||||
@ -5615,6 +5730,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
||||||
self.evaluate(output_b_embedded))
|
self.evaluate(output_b_embedded))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5683,6 +5799,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
||||||
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_weight_collections(self):
|
def test_get_dense_tensor_weight_collections(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5735,6 +5852,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
||||||
tuple(v.name for v in my_vars))
|
tuple(v.name for v in my_vars))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_placeholder_inputs(self):
|
def test_get_dense_tensor_placeholder_inputs(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -5791,6 +5909,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
with _initialized_session() as sess:
|
with _initialized_session() as sess:
|
||||||
sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
|
sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@ -5881,6 +6000,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
|
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
|
||||||
self.assertAllClose([[94. + 13.], [29.]], self.evaluate(predictions))
|
self.assertAllClose([[94. + 13.], [29.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@ -6048,15 +6168,18 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
|
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer(self):
|
def test_input_layer(self):
|
||||||
self._test_input_layer()
|
self._test_input_layer()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer_no_trainable(self):
|
def test_input_layer_no_trainable(self):
|
||||||
self._test_input_layer(trainable=False)
|
self._test_input_layer(trainable=False)
|
||||||
|
|
||||||
|
|
||||||
class WeightedCategoricalColumnTest(test.TestCase):
|
class WeightedCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
@ -6070,6 +6193,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
'values': parsing_ops.VarLenFeature(dtypes.float32)
|
'values': parsing_ops.VarLenFeature(dtypes.float32)
|
||||||
}, column._parse_example_spec)
|
}, column._parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||||
original = fc._weighted_categorical_column(
|
original = fc._weighted_categorical_column(
|
||||||
@ -6132,6 +6256,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
ValueError, 'values is not in features dictionary'):
|
ValueError, 'values is not in features dictionary'):
|
||||||
_transform_features({'ids': inputs}, (column,))
|
_transform_features({'ids': inputs}, (column,))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc._categorical_column_with_vocabulary_list(
|
a = fc._categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -6167,6 +6292,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=[1, 2]),
|
dense_shape=[1, 2]),
|
||||||
features['weights'].eval())
|
features['weights'].eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features(self):
|
def test_transform_features(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
@ -6198,6 +6324,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array(weights.values, dtype=np.float32),
|
values=np.array(weights.values, dtype=np.float32),
|
||||||
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features_dense_input(self):
|
def test_transform_features_dense_input(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
@ -6225,6 +6352,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array(weights.values, dtype=np.float32),
|
values=np.array(weights.values, dtype=np.float32),
|
||||||
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features_dense_weights(self):
|
def test_transform_features_dense_weights(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
@ -6252,6 +6380,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((.5, 1., .1), dtype=np.float32),
|
values=np.array((.5, 1., .1), dtype=np.float32),
|
||||||
dense_shape=(2, 2)), self.evaluate(weight_tensor))
|
dense_shape=(2, 2)), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
@ -6354,6 +6483,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
# = 3*1 + 2*.1 = 3+.2 = 3.2
|
# = 3*1 + 2*.1 = 3+.2 = 3.2
|
||||||
self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions))
|
self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
column = fc._weighted_categorical_column(
|
column = fc._weighted_categorical_column(
|
||||||
categorical_column=fc._categorical_column_with_identity(
|
categorical_column=fc._categorical_column_with_identity(
|
||||||
|
@ -218,6 +218,7 @@ class LazyColumnTest(test.TestCase):
|
|||||||
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
|
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
|
||||||
transformation_cache.get(NotAFeatureColumn(), None)
|
transformation_cache.get(NotAFeatureColumn(), None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
|
def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
|
||||||
# empty 1-D sparse tensor:
|
# empty 1-D sparse tensor:
|
||||||
transformation_cache = fc.FeatureTransformationCache(
|
transformation_cache = fc.FeatureTransformationCache(
|
||||||
@ -237,6 +238,7 @@ class LazyColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class NumericColumnTest(test.TestCase):
|
class NumericColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
a = fc.numeric_column('aaa')
|
a = fc.numeric_column('aaa')
|
||||||
self.assertEqual('aaa', a.key)
|
self.assertEqual('aaa', a.key)
|
||||||
@ -315,6 +317,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
|
'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
|
||||||
}, a.parse_example_spec)
|
}, a.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_no_default_value(self):
|
def test_parse_example_no_default_value(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
data = example_pb2.Example(
|
data = example_pb2.Example(
|
||||||
@ -331,6 +334,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[20., 110.]], self.evaluate(features['price']))
|
self.assertAllEqual([[20., 110.]], self.evaluate(features['price']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_with_default_value(self):
|
def test_parse_example_with_default_value(self):
|
||||||
price = fc.numeric_column('price', shape=[2], default_value=11.)
|
price = fc.numeric_column('price', shape=[2], default_value=11.)
|
||||||
data = example_pb2.Example(
|
data = example_pb2.Example(
|
||||||
@ -360,6 +364,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
|
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
|
||||||
fc.numeric_column('price', normalizer_fn='NotACallable')
|
fc.numeric_column('price', normalizer_fn='NotACallable')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_normalizer_fn_transform_feature(self):
|
def test_normalizer_fn_transform_feature(self):
|
||||||
|
|
||||||
def _increment_two(input_tensor):
|
def _increment_two(input_tensor):
|
||||||
@ -372,6 +377,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[3., 4.], [7., 8.]], self.evaluate(output[price]))
|
self.assertAllEqual([[3., 4.], [7., 8.]], self.evaluate(output[price]))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
|
|
||||||
def _increment_two(input_tensor):
|
def _increment_two(input_tensor):
|
||||||
@ -395,6 +401,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
||||||
price.transform_feature(transformation_cache, None)
|
price.transform_feature(transformation_cache, None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
|
a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
|
||||||
a_copy = copy.deepcopy(a)
|
a_copy = copy.deepcopy(a)
|
||||||
@ -407,6 +414,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
|
'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
|
||||||
self.assertEqual(a.default_value, ((3., 2.),))
|
self.assertEqual(a.default_value, ((3., 2.),))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -435,6 +443,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
sess.run(price_var.assign([[10.]]))
|
sess.run(price_var.assign([[10.]]))
|
||||||
self.assertAllClose([[10.], [50.]], self.evaluate(predictions))
|
self.assertAllClose([[10.], [50.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
|
|
||||||
def _increment_two(input_tensor):
|
def _increment_two(input_tensor):
|
||||||
@ -519,6 +528,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
|
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
|
||||||
self.assertEqual(6, b.num_buckets)
|
self.assertEqual(6, b.num_buckets)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -536,6 +546,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[20., 110.]], self.evaluate(features['price']))
|
self.assertAllEqual([[20., 110.]], self.evaluate(features['price']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||||
@ -639,6 +650,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
|
||||||
bucketized_price.transform_feature(transformation_cache, None)
|
bucketized_price.transform_feature(transformation_cache, None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc.numeric_column('aaa', shape=[2])
|
a = fc.numeric_column('aaa', shape=[2])
|
||||||
a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
|
a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
@ -789,6 +801,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
self.assertAllClose([[11.], [21.], [41.], [51.]],
|
self.assertAllClose([[11.], [21.], [41.], [51.]],
|
||||||
self.evaluate(predictions))
|
self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||||
@ -821,6 +834,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class HashedCategoricalColumnTest(test.TestCase):
|
class HashedCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||||
self.assertEqual('aaa', a.name)
|
self.assertEqual('aaa', a.name)
|
||||||
@ -848,6 +862,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
||||||
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
|
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc.categorical_column_with_hash_bucket('aaa', 10)
|
original = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||||
for column in (original, copy.deepcopy(original)):
|
for column in (original, copy.deepcopy(original)):
|
||||||
@ -868,6 +883,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, a.parse_example_spec)
|
}, a.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||||
data = example_pb2.Example(
|
data = example_pb2.Example(
|
||||||
@ -890,6 +906,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_strings_should_be_hashed(self):
|
def test_strings_should_be_hashed(self):
|
||||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||||
wire_tensor = sparse_tensor.SparseTensor(
|
wire_tensor = sparse_tensor.SparseTensor(
|
||||||
@ -943,6 +960,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
transformation_cache.get(hashed_sparse, None)
|
transformation_cache.get(hashed_sparse, None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_ints_should_be_hashed(self):
|
def test_ints_should_be_hashed(self):
|
||||||
hashed_sparse = fc.categorical_column_with_hash_bucket(
|
hashed_sparse = fc.categorical_column_with_hash_bucket(
|
||||||
'wire', 10, dtype=dtypes.int64)
|
'wire', 10, dtype=dtypes.int64)
|
||||||
@ -957,6 +975,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(expected_values, self.evaluate(output.values))
|
self.assertAllEqual(expected_values, self.evaluate(output.values))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_int32_64_is_compatible(self):
|
def test_int32_64_is_compatible(self):
|
||||||
hashed_sparse = fc.categorical_column_with_hash_bucket(
|
hashed_sparse = fc.categorical_column_with_hash_bucket(
|
||||||
'wire', 10, dtype=dtypes.int64)
|
'wire', 10, dtype=dtypes.int64)
|
||||||
@ -971,6 +990,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(expected_values, self.evaluate(output.values))
|
self.assertAllEqual(expected_values, self.evaluate(output.values))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||||
transformation_cache = fc.FeatureTransformationCache({
|
transformation_cache = fc.FeatureTransformationCache({
|
||||||
@ -986,6 +1006,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
|
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||||
transformation_cache = fc.FeatureTransformationCache({
|
transformation_cache = fc.FeatureTransformationCache({
|
||||||
@ -997,6 +1018,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
|
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
||||||
self.assertEqual(4, wire_column.num_buckets)
|
self.assertEqual(4, wire_column.num_buckets)
|
||||||
@ -1047,6 +1069,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
|
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
|
||||||
self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions))
|
self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
||||||
self.assertEqual(['wire'], wire_column.parents)
|
self.assertEqual(['wire'], wire_column.parents)
|
||||||
@ -1148,6 +1171,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
crossed = fc.crossed_column([b, 'c'], 15)
|
crossed = fc.crossed_column([b, 'c'], 15)
|
||||||
self.assertEqual(15, crossed.num_buckets)
|
self.assertEqual(15, crossed.num_buckets)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc.numeric_column('a', dtype=dtypes.int32)
|
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
@ -1161,6 +1185,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
||||||
self.assertEqual(5, crossed2_copy.hash_key)
|
self.assertEqual(5, crossed2_copy.hash_key)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -1190,6 +1215,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.evaluate(wire_sparse.values))
|
self.evaluate(wire_sparse.values))
|
||||||
self.assertAllEqual([1, 2], self.evaluate(wire_sparse.dense_shape))
|
self.assertAllEqual([1, 2], self.evaluate(wire_sparse.dense_shape))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
price = fc.numeric_column('price', shape=[2])
|
price = fc.numeric_column('price', shape=[2])
|
||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||||
@ -1214,6 +1240,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertIn(val, list(range(hash_bucket_size)))
|
self.assertIn(val, list(range(hash_bucket_size)))
|
||||||
self.assertAllEqual([2, 4], output_val.dense_shape)
|
self.assertAllEqual([2, 4], output_val.dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
@ -1285,6 +1312,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||||
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
@ -1520,6 +1548,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
sess.run(bias.assign((.1,)))
|
sess.run(bias.assign((.1,)))
|
||||||
self.assertAllClose(((3.1,), (14.1,)), self.evaluate(predictions))
|
self.assertAllClose(((3.1,), (14.1,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
@ -2077,6 +2106,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_numpy_input_fn(self):
|
def test_with_numpy_input_fn(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
price_buckets = fc.bucketized_column(
|
price_buckets = fc.bucketized_column(
|
||||||
@ -2115,6 +2145,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
coord.join(threads)
|
coord.join(threads)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
price_buckets = fc.bucketized_column(
|
price_buckets = fc.bucketized_column(
|
||||||
@ -2154,6 +2185,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
self.evaluate(net))
|
self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
price_buckets = fc.bucketized_column(
|
price_buckets = fc.bucketized_column(
|
||||||
@ -2198,6 +2230,7 @@ class LinearModelTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
features = {
|
features = {
|
||||||
@ -2835,6 +2868,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
price_buckets = fc.bucketized_column(
|
price_buckets = fc.bucketized_column(
|
||||||
@ -2875,6 +2909,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
self.evaluate(net))
|
self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
price_buckets = fc.bucketized_column(
|
price_buckets = fc.bucketized_column(
|
||||||
@ -2920,6 +2955,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
features = {
|
features = {
|
||||||
@ -2962,6 +2998,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
sess.run(bias2.assign([5.]))
|
sess.run(bias2.assign([5.]))
|
||||||
self.assertAllClose([[25.], [105.]], self.evaluate(predictions2))
|
self.assertAllClose([[25.], [105.]], self.evaluate(predictions2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model_v1_shared_embedding_all_other_v2(self):
|
def test_linear_model_v1_shared_embedding_all_other_v2(self):
|
||||||
price = fc.numeric_column('price') # v2
|
price = fc.numeric_column('price') # v2
|
||||||
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
||||||
@ -3001,6 +3038,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose([0.], self.evaluate(bias))
|
self.assertAllClose([0.], self.evaluate(bias))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model_v1_shared_embedding_with_v2_cat_all_other_v2(self):
|
def test_linear_model_v1_shared_embedding_with_v2_cat_all_other_v2(self):
|
||||||
price = fc.numeric_column('price') # v2
|
price = fc.numeric_column('price') # v2
|
||||||
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
||||||
@ -3040,6 +3078,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose([0.], self.evaluate(bias))
|
self.assertAllClose([0.], self.evaluate(bias))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model_v1_v2_mix(self):
|
def test_linear_model_v1_v2_mix(self):
|
||||||
price = fc.numeric_column('price') # v2
|
price = fc.numeric_column('price') # v2
|
||||||
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
||||||
@ -3079,6 +3118,7 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose([0.], self.evaluate(bias))
|
self.assertAllClose([0.], self.evaluate(bias))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model_v2_shared_embedding_all_other_v1(self):
|
def test_linear_model_v2_shared_embedding_all_other_v1(self):
|
||||||
price = fc.numeric_column('price') # v1
|
price = fc.numeric_column('price') # v1
|
||||||
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
some_sparse_column = fc.categorical_column_with_hash_bucket(
|
||||||
@ -3468,6 +3508,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_multiple_layers_with_same_shared_embedding_column(self):
|
def test_multiple_layers_with_same_shared_embedding_column(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -3501,6 +3542,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
['aaa_bbb_shared_embedding:0'],
|
['aaa_bbb_shared_embedding:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
|
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -3552,6 +3594,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
['aaa_bbb_shared_embedding:0'],
|
['aaa_bbb_shared_embedding:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_numpy_input_fn(self):
|
def test_with_numpy_input_fn(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2., 3., 4., 5.), # id 0
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
@ -3596,6 +3639,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
coord.join(threads)
|
coord.join(threads)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2., 3., 4., 5.), # id 0
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
@ -3652,6 +3696,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
||||||
sess.run(net))
|
sess.run(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
@ -3710,6 +3755,7 @@ class DenseFeaturesTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
# price has 1 dimension in dense_features
|
# price has 1 dimension in dense_features
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
@ -3967,6 +4013,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
variables_lib.Variable)
|
variables_lib.Variable)
|
||||||
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
|
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_fills_cols_to_vars_shared_embedding(self):
|
def test_fills_cols_to_vars_shared_embedding(self):
|
||||||
# Provide 5 DenseColumn's to input_layer: a NumericColumn, a
|
# Provide 5 DenseColumn's to input_layer: a NumericColumn, a
|
||||||
# BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
|
# BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
|
||||||
@ -4167,6 +4214,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_sparse_tensor(self):
|
def test_with_1d_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2., 3., 4., 5.), # id 0
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
@ -4223,6 +4271,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
||||||
sess.run(net))
|
sess.run(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
@ -4281,6 +4330,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
|||||||
features['country']: country_data
|
features['country']: country_data
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_with_rank_0_feature(self):
|
def test_with_rank_0_feature(self):
|
||||||
# price has 1 dimension in input_layer
|
# price has 1 dimension in input_layer
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
@ -4444,6 +4494,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
'python/feature_column/testdata/wire_vocabulary.txt')
|
'python/feature_column/testdata/wire_vocabulary.txt')
|
||||||
self._wire_vocabulary_size = 3
|
self._wire_vocabulary_size = 3
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
@ -4460,6 +4511,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
fc.categorical_column_with_vocabulary_file(
|
fc.categorical_column_with_vocabulary_file(
|
||||||
key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
|
key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4472,6 +4524,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, column.parse_example_spec)
|
}, column.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc.categorical_column_with_vocabulary_file(
|
original = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4496,6 +4549,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
fc.categorical_column_with_vocabulary_file(
|
fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='', vocabulary_size=3)
|
key='aaa', vocabulary_file='', vocabulary_size=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_vocabulary_file(self):
|
def test_invalid_vocabulary_file(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
|
key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
|
||||||
@ -4522,6 +4576,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
vocabulary_file=self._wire_vocabulary_file_name,
|
vocabulary_file=self._wire_vocabulary_file_name,
|
||||||
vocabulary_size=0)
|
vocabulary_size=0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_too_large_vocabulary_size(self):
|
def test_too_large_vocabulary_size(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4596,6 +4651,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), None)
|
}), None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_vocabulary_file(
|
a = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
@ -4619,6 +4675,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4645,6 +4702,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_none_vocabulary_size(self):
|
def test_get_sparse_tensors_none_vocabulary_size(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
|
key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
|
||||||
@ -4669,6 +4727,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4692,6 +4751,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4713,6 +4773,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4740,6 +4801,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_oov_buckets(self):
|
def test_get_sparse_tensors_with_oov_buckets(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4767,6 +4829,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_small_vocabulary_size(self):
|
def test_get_sparse_tensors_small_vocabulary_size(self):
|
||||||
# 'marlo' is the last entry in our vocabulary file, so be setting
|
# 'marlo' is the last entry in our vocabulary file, so be setting
|
||||||
# `vocabulary_size` to 1 less than number of entries in file, we take
|
# `vocabulary_size` to 1 less than number of entries in file, we take
|
||||||
@ -4796,6 +4859,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32(self):
|
def test_get_sparse_tensors_int32(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4823,6 +4887,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_dense_input(self):
|
def test_get_sparse_tensors_int32_dense_input(self):
|
||||||
default_value = -100
|
default_value = -100
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
@ -4847,6 +4912,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
||||||
dense_shape=(3, 3)), self.evaluate(id_weight_pair.id_tensor))
|
dense_shape=(3, 3)), self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4875,6 +4941,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc.categorical_column_with_vocabulary_file(
|
wire_column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='wire',
|
key='wire',
|
||||||
@ -4933,6 +5000,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
||||||
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
wire_column = fc.categorical_column_with_vocabulary_file(
|
wire_column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='wire',
|
key='wire',
|
||||||
@ -4984,6 +5052,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, column.parse_example_spec)
|
}, column.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -4995,6 +5064,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
}, column.parse_example_spec)
|
}, column.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc.categorical_column_with_vocabulary_list(
|
original = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
||||||
@ -5089,6 +5159,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), None)
|
}), None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_string(self):
|
def test_parse_example_string(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5112,6 +5183,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example_int(self):
|
def test_parse_example_int(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=(11, 21, 31))
|
key='aaa', vocabulary_list=(11, 21, 31))
|
||||||
@ -5133,6 +5205,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
indices=[[0, 0], [0, 1]], values=[11, 21], dense_shape=[1, 2]),
|
indices=[[0, 0], [0, 1]], values=[11, 21], dense_shape=[1, 2]),
|
||||||
self.evaluate(features['aaa']))
|
self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5157,6 +5230,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5178,6 +5252,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5197,6 +5272,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5223,6 +5299,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_oov_buckets(self):
|
def test_get_sparse_tensors_with_oov_buckets(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5249,6 +5326,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32(self):
|
def test_get_sparse_tensors_int32(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5275,6 +5353,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_dense_input(self):
|
def test_get_sparse_tensors_int32_dense_input(self):
|
||||||
default_value = -100
|
default_value = -100
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
@ -5300,6 +5379,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
||||||
dense_shape=(3, 3)), self.evaluate(id_weight_pair.id_tensor))
|
dense_shape=(3, 3)), self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
||||||
column = fc.categorical_column_with_vocabulary_list(
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5327,6 +5407,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
wire_column = fc.categorical_column_with_vocabulary_list(
|
wire_column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5383,6 +5464,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
||||||
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
wire_column = fc.categorical_column_with_vocabulary_list(
|
wire_column = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -5420,6 +5502,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
|
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
|
||||||
fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
|
fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
for column in (original, copy.deepcopy(original)):
|
for column in (original, copy.deepcopy(original)):
|
||||||
@ -5459,6 +5542,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), None)
|
}), None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
|
a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
|
||||||
data = example_pb2.Example(
|
data = example_pb2.Example(
|
||||||
@ -5480,6 +5564,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array([11, 21], dtype=np.int64),
|
values=np.array([11, 21], dtype=np.int64),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -5501,6 +5586,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -5519,6 +5605,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((0, 1, 0), dtype=np.int64),
|
values=np.array((0, 1, 0), dtype=np.int64),
|
||||||
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
dense_shape=inputs.dense_shape), self.evaluate(id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
id_weight_pair = column.get_sparse_tensors(
|
id_weight_pair = column.get_sparse_tensors(
|
||||||
@ -5537,6 +5624,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((0, 1, 0), dtype=np.int64),
|
values=np.array((0, 1, 0), dtype=np.int64),
|
||||||
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_inputs_too_small(self):
|
def test_get_sparse_tensors_with_inputs_too_small(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -5553,6 +5641,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(errors.OpError, 'assert_greater_or_equal_0'):
|
with self.assertRaisesRegexp(errors.OpError, 'assert_greater_or_equal_0'):
|
||||||
self.evaluate(id_weight_pair.id_tensor)
|
self.evaluate(id_weight_pair.id_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_inputs_too_big(self):
|
def test_get_sparse_tensors_with_inputs_too_big(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
inputs = sparse_tensor.SparseTensorValue(
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
@ -5570,6 +5659,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
'assert_less_than_num_buckets'):
|
'assert_less_than_num_buckets'):
|
||||||
self.evaluate(id_weight_pair.id_tensor)
|
self.evaluate(id_weight_pair.id_tensor)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_default_value(self):
|
def test_get_sparse_tensors_with_default_value(self):
|
||||||
column = fc.categorical_column_with_identity(
|
column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=4, default_value=3)
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
@ -5594,6 +5684,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
dense_shape=inputs.dense_shape),
|
dense_shape=inputs.dense_shape),
|
||||||
self.evaluate(id_weight_pair.id_tensor))
|
self.evaluate(id_weight_pair.id_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
||||||
column = fc.categorical_column_with_identity(
|
column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=4, default_value=3)
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
@ -5624,6 +5715,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
input_shape: (2, 2),
|
input_shape: (2, 2),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
self.assertEqual(3, column.num_buckets)
|
self.assertEqual(3, column.num_buckets)
|
||||||
@ -5674,6 +5766,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
# weight_var[2] + weight_var[1] = 3+2 = 5
|
# weight_var[2] + weight_var[1] = 3+2 = 5
|
||||||
self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions))
|
self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
|
||||||
@ -5827,6 +5920,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[0., 1., 1., 0.]], self.evaluate(output))
|
self.assertAllEqual([[0., 1., 1., 0.]], self.evaluate(output))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
a = fc.categorical_column_with_hash_bucket('a', 4)
|
a = fc.categorical_column_with_hash_bucket('a', 4)
|
||||||
column = fc.indicator_column(a)
|
column = fc.indicator_column(a)
|
||||||
@ -5835,6 +5929,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
self.assertEqual(column.name, 'a_indicator')
|
self.assertEqual(column.name, 'a_indicator')
|
||||||
self.assertEqual(column.variable_shape, [1, 4])
|
self.assertEqual(column.variable_shape, [1, 4])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5859,6 +5954,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform(self):
|
def test_transform(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -5878,6 +5974,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[0, 0, 1], [1, 0, 0]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[0, 0, 1], [1, 0, 0]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_weighted_column(self):
|
def test_transform_with_weighted_column(self):
|
||||||
# Github issue 12557
|
# Github issue 12557
|
||||||
ids = fc.categorical_column_with_vocabulary_list(
|
ids = fc.categorical_column_with_vocabulary_list(
|
||||||
@ -5896,6 +5993,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[6., 4., 3.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[6., 4., 3.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_missing_value_in_weighted_column(self):
|
def test_transform_with_missing_value_in_weighted_column(self):
|
||||||
# Github issue 12583
|
# Github issue 12583
|
||||||
ids = fc.categorical_column_with_vocabulary_list(
|
ids = fc.categorical_column_with_vocabulary_list(
|
||||||
@ -5914,6 +6012,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[0., 4., 2.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[0., 4., 2.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_with_missing_value_in_categorical_column(self):
|
def test_transform_with_missing_value_in_categorical_column(self):
|
||||||
# Github issue 12583
|
# Github issue 12583
|
||||||
ids = fc.categorical_column_with_vocabulary_list(
|
ids = fc.categorical_column_with_vocabulary_list(
|
||||||
@ -5930,6 +6029,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual([[0., 1., 1.]], self.evaluate(indicator_tensor))
|
self.assertAllEqual([[0., 1., 1.]], self.evaluate(indicator_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
animal = fc.indicator_column(
|
animal = fc.indicator_column(
|
||||||
fc.categorical_column_with_identity('animal', num_buckets=4))
|
fc.categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -5997,6 +6097,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
self.evaluate(weight_var.assign([[1.], [2.], [3.], [4.]]))
|
self.evaluate(weight_var.assign([[1.], [2.], [3.], [4.]]))
|
||||||
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
self.assertAllClose([[2. + 3.]], self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features(self):
|
def test_dense_features(self):
|
||||||
animal = fc.indicator_column(
|
animal = fc.indicator_column(
|
||||||
fc.categorical_column_with_identity('animal', num_buckets=4))
|
fc.categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -6013,6 +6114,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer(self):
|
def test_input_layer(self):
|
||||||
animal = fc.indicator_column(
|
animal = fc.indicator_column(
|
||||||
fc.categorical_column_with_identity('animal', num_buckets=4))
|
fc.categorical_column_with_identity('animal', num_buckets=4))
|
||||||
@ -6045,6 +6147,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
parent = fc.categorical_column_with_identity('animal', num_buckets=4)
|
parent = fc.categorical_column_with_identity('animal', num_buckets=4)
|
||||||
animal = fc.indicator_column(parent)
|
animal = fc.indicator_column(parent)
|
||||||
@ -6114,6 +6217,7 @@ class _TestStateManager(fc.StateManager):
|
|||||||
|
|
||||||
class EmbeddingColumnTest(test.TestCase):
|
class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -6142,6 +6246,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
categorical_column, dimension=embedding_dimension)
|
categorical_column, dimension=embedding_dimension)
|
||||||
self.assertFalse(embedding_column._is_v2_column)
|
self.assertFalse(embedding_column._is_v2_column)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -6168,6 +6273,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column.parse_example_spec)
|
}, embedding_column.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -6200,12 +6306,14 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column.parse_example_spec)
|
}, embedding_column.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_initializer(self):
|
def test_invalid_initializer(self):
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
|
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
|
||||||
fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
|
fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -6230,6 +6338,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
values=np.array([b'omar', b'stringer'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
dense_shape=[1, 2]), self.evaluate(features['aaa']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
a_embedded = fc.embedding_column(a, dimension=2)
|
a_embedded = fc.embedding_column(a, dimension=2)
|
||||||
@ -6250,6 +6359,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
||||||
self.evaluate(output_embedded))
|
self.evaluate(output_embedded))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6315,6 +6425,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_old_categorical(self):
|
def test_get_dense_tensor_old_categorical(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6378,6 +6489,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_3d(self):
|
def test_get_dense_tensor_3d(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 4
|
vocabulary_size = 4
|
||||||
@ -6445,6 +6557,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_placeholder_inputs(self):
|
def test_get_dense_tensor_placeholder_inputs(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6524,6 +6637,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
input_shape: sparse_input.dense_shape,
|
input_shape: sparse_input.dense_shape,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_restore_from_ckpt(self):
|
def test_get_dense_tensor_restore_from_ckpt(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6587,6 +6701,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@ -6668,6 +6783,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
||||||
self.evaluate(predictions))
|
self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features(self):
|
def test_dense_features(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6734,6 +6850,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(trainable_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(trainable_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features_not_trainable(self):
|
def test_dense_features_not_trainable(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -6799,6 +6916,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_input_layer(self):
|
def test_input_layer(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -7028,6 +7146,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),
|
||||||
self.evaluate(predictions))
|
self.evaluate(predictions))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
|
|
||||||
def _initializer(shape, dtype, partition_info):
|
def _initializer(shape, dtype, partition_info):
|
||||||
@ -7081,6 +7200,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class SharedEmbeddingColumnTest(test.TestCase):
|
class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7105,6 +7225,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_b.parse_example_spec)
|
}, embedding_column_b.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_all_constructor_args(self):
|
def test_all_constructor_args(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7136,6 +7257,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_b.parse_example_spec)
|
}, embedding_column_b.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7167,6 +7289,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
}, embedding_column_a.parse_example_spec)
|
}, embedding_column_a.parse_example_spec)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_invalid_initializer(self):
|
def test_invalid_initializer(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7178,6 +7301,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
dimension=2,
|
dimension=2,
|
||||||
initializer='not_fn')
|
initializer='not_fn')
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_incompatible_column_type(self):
|
def test_incompatible_column_type(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7192,6 +7316,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
[categorical_column_a, categorical_column_b, categorical_column_c],
|
[categorical_column_a, categorical_column_b, categorical_column_c],
|
||||||
dimension=2)
|
dimension=2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_weighted_categorical_column_ok(self):
|
def test_weighted_categorical_column_ok(self):
|
||||||
categorical_column_a = fc.categorical_column_with_identity(
|
categorical_column_a = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=3)
|
key='aaa', num_buckets=3)
|
||||||
@ -7209,6 +7334,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
[weighted_categorical_column_a, weighted_categorical_column_b],
|
[weighted_categorical_column_a, weighted_categorical_column_b],
|
||||||
dimension=2)
|
dimension=2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -7246,6 +7372,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
values=np.array([b'stringer', b'marlo'], dtype=np.object_),
|
values=np.array([b'stringer', b'marlo'], dtype=np.object_),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['bbb']))
|
dense_shape=[1, 2]), self.evaluate(features['bbb']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_feature(self):
|
def test_transform_feature(self):
|
||||||
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
|
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
|
||||||
@ -7277,6 +7404,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
||||||
self.evaluate(output_b_embedded))
|
self.evaluate(output_b_embedded))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -7348,6 +7476,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
||||||
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_placeholder_inputs(self):
|
def test_get_dense_tensor_placeholder_inputs(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 3
|
||||||
@ -7407,6 +7536,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
with _initialized_session() as sess:
|
with _initialized_session() as sess:
|
||||||
sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
|
sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@ -7619,12 +7749,15 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
self.evaluate(shared_embedding_vars[0]))
|
self.evaluate(shared_embedding_vars[0]))
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features(self):
|
def test_dense_features(self):
|
||||||
self._test_dense_features()
|
self._test_dense_features()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features_no_trainable(self):
|
def test_dense_features_no_trainable(self):
|
||||||
self._test_dense_features(trainable=False)
|
self._test_dense_features(trainable=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
|
|
||||||
def _initializer(shape, dtype, partition_info):
|
def _initializer(shape, dtype, partition_info):
|
||||||
@ -7647,6 +7780,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
|
|
||||||
class WeightedCategoricalColumnTest(test.TestCase):
|
class WeightedCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
column = fc.weighted_categorical_column(
|
column = fc.weighted_categorical_column(
|
||||||
categorical_column=fc.categorical_column_with_identity(
|
categorical_column=fc.categorical_column_with_identity(
|
||||||
@ -7667,6 +7801,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
weight_feature_key='values')
|
weight_feature_key='values')
|
||||||
self.assertFalse(column._is_v2_column)
|
self.assertFalse(column._is_v2_column)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||||
original = fc.weighted_categorical_column(
|
original = fc.weighted_categorical_column(
|
||||||
@ -7732,6 +7867,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
'values is not in features dictionary'):
|
'values is not in features dictionary'):
|
||||||
fc._transform_features_v2({'ids': inputs}, (column,), None)
|
fc._transform_features_v2({'ids': inputs}, (column,), None)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_parse_example(self):
|
def test_parse_example(self):
|
||||||
a = fc.categorical_column_with_vocabulary_list(
|
a = fc.categorical_column_with_vocabulary_list(
|
||||||
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
@ -7766,6 +7902,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array([1., 10.], dtype=np.float32),
|
values=np.array([1., 10.], dtype=np.float32),
|
||||||
dense_shape=[1, 2]), self.evaluate(features['weights']))
|
dense_shape=[1, 2]), self.evaluate(features['weights']))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features(self):
|
def test_transform_features(self):
|
||||||
column = fc.weighted_categorical_column(
|
column = fc.weighted_categorical_column(
|
||||||
categorical_column=fc.categorical_column_with_identity(
|
categorical_column=fc.categorical_column_with_identity(
|
||||||
@ -7798,6 +7935,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array(weights.values, dtype=np.float32),
|
values=np.array(weights.values, dtype=np.float32),
|
||||||
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features_dense_input(self):
|
def test_transform_features_dense_input(self):
|
||||||
column = fc.weighted_categorical_column(
|
column = fc.weighted_categorical_column(
|
||||||
categorical_column=fc.categorical_column_with_identity(
|
categorical_column=fc.categorical_column_with_identity(
|
||||||
@ -7828,6 +7966,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array(weights.values, dtype=np.float32),
|
values=np.array(weights.values, dtype=np.float32),
|
||||||
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
dense_shape=weights.dense_shape), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_transform_features_dense_weights(self):
|
def test_transform_features_dense_weights(self):
|
||||||
column = fc.weighted_categorical_column(
|
column = fc.weighted_categorical_column(
|
||||||
categorical_column=fc.categorical_column_with_identity(
|
categorical_column=fc.categorical_column_with_identity(
|
||||||
@ -7856,6 +7995,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
values=np.array((.5, 1., .1), dtype=np.float32),
|
values=np.array((.5, 1., .1), dtype=np.float32),
|
||||||
dense_shape=(2, 2)), self.evaluate(weight_tensor))
|
dense_shape=(2, 2)), self.evaluate(weight_tensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
column = fc.weighted_categorical_column(
|
column = fc.weighted_categorical_column(
|
||||||
categorical_column=fc.categorical_column_with_identity(
|
categorical_column=fc.categorical_column_with_identity(
|
||||||
@ -8106,6 +8246,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
# TODO(ptucker): Add test with embedding of weighted categorical.
|
# TODO(ptucker): Add test with embedding of weighted categorical.
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='ids', num_buckets=3)
|
key='ids', num_buckets=3)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework import load_library
|
from tensorflow.python.framework import load_library
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
@ -36,6 +37,7 @@ class FileSystemTest(test.TestCase):
|
|||||||
"test_file_system.so")
|
"test_file_system.so")
|
||||||
load_library.load_file_system_library(file_system_library)
|
load_library.load_file_system_library(file_system_library)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
reader = io_ops.WholeFileReader("test_reader")
|
reader = io_ops.WholeFileReader("test_reader")
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import function_def_to_graph
|
|||||||
from tensorflow.python.framework import graph_to_function_def
|
from tensorflow.python.framework import graph_to_function_def
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework import test_ops
|
from tensorflow.python.framework import test_ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -52,6 +53,7 @@ class FunctionDefToGraphTest(test.TestCase):
|
|||||||
fdef.signature.name = "_whats_in_a_name"
|
fdef.signature.name = "_whats_in_a_name"
|
||||||
return fdef
|
return fdef
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInputsAndOutputs(self):
|
def testInputsAndOutputs(self):
|
||||||
fdef = self._build_function_def()
|
fdef = self._build_function_def()
|
||||||
g = function_def_to_graph.function_def_to_graph(fdef)
|
g = function_def_to_graph.function_def_to_graph(fdef)
|
||||||
@ -186,6 +188,7 @@ class FunctionDefToGraphDefTest(test.TestCase):
|
|||||||
self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
|
self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
|
||||||
self.assertFalse("shape" in g.node[2].attr)
|
self.assertFalse("shape" in g.node[2].attr)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFunctionCallsFromFunction(self):
|
def testFunctionCallsFromFunction(self):
|
||||||
x = constant_op.constant(5.0)
|
x = constant_op.constant(5.0)
|
||||||
y = constant_op.constant(10.0)
|
y = constant_op.constant(10.0)
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.framework import function
|
|||||||
from tensorflow.python.framework import graph_to_function_def
|
from tensorflow.python.framework import graph_to_function_def
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework.errors import InvalidArgumentError
|
from tensorflow.python.framework.errors import InvalidArgumentError
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -104,6 +105,7 @@ class FunctionTest(test.TestCase):
|
|||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([18.0], self.evaluate(call))
|
self.assertAllEqual([18.0], self.evaluate(call))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIdentityImplicitDeref(self):
|
def testIdentityImplicitDeref(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32, func_name="MyIdentity")
|
@function.Defun(dtypes.float32, func_name="MyIdentity")
|
||||||
@ -322,6 +324,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual(x.get_shape(), dx.get_shape())
|
self.assertEqual(x.get_shape(), dx.get_shape())
|
||||||
self.assertEqual(y.get_shape(), dy.get_shape())
|
self.assertEqual(y.get_shape(), dy.get_shape())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSymGradAttr(self):
|
def testSymGradAttr(self):
|
||||||
|
|
||||||
@function.Defun(noinline=True)
|
@function.Defun(noinline=True)
|
||||||
@ -438,6 +441,7 @@ class FunctionTest(test.TestCase):
|
|||||||
"assertion failed.*-3"):
|
"assertion failed.*-3"):
|
||||||
self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
|
self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAssertWrapper(self):
|
def testAssertWrapper(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32)
|
@function.Defun(dtypes.float32)
|
||||||
@ -452,6 +456,7 @@ class FunctionTest(test.TestCase):
|
|||||||
"assertion"):
|
"assertion"):
|
||||||
_ = MyFn(100.0).eval()
|
_ = MyFn(100.0).eval()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWhileLoopCallsFunc(self):
|
def testWhileLoopCallsFunc(self):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
|
|
||||||
@ -471,6 +476,7 @@ class FunctionTest(test.TestCase):
|
|||||||
ans = self.evaluate(loop)
|
ans = self.evaluate(loop)
|
||||||
self.assertAllClose(ans, 131072.)
|
self.assertAllClose(ans, 131072.)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testControlFlowStrictness(self):
|
def testControlFlowStrictness(self):
|
||||||
"""Inlined functions must not execute in a untaken control flow branch."""
|
"""Inlined functions must not execute in a untaken control flow branch."""
|
||||||
|
|
||||||
@ -517,6 +523,7 @@ class FunctionTest(test.TestCase):
|
|||||||
"assertion"):
|
"assertion"):
|
||||||
sess.run(loop, {pred: True, x: 3})
|
sess.run(loop, {pred: True, x: 3})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testVar(self):
|
def testVar(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32)
|
@function.Defun(dtypes.float32)
|
||||||
@ -532,6 +539,7 @@ class FunctionTest(test.TestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
self.assertAllEqual(z.eval(), 101.)
|
self.assertAllEqual(z.eval(), 101.)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testResourceVarAsImplicitInput(self):
|
def testResourceVarAsImplicitInput(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default(), ops.device("cpu:0"):
|
with g.as_default(), ops.device("cpu:0"):
|
||||||
@ -707,6 +715,7 @@ class FunctionTest(test.TestCase):
|
|||||||
gdef = g.as_graph_def()
|
gdef = g.as_graph_def()
|
||||||
self.assertEqual(0, len(gdef.library.function))
|
self.assertEqual(0, len(gdef.library.function))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testReduction(self):
|
def testReduction(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
|
|
||||||
@ -735,6 +744,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertAllClose(vals[0], vals[1])
|
self.assertAllClose(vals[0], vals[1])
|
||||||
self.assertAllClose(vals[2], vals[3])
|
self.assertAllClose(vals[2], vals[3])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCapture(self):
|
def testCapture(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -781,6 +791,7 @@ class FunctionTest(test.TestCase):
|
|||||||
# NOTE: We still do not support capturing control deps.
|
# NOTE: We still do not support capturing control deps.
|
||||||
_ = Foo(x)
|
_ = Foo(x)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCaptureInWhileLoop(self):
|
def testCaptureInWhileLoop(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -796,6 +807,7 @@ class FunctionTest(test.TestCase):
|
|||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
self.assertEqual(self.evaluate(y), 10)
|
self.assertEqual(self.evaluate(y), 10)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCaptureInCond(self):
|
def testCaptureInCond(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -825,6 +837,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual("Foo_aCYSbwBkR5A",
|
self.assertEqual("Foo_aCYSbwBkR5A",
|
||||||
Foo.instantiate([dtypes.float32] * 3).name)
|
Foo.instantiate([dtypes.float32] * 3).name)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSignatureHash(self):
|
def testSignatureHash(self):
|
||||||
# Foo.Inner and Bar.Inner have identical function body but have
|
# Foo.Inner and Bar.Inner have identical function body but have
|
||||||
# different signatures. They should be treated as two different functions.
|
# different signatures. They should be treated as two different functions.
|
||||||
@ -877,6 +890,7 @@ class FunctionTest(test.TestCase):
|
|||||||
y = Bar(array_ops.zeros([1, 2, 3]))
|
y = Bar(array_ops.zeros([1, 2, 3]))
|
||||||
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
|
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testVariableReuse(self):
|
def testVariableReuse(self):
|
||||||
|
|
||||||
def LinearWithReuse(input_tensor, reuse=None):
|
def LinearWithReuse(input_tensor, reuse=None):
|
||||||
@ -905,6 +919,7 @@ class FunctionTest(test.TestCase):
|
|||||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||||
self.assertEqual(output_val.shape, (32, 100))
|
self.assertEqual(output_val.shape, (32, 100))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFunctionCallInDifferentVariableScopes(self):
|
def testFunctionCallInDifferentVariableScopes(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32)
|
@function.Defun(dtypes.float32)
|
||||||
@ -968,6 +983,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x}))
|
np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x}))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFunctionMarkedStateful(self):
|
def testFunctionMarkedStateful(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32, dtypes.float32)
|
@function.Defun(dtypes.int32, dtypes.float32)
|
||||||
@ -995,6 +1011,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual(100, self.evaluate(result_2))
|
self.assertEqual(100, self.evaluate(result_2))
|
||||||
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testStatefulFunction(self):
|
def testStatefulFunction(self):
|
||||||
|
|
||||||
@function.Defun()
|
@function.Defun()
|
||||||
@ -1037,6 +1054,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertFalse(all(val3 == val1))
|
self.assertFalse(all(val3 == val1))
|
||||||
self.assertFalse(all(val4 == val2))
|
self.assertFalse(all(val4 == val2))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSameFunctionOnTwoDevices(self):
|
def testSameFunctionOnTwoDevices(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32)
|
@function.Defun(dtypes.float32)
|
||||||
@ -1056,6 +1074,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual(44.0, self.evaluate(f_1))
|
self.assertEqual(44.0, self.evaluate(f_1))
|
||||||
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGuaranteedConstsAreCaptured(self):
|
def testGuaranteedConstsAreCaptured(self):
|
||||||
var = variables.Variable(1.0)
|
var = variables.Variable(1.0)
|
||||||
const = array_ops.guarantee_const(var)
|
const = array_ops.guarantee_const(var)
|
||||||
@ -1079,6 +1098,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.evaluate(var.initializer)
|
self.evaluate(var.initializer)
|
||||||
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSameFunctionDifferentGrads(self):
|
def testSameFunctionDifferentGrads(self):
|
||||||
|
|
||||||
def PartOne(x):
|
def PartOne(x):
|
||||||
@ -1150,6 +1170,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
self.assertEqual(func.declared_input_types, new_func.declared_input_types)
|
self.assertEqual(func.declared_input_types, new_func.declared_input_types)
|
||||||
self.assertEqual(func.captured_inputs, new_func.captured_inputs)
|
self.assertEqual(func.captured_inputs, new_func.captured_inputs)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.float32, dtypes.float32)
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
@ -1359,6 +1380,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
|
|
||||||
class FunctionOverloadTest(test.TestCase):
|
class FunctionOverloadTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
|
|
||||||
@function.Defun()
|
@function.Defun()
|
||||||
@ -1411,6 +1433,7 @@ class FunctionOverloadTest(test.TestCase):
|
|||||||
|
|
||||||
class FunctionCaptureByValueTest(test.TestCase):
|
class FunctionCaptureByValueTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCaptureByValue(self):
|
def testCaptureByValue(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -1634,6 +1657,7 @@ class FunctionInlineControlTest(test.TestCase):
|
|||||||
|
|
||||||
class ModuleFunctionTest(test.TestCase):
|
class ModuleFunctionTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
|
|
||||||
@function.Defun(*[dtypes.float32] * 3)
|
@function.Defun(*[dtypes.float32] * 3)
|
||||||
@ -1717,10 +1741,12 @@ class VariableHoistingTest(test.TestCase):
|
|||||||
self.assertAllEqual(db.shape, (64,))
|
self.assertAllEqual(db.shape, (64,))
|
||||||
self.assertAllClose(np.sum(db), 0.509, rtol=1e-2)
|
self.assertAllClose(np.sum(db), 0.509, rtol=1e-2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
self._testSimpleModel(True)
|
self._testSimpleModel(True)
|
||||||
self._testSimpleModel(False)
|
self._testSimpleModel(False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicResource(self):
|
def testBasicResource(self):
|
||||||
self._testSimpleModel(True, use_resource=True)
|
self._testSimpleModel(True, use_resource=True)
|
||||||
self._testSimpleModel(False, use_resource=True)
|
self._testSimpleModel(False, use_resource=True)
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import graph_util
|
|||||||
from tensorflow.python.framework import importer
|
from tensorflow.python.framework import importer
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_state_ops
|
from tensorflow.python.ops import gen_state_ops
|
||||||
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
|
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import math_ops as math_ops_lib
|
from tensorflow.python.ops import math_ops as math_ops_lib
|
||||||
@ -102,6 +103,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
|
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
|
||||||
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
|
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNestedDeviceFunctions(self):
|
def testNestedDeviceFunctions(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
var_0 = variables.VariableV1(0)
|
var_0 = variables.VariableV1(0)
|
||||||
|
@ -63,6 +63,7 @@ def _TestDir(test_name):
|
|||||||
|
|
||||||
class SimpleMetaGraphTest(test.TestCase):
|
class SimpleMetaGraphTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoVariables(self):
|
def testNoVariables(self):
|
||||||
test_dir = _TestDir("no_variables")
|
test_dir = _TestDir("no_variables")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
@ -116,6 +117,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
|||||||
{new_input_tensor: input_feed_value})
|
{new_input_tensor: input_feed_value})
|
||||||
self.assertEqual(new_output_value, output_value)
|
self.assertEqual(new_output_value, output_value)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testStrippedOpListNestedFunctions(self):
|
def testStrippedOpListNestedFunctions(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# Square two levels deep
|
# Square two levels deep
|
||||||
@ -158,6 +160,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
|||||||
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
||||||
self.assertEqual(["Const"], [op.name for op in op_list.op])
|
self.assertEqual(["Const"], [op.name for op in op_list.op])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDefaultAttrStripping(self):
|
def testDefaultAttrStripping(self):
|
||||||
"""Verifies that default attributes are stripped from a graph def."""
|
"""Verifies that default attributes are stripped from a graph def."""
|
||||||
|
|
||||||
@ -210,6 +213,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
|
self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
|
||||||
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
|
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDefaultAttrStrippingNestedFunctions(self):
|
def testDefaultAttrStrippingNestedFunctions(self):
|
||||||
"""Verifies that default attributes are stripped from function node defs."""
|
"""Verifies that default attributes are stripped from function node defs."""
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -261,6 +265,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(node_def.attr["attr_1"].i, 1)
|
self.assertEqual(node_def.attr["attr_1"].i, 1)
|
||||||
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
|
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testVariableObjectsAreSharedAmongCollections(self):
|
def testVariableObjectsAreSharedAmongCollections(self):
|
||||||
with ops.Graph().as_default() as graph1:
|
with ops.Graph().as_default() as graph1:
|
||||||
v = variables.Variable(3.0)
|
v = variables.Variable(3.0)
|
||||||
@ -454,6 +459,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
|
|
||||||
# Verifies that we can export the subgraph under each layer and import
|
# Verifies that we can export the subgraph under each layer and import
|
||||||
# them into new layers in a new graph.
|
# them into new layers in a new graph.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScopedExportAndImport(self):
|
def testScopedExportAndImport(self):
|
||||||
test_dir = _TestDir("scoped_export_import")
|
test_dir = _TestDir("scoped_export_import")
|
||||||
filenames = [
|
filenames = [
|
||||||
@ -522,6 +528,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
actual_grad_value = self.evaluate(grad)
|
actual_grad_value = self.evaluate(grad)
|
||||||
self.assertEqual(expected_grad_value, actual_grad_value)
|
self.assertEqual(expected_grad_value, actual_grad_value)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testImportWhileLoopInWhileLoop(self):
|
def testImportWhileLoopInWhileLoop(self):
|
||||||
# Create a simple while loop.
|
# Create a simple while loop.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -547,6 +554,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.evaluate(x)
|
self.evaluate(x)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScopedImportUnderNameScope(self):
|
def testScopedImportUnderNameScope(self):
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@ -562,6 +570,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(list(imported_variables.values())[0].name,
|
self.assertEqual(list(imported_variables.values())[0].name,
|
||||||
"foo/bar/myvar:0")
|
"foo/bar/myvar:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScopedImportUnderNameScopeNoVarScope(self):
|
def testScopedImportUnderNameScopeNoVarScope(self):
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@ -590,6 +599,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(list(imported_variables.values())[0].name,
|
self.assertEqual(list(imported_variables.values())[0].name,
|
||||||
"s" + suffix + "/v:0")
|
"s" + suffix + "/v:0")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScopedImportWithSelectedCollections(self):
|
def testScopedImportWithSelectedCollections(self):
|
||||||
meta_graph_filename = os.path.join(
|
meta_graph_filename = os.path.join(
|
||||||
_TestDir("selected_collections_import"), "meta_graph.pb")
|
_TestDir("selected_collections_import"), "meta_graph.pb")
|
||||||
@ -687,6 +697,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
|
|
||||||
# Verifies that we can export the subgraph containing a FIFOQueue under
|
# Verifies that we can export the subgraph containing a FIFOQueue under
|
||||||
# "queue1" and import it into "new_queue1" in a new graph.
|
# "queue1" and import it into "new_queue1" in a new graph.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScopedWithQueue(self):
|
def testScopedWithQueue(self):
|
||||||
test_dir = _TestDir("scoped_with_queue")
|
test_dir = _TestDir("scoped_with_queue")
|
||||||
orig_meta_graph = self._testScopedExportWithQueue(test_dir,
|
orig_meta_graph = self._testScopedExportWithQueue(test_dir,
|
||||||
@ -749,12 +760,15 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
for n, e in zip(nodes, expected):
|
for n, e in zip(nodes, expected):
|
||||||
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
|
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testExportNestedNames(self):
|
def testExportNestedNames(self):
|
||||||
self.doTestExportNestedNames(use_resource=False)
|
self.doTestExportNestedNames(use_resource=False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testExportNestedNamesResource(self):
|
def testExportNestedNamesResource(self):
|
||||||
self.doTestExportNestedNames(use_resource=True)
|
self.doTestExportNestedNames(use_resource=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPotentialCycle(self):
|
def testPotentialCycle(self):
|
||||||
graph1 = ops.Graph()
|
graph1 = ops.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
@ -783,6 +797,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
4.0, shape=[2, 2])
|
4.0, shape=[2, 2])
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testClearDevices(self):
|
def testClearDevices(self):
|
||||||
graph1 = ops.Graph()
|
graph1 = ops.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
@ -842,6 +857,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
|
|
||||||
class MetaGraphWithVariableScopeTest(test.TestCase):
|
class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMetricsCollection(self):
|
def testMetricsCollection(self):
|
||||||
|
|
||||||
def _enqueue_vector(sess, queue, values, shape=None):
|
def _enqueue_vector(sess, queue, values, shape=None):
|
||||||
@ -899,6 +915,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
|||||||
|
|
||||||
class ExportImportAcrossScopesTest(test.TestCase):
|
class ExportImportAcrossScopesTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPartionedVariables(self):
|
def testPartionedVariables(self):
|
||||||
|
|
||||||
def make_graph_with_partitioned_variables(use_resource):
|
def make_graph_with_partitioned_variables(use_resource):
|
||||||
|
@ -57,11 +57,13 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
|
|||||||
|
|
||||||
class ResourceTest(test_util.TensorFlowTestCase):
|
class ResourceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBuildGraph(self):
|
def testBuildGraph(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
||||||
test_ops.resource_create_op(pt).run()
|
test_ops.resource_create_op(pt).run()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInitialize(self):
|
def testInitialize(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
||||||
@ -106,6 +108,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
c = a + b
|
c = a + b
|
||||||
self.assertEqual([2, 3], c.shape)
|
self.assertEqual([2, 3], c.shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnknownDim(self):
|
def testUnknownDim(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
|
||||||
@ -113,6 +116,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
c = a + b
|
c = a + b
|
||||||
self.assertEqual([2, None, 3], c.shape.as_list())
|
self.assertEqual([2, None, 3], c.shape.as_list())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnknownShape(self):
|
def testUnknownShape(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
|
||||||
@ -120,6 +124,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
c = a + b
|
c = a + b
|
||||||
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
|
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScalarShape(self):
|
def testScalarShape(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
|
||||||
@ -127,6 +132,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
c = a + b
|
c = a + b
|
||||||
self.assertEqual(tensor_shape.scalar(), c.shape)
|
self.assertEqual(tensor_shape.scalar(), c.shape)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShapeFunctionError(self):
|
def testShapeFunctionError(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
a = array_ops.ones([1, 2, 3])
|
a = array_ops.ones([1, 2, 3])
|
||||||
@ -140,6 +146,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testToTensor(self):
|
def testToTensor(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
||||||
@ -149,6 +156,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
|||||||
tensor = ops.convert_to_tensor(x, name="tensor")
|
tensor = ops.convert_to_tensor(x, name="tensor")
|
||||||
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
|
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNegation(self):
|
def testNegation(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
||||||
@ -157,6 +165,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
|
self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
|
||||||
self.assertAllEqual(x.indices.eval(), [0, 2])
|
self.assertAllEqual(x.indices.eval(), [0, 2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testScalarMul(self):
|
def testScalarMul(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
||||||
@ -190,6 +199,7 @@ def _apply_op(g, *args, **kwargs):
|
|||||||
|
|
||||||
class OperationTest(test_util.TensorFlowTestCase):
|
class OperationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoInputs(self):
|
def testNoInputs(self):
|
||||||
op = test_ops.float_output_string_output(name="myop").a.op
|
op = test_ops.float_output_string_output(name="myop").a.op
|
||||||
self.assertEqual(2, len(op.values()))
|
self.assertEqual(2, len(op.values()))
|
||||||
@ -212,6 +222,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
|
self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
|
||||||
op.node_def)
|
op.node_def)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoOutputs(self):
|
def testNoOutputs(self):
|
||||||
op1 = test_ops.float_output(name="myop1").op
|
op1 = test_ops.float_output(name="myop1").op
|
||||||
float_t, = op1.values()
|
float_t, = op1.values()
|
||||||
@ -227,6 +238,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
|
self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
|
||||||
op2.node_def)
|
op2.node_def)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInputsAndOutputs(self):
|
def testInputsAndOutputs(self):
|
||||||
op1 = test_ops.float_output(name="myop1").op
|
op1 = test_ops.float_output(name="myop1").op
|
||||||
self.assertEqual(1, len(op1.values()))
|
self.assertEqual(1, len(op1.values()))
|
||||||
@ -308,6 +320,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ops.Operation(ops._NodeDef("op", "invalid:0"), g)
|
ops.Operation(ops._NodeDef("op", "invalid:0"), g)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoShapeFunction(self):
|
def testNoShapeFunction(self):
|
||||||
op = test_ops.a()
|
op = test_ops.a()
|
||||||
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
|
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
|
||||||
@ -333,6 +346,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
converted = ops.convert_to_tensor(1)
|
converted = ops.convert_to_tensor(1)
|
||||||
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConvertToTensorNestedTuple(self):
|
def testConvertToTensorNestedTuple(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = ((2,), (3,), (5,), (7,))
|
values = ((2,), (3,), (5,), (7,))
|
||||||
@ -384,6 +398,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
values = [1.23]
|
values = [1.23]
|
||||||
_ = ops.convert_to_tensor(values, dtype=dtypes.int64)
|
_ = ops.convert_to_tensor(values, dtype=dtypes.int64)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoConvert(self):
|
def testNoConvert(self):
|
||||||
# Operation cannot be converted to Tensor.
|
# Operation cannot be converted to Tensor.
|
||||||
op = control_flow_ops.no_op()
|
op = control_flow_ops.no_op()
|
||||||
@ -401,6 +416,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
|
ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
|
||||||
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGetAttr(self):
|
def testGetAttr(self):
|
||||||
op = test_ops.default_attrs()
|
op = test_ops.default_attrs()
|
||||||
self.assertEqual(op.get_attr("string_val"), b"abc")
|
self.assertEqual(op.get_attr("string_val"), b"abc")
|
||||||
@ -446,6 +462,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# TODO(b/65162920): remove this test when users who are directly mutating the
|
# TODO(b/65162920): remove this test when users who are directly mutating the
|
||||||
# node_def have been updated to proper usage.
|
# node_def have been updated to proper usage.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSetAttr(self):
|
def testSetAttr(self):
|
||||||
op = test_ops.int_attr().op
|
op = test_ops.int_attr().op
|
||||||
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
|
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
|
||||||
@ -466,6 +483,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(z.control_inputs, [x, y])
|
self.assertEqual(z.control_inputs, [x, y])
|
||||||
self.assertEqual(x._control_outputs, [z])
|
self.assertEqual(x._control_outputs, [z])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoveAllControlInputs(self):
|
def testRemoveAllControlInputs(self):
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
with ops.control_dependencies([a]):
|
with ops.control_dependencies([a]):
|
||||||
@ -490,6 +508,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(f.op.control_inputs, [])
|
self.assertEqual(f.op.control_inputs, [])
|
||||||
self.assertEqual(list(f.op.inputs), [d, e])
|
self.assertEqual(list(f.op.inputs), [d, e])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testControlInputCycle(self):
|
def testControlInputCycle(self):
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@ -582,6 +601,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
):
|
):
|
||||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testOpDef(self):
|
def testOpDef(self):
|
||||||
x = constant_op.constant(0)
|
x = constant_op.constant(0)
|
||||||
y = constant_op.constant(1)
|
y = constant_op.constant(1)
|
||||||
@ -681,6 +701,7 @@ class CreateOpTest(test_util.TensorFlowTestCase):
|
|||||||
# the low-level behavior.
|
# the low-level behavior.
|
||||||
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -731,6 +752,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(op3.name, "myop_2")
|
self.assertEqual(op3.name, "myop_2")
|
||||||
self.assertEqual(op4.name, "myop_1_1")
|
self.assertEqual(op4.name, "myop_1_1")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCond(self):
|
def testCond(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -760,6 +782,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
|||||||
"cond/cond_text")
|
"cond/cond_text")
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWhileLoop(self):
|
def testWhileLoop(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -789,6 +812,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
|||||||
"myloop/while_context")
|
"myloop/while_context")
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWhileLoopWithInternalControlDep(self):
|
def testWhileLoopWithInternalControlDep(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -812,6 +836,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
|||||||
# Internal control dep is preserved
|
# Internal control dep is preserved
|
||||||
self.assertEqual(op.control_inputs, [c])
|
self.assertEqual(op.control_inputs, [c])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testWhileLoopWithExternalControlDep(self):
|
def testWhileLoopWithExternalControlDep(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -945,6 +970,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
|
self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
|
||||||
self.assertEqual("bar_2", g.unique_name("bar"))
|
self.assertEqual("bar_2", g.unique_name("bar"))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNameAndVariableScope(self):
|
def testNameAndVariableScope(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
with sess.graph.name_scope("l0"):
|
with sess.graph.name_scope("l0"):
|
||||||
@ -1671,6 +1697,7 @@ def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
|
|||||||
|
|
||||||
class RegistrationTest(test_util.TensorFlowTestCase):
|
class RegistrationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRegisterGradients(self):
|
def testRegisterGradients(self):
|
||||||
x = test_ops.float_output()
|
x = test_ops.float_output()
|
||||||
y = test_ops.copy_op(x)
|
y = test_ops.copy_op(x)
|
||||||
@ -1710,6 +1737,7 @@ class ComparisonTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -1953,6 +1981,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.name_scope(None, "default2") as scope2:
|
with ops.name_scope(None, "default2") as scope2:
|
||||||
self.assertEqual(scope2, "default/default2/")
|
self.assertEqual(scope2, "default/default2/")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoScopeName(self):
|
def testNoScopeName(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
values = [
|
values = [
|
||||||
@ -1966,6 +1995,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.name_scope(None, None, values):
|
with ops.name_scope(None, None, values):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testEmptyScopeName(self):
|
def testEmptyScopeName(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
a = g0.create_op("A", [], [dtypes.float32])
|
a = g0.create_op("A", [], [dtypes.float32])
|
||||||
@ -1977,6 +2007,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("", scope)
|
self.assertEqual("", scope)
|
||||||
self.assertEqual(g0, ops.get_default_graph())
|
self.assertEqual(g0, ops.get_default_graph())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDefaultScopeName(self):
|
def testDefaultScopeName(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
a = g0.create_op("A", [], [dtypes.float32])
|
a = g0.create_op("A", [], [dtypes.float32])
|
||||||
@ -2001,12 +2032,14 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.name_scope(scope_name, values=graph_elements + [a]):
|
with ops.name_scope(scope_name, values=graph_elements + [a]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTensor(self):
|
def testTensor(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
a = g0.create_op("A", [], [dtypes.float32])
|
a = g0.create_op("A", [], [dtypes.float32])
|
||||||
b = g0.create_op("B", [], [dtypes.float32])
|
b = g0.create_op("B", [], [dtypes.float32])
|
||||||
self._testGraphElements([a, b])
|
self._testGraphElements([a, b])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSparseTensor(self):
|
def testSparseTensor(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
a = g0.create_op("A", [], [dtypes.float32])
|
a = g0.create_op("A", [], [dtypes.float32])
|
||||||
@ -2017,6 +2050,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
_apply_op(g0, "Int64Output", [], [dtypes.int64]))
|
_apply_op(g0, "Int64Output", [], [dtypes.int64]))
|
||||||
self._testGraphElements([a, sparse, b])
|
self._testGraphElements([a, sparse, b])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testVariable(self):
|
def testVariable(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
with g0.as_default():
|
with g0.as_default():
|
||||||
@ -2221,6 +2255,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
|
self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
|
||||||
self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
|
self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
|
def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
@ -2357,6 +2392,7 @@ class GraphTest(test_util.TensorFlowTestCase):
|
|||||||
g.prevent_feeding(a)
|
g.prevent_feeding(a)
|
||||||
self.assertFalse(g.is_feedable(a))
|
self.assertFalse(g.is_feedable(a))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPreventFetching(self):
|
def testPreventFetching(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
a = constant_op.constant(2.0)
|
a = constant_op.constant(2.0)
|
||||||
@ -2440,10 +2476,12 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
|
|||||||
b = None
|
b = None
|
||||||
return (a, b)
|
return (a, b)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoLabel(self):
|
def testNoLabel(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual((None, None), self._get_test_attrs())
|
self.assertAllEqual((None, None), self._get_test_attrs())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLabelMap(self):
|
def testLabelMap(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
a1 = self._get_test_attrs()
|
a1 = self._get_test_attrs()
|
||||||
@ -2478,11 +2516,13 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
|
|||||||
|
|
||||||
class KernelLabelTest(test_util.TensorFlowTestCase):
|
class KernelLabelTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoLabel(self):
|
def testNoLabel(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllEqual(b"My label is: default",
|
self.assertAllEqual(b"My label is: default",
|
||||||
test_ops.kernel_label().eval())
|
test_ops.kernel_label().eval())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLabelMap(self):
|
def testLabelMap(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
default_1 = test_ops.kernel_label()
|
default_1 = test_ops.kernel_label()
|
||||||
@ -2599,6 +2639,7 @@ class StatisticsTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class DeviceStackTest(test_util.TensorFlowTestCase):
|
class DeviceStackTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicDeviceAssignmentMetadata(self):
|
def testBasicDeviceAssignmentMetadata(self):
|
||||||
|
|
||||||
def device_func(unused_op):
|
def device_func(unused_op):
|
||||||
@ -2630,6 +2671,7 @@ class DeviceStackTest(test_util.TensorFlowTestCase):
|
|||||||
expected_regex = r"device_func<.*ops_test.py, [0-9]+"
|
expected_regex = r"device_func<.*ops_test.py, [0-9]+"
|
||||||
self.assertRegexpMatches(func_description, expected_regex)
|
self.assertRegexpMatches(func_description, expected_regex)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
|
def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
|
||||||
|
|
||||||
with ops.device("/cpu"):
|
with ops.device("/cpu"):
|
||||||
@ -2649,6 +2691,7 @@ class DeviceStackTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class ColocationGroupTest(test_util.TensorFlowTestCase):
|
class ColocationGroupTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
with ops.colocate_with(a.op):
|
with ops.colocate_with(a.op):
|
||||||
@ -2659,6 +2702,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
c.op.get_attr("_class")
|
c.op.get_attr("_class")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicColocationMetadata(self):
|
def testBasicColocationMetadata(self):
|
||||||
const_two = constant_op.constant([2.0], name="two")
|
const_two = constant_op.constant([2.0], name="two")
|
||||||
with ops.colocate_with(const_two.op):
|
with ops.colocate_with(const_two.op):
|
||||||
@ -2671,6 +2715,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
# colocation statement.
|
# colocation statement.
|
||||||
self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
|
self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocationDeviceInteraction(self):
|
def testColocationDeviceInteraction(self):
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
@ -2683,6 +2728,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
self.assertEqual(a.op.device, b.op.device)
|
self.assertEqual(a.op.device, b.op.device)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocationCanonicalization(self):
|
def testColocationCanonicalization(self):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
_ = constant_op.constant(2.0)
|
_ = constant_op.constant(2.0)
|
||||||
@ -2698,6 +2744,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
# inherits B's device name, after canonicalizing the names.
|
# inherits B's device name, after canonicalizing the names.
|
||||||
self.assertEqual(b.op.device, c.op.device)
|
self.assertEqual(b.op.device, c.op.device)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testLocationOverrides(self):
|
def testLocationOverrides(self):
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
@ -2719,6 +2766,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("/device:GPU:0", c.op.device)
|
self.assertEqual("/device:GPU:0", c.op.device)
|
||||||
self.assertEqual("/device:CPU:0", d.op.device)
|
self.assertEqual("/device:CPU:0", d.op.device)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNestedColocateWith(self):
|
def testNestedColocateWith(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
with ops.colocate_with(a.op):
|
with ops.colocate_with(a.op):
|
||||||
@ -2728,6 +2776,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
self.assertEqual([b"loc:@a"], c.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], c.op.colocation_groups())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultiColocationGroups(self):
|
def testMultiColocationGroups(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
b = constant_op.constant(3.0, name="b")
|
b = constant_op.constant(3.0, name="b")
|
||||||
@ -2736,6 +2785,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
c = constant_op.constant(4.0)
|
c = constant_op.constant(4.0)
|
||||||
self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
|
self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocationIgnoreStack(self):
|
def testColocationIgnoreStack(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
b = constant_op.constant(3.0, name="b")
|
b = constant_op.constant(3.0, name="b")
|
||||||
@ -2744,6 +2794,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
c = constant_op.constant(4.0)
|
c = constant_op.constant(4.0)
|
||||||
self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
|
self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocateWithReset(self):
|
def testColocateWithReset(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
with ops.colocate_with(a.op):
|
with ops.colocate_with(a.op):
|
||||||
@ -2753,6 +2804,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
self.assertEqual([b"loc:@c"], c.op.colocation_groups())
|
self.assertEqual([b"loc:@c"], c.op.colocation_groups())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocateWithInitialNoneThenNested(self):
|
def testColocateWithInitialNoneThenNested(self):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
with ops.colocate_with(a.op):
|
with ops.colocate_with(a.op):
|
||||||
@ -2763,12 +2815,14 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual([b"loc:@b"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@b"], b.op.colocation_groups())
|
||||||
self.assertEqual([b"loc:@b"], c.op.colocation_groups())
|
self.assertEqual([b"loc:@b"], c.op.colocation_groups())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocateVariables(self):
|
def testColocateVariables(self):
|
||||||
a = variables.Variable([2.0], name="a")
|
a = variables.Variable([2.0], name="a")
|
||||||
with ops.colocate_with(a.op):
|
with ops.colocate_with(a.op):
|
||||||
b = variables.Variable([3.0], name="b")
|
b = variables.Variable([3.0], name="b")
|
||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInconsistentDeviceWithinColocate(self):
|
def testInconsistentDeviceWithinColocate(self):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
a = constant_op.constant([2.0], name="a")
|
a = constant_op.constant([2.0], name="a")
|
||||||
@ -2782,6 +2836,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
self.assertEqual("/device:CPU:0", b.device)
|
self.assertEqual("/device:CPU:0", b.device)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMakeColocationConflictMessage(self):
|
def testMakeColocationConflictMessage(self):
|
||||||
"""Test that provides an example of a complicated error message."""
|
"""Test that provides an example of a complicated error message."""
|
||||||
# We could test the message with any ops, but this test will be more
|
# We could test the message with any ops, but this test will be more
|
||||||
@ -2926,6 +2981,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class TracebackTest(test_util.TensorFlowTestCase):
|
class TracebackTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTracebackWithStartLines(self):
|
def testTracebackWithStartLines(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
a = constant_op.constant(2.0)
|
a = constant_op.constant(2.0)
|
||||||
@ -2947,6 +3003,7 @@ class TracebackTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
|
class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBadArgumentsToEnableEagerExecution(self):
|
def testBadArgumentsToEnableEagerExecution(self):
|
||||||
with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
|
with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
|
||||||
ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
|
ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
|
||||||
|
@ -35,6 +35,7 @@ def raise_exception():
|
|||||||
|
|
||||||
class SmartCondTest(test_util.TensorFlowTestCase):
|
class SmartCondTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTrue(self):
|
def testTrue(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with session.Session():
|
with session.Session():
|
||||||
@ -44,6 +45,7 @@ class SmartCondTest(test_util.TensorFlowTestCase):
|
|||||||
lambda: math_ops.multiply(y, 5))
|
lambda: math_ops.multiply(y, 5))
|
||||||
self.assertEqual(z.eval(), 32)
|
self.assertEqual(z.eval(), 32)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFalse(self):
|
def testFalse(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with session.Session():
|
with session.Session():
|
||||||
@ -99,6 +101,7 @@ class SmartCondTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class SmartCaseTest(test_util.TensorFlowTestCase):
|
class SmartCaseTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTrue(self):
|
def testTrue(self):
|
||||||
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||||
conditions = [(True, lambda: constant_op.constant(1)),
|
conditions = [(True, lambda: constant_op.constant(1)),
|
||||||
@ -112,6 +115,7 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(self.evaluate(y), 1)
|
self.assertEqual(self.evaluate(y), 1)
|
||||||
self.assertEqual(self.evaluate(z), 1)
|
self.assertEqual(self.evaluate(z), 1)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFalse(self):
|
def testFalse(self):
|
||||||
conditions = [(False, raise_exception)]
|
conditions = [(False, raise_exception)]
|
||||||
y = smart_cond.smart_case(conditions,
|
y = smart_cond.smart_case(conditions,
|
||||||
@ -124,6 +128,7 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(self.evaluate(y), 1)
|
self.assertEqual(self.evaluate(y), 1)
|
||||||
self.assertEqual(self.evaluate(z), 1)
|
self.assertEqual(self.evaluate(z), 1)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMix(self):
|
def testMix(self):
|
||||||
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||||
y = constant_op.constant(10)
|
y = constant_op.constant(10)
|
||||||
|
@ -65,6 +65,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
|||||||
sparse_tensor.is_sparse(
|
sparse_tensor.is_sparse(
|
||||||
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
|
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConsumers(self):
|
def testConsumers(self):
|
||||||
sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
|
sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
|
||||||
w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
|
w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
|
||||||
@ -87,6 +88,7 @@ class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
|||||||
value)
|
value)
|
||||||
self.assertAllEqual(value, self.evaluate(from_value))
|
self.assertAllEqual(value, self.evaluate(from_value))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_convert_sparse(self):
|
def test_convert_sparse(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
indices = [[0, 1], [1, 0]]
|
indices = [[0, 1], [1, 0]]
|
||||||
|
@ -43,6 +43,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
all(subscribe._is_subscribed_identity(x) for x in container))
|
all(subscribe._is_subscribed_identity(x) for x in container))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSideEffect(self):
|
def testSideEffect(self):
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
b = constant_op.constant(1)
|
b = constant_op.constant(1)
|
||||||
@ -75,6 +76,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(d_out, [42])
|
self.assertEqual(d_out, [42])
|
||||||
self.assertEqual(shared, [2, 2, 2])
|
self.assertEqual(shared, [2, 2, 2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSupportedTypes(self):
|
def testSupportedTypes(self):
|
||||||
"""Confirm that supported types are correctly detected and handled."""
|
"""Confirm that supported types are correctly detected and handled."""
|
||||||
|
|
||||||
@ -120,6 +122,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
subscribe.subscribe(c.name,
|
subscribe.subscribe(c.name,
|
||||||
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
|
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCaching(self):
|
def testCaching(self):
|
||||||
"""Confirm caching of control output is recalculated between calls."""
|
"""Confirm caching of control output is recalculated between calls."""
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
@ -152,6 +155,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(d_out, [11])
|
self.assertEqual(d_out, [11])
|
||||||
self.assertEqual(shared, {2: 1, 1: 1})
|
self.assertEqual(shared, {2: 1, 1: 1})
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testIsSubscribedIdentity(self):
|
def testIsSubscribedIdentity(self):
|
||||||
"""Confirm subscribed identity ops are correctly detected."""
|
"""Confirm subscribed identity ops are correctly detected."""
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
@ -165,6 +169,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertFalse(subscribe._is_subscribed_identity(idop))
|
self.assertFalse(subscribe._is_subscribed_identity(idop))
|
||||||
self.assertTrue(subscribe._is_subscribed_identity(c_sub))
|
self.assertTrue(subscribe._is_subscribed_identity(c_sub))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSubscribeExtend(self):
|
def testSubscribeExtend(self):
|
||||||
"""Confirm side effect are correctly added for different input types."""
|
"""Confirm side effect are correctly added for different input types."""
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
@ -210,6 +215,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIn('graph2', shared)
|
self.assertIn('graph2', shared)
|
||||||
self.assertIn('graph3', shared)
|
self.assertIn('graph3', shared)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSubscribeVariable(self):
|
def testSubscribeVariable(self):
|
||||||
"""Confirm that variables can be subscribed."""
|
"""Confirm that variables can be subscribed."""
|
||||||
v1 = variables.VariableV1(0.0)
|
v1 = variables.VariableV1(0.0)
|
||||||
@ -248,6 +254,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
# Make sure the values read from the variable match the expected ones.
|
# Make sure the values read from the variable match the expected ones.
|
||||||
self.assertEqual([0.0, 3.0], shared)
|
self.assertEqual([0.0, 3.0], shared)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testResourceType(self):
|
def testResourceType(self):
|
||||||
"""Confirm that subscribe correctly handles tensors with 'resource' type."""
|
"""Confirm that subscribe correctly handles tensors with 'resource' type."""
|
||||||
tensor_array = tensor_array_ops.TensorArray(
|
tensor_array = tensor_array_ops.TensorArray(
|
||||||
@ -276,6 +283,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.evaluate([reader])
|
self.evaluate([reader])
|
||||||
self.assertEqual(0, len(shared))
|
self.assertEqual(0, len(shared))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMultipleOutputs(self):
|
def testMultipleOutputs(self):
|
||||||
"""Handle subscriptions to multiple outputs from the same op."""
|
"""Handle subscriptions to multiple outputs from the same op."""
|
||||||
sparse_tensor_1 = sparse_tensor.SparseTensor(
|
sparse_tensor_1 = sparse_tensor.SparseTensor(
|
||||||
@ -309,6 +317,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
# All three ops have been processed.
|
# All three ops have been processed.
|
||||||
self.assertEqual(3, len(shared))
|
self.assertEqual(3, len(shared))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_subscribe_tensors_on_different_devices(self):
|
def test_subscribe_tensors_on_different_devices(self):
|
||||||
"""Side effect ops are added with the same device of the subscribed op."""
|
"""Side effect ops are added with the same device of the subscribed op."""
|
||||||
c1 = constant_op.constant(10)
|
c1 = constant_op.constant(10)
|
||||||
@ -335,6 +344,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(add.device, add_sub.device)
|
self.assertEqual(add.device, add_sub.device)
|
||||||
self.assertEqual(mul.device, mul_sub.device)
|
self.assertEqual(mul.device, mul_sub.device)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_subscribe_tensors_within_control_flow_context(self):
|
def test_subscribe_tensors_within_control_flow_context(self):
|
||||||
"""Side effect ops are added with the same control flow context."""
|
"""Side effect ops are added with the same control flow context."""
|
||||||
c1 = constant_op.constant(10)
|
c1 = constant_op.constant(10)
|
||||||
|
@ -45,6 +45,7 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
|||||||
desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
|
desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
|
||||||
self.assertEqual(desc.shape, tensor_shape.TensorShape(None))
|
self.assertEqual(desc.shape, tensor_shape.TensorShape(None))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testShapeCompatibility(self):
|
def testShapeCompatibility(self):
|
||||||
unknown = array_ops.placeholder(dtypes.int64)
|
unknown = array_ops.placeholder(dtypes.int64)
|
||||||
partial = array_ops.placeholder(dtypes.int64, shape=[None, 1])
|
partial = array_ops.placeholder(dtypes.int64, shape=[None, 1])
|
||||||
@ -75,6 +76,7 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertFalse(desc_rank3.is_compatible_with(full))
|
self.assertFalse(desc_rank3.is_compatible_with(full))
|
||||||
self.assertTrue(desc_rank3.is_compatible_with(rank3))
|
self.assertTrue(desc_rank3.is_compatible_with(rank3))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTypeCompatibility(self):
|
def testTypeCompatibility(self):
|
||||||
floats = array_ops.placeholder(dtypes.float32, shape=[10, 10])
|
floats = array_ops.placeholder(dtypes.float32, shape=[10, 10])
|
||||||
ints = array_ops.placeholder(dtypes.int32, shape=[10, 10])
|
ints = array_ops.placeholder(dtypes.int32, shape=[10, 10])
|
||||||
@ -106,6 +108,7 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
|||||||
spec_2 = tensor_spec.TensorSpec.from_spec(spec_1)
|
spec_2 = tensor_spec.TensorSpec.from_spec(spec_1)
|
||||||
self.assertEqual(spec_1, spec_2)
|
self.assertEqual(spec_1, spec_2)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromTensor(self):
|
def testFromTensor(self):
|
||||||
zero = constant_op.constant(0)
|
zero = constant_op.constant(0)
|
||||||
spec = tensor_spec.TensorSpec.from_tensor(zero)
|
spec = tensor_spec.TensorSpec.from_tensor(zero)
|
||||||
@ -113,6 +116,7 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(spec.shape, [])
|
self.assertEqual(spec.shape, [])
|
||||||
self.assertEqual(spec.name, "Const")
|
self.assertEqual(spec.name, "Const")
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFromPlaceholder(self):
|
def testFromPlaceholder(self):
|
||||||
unknown = array_ops.placeholder(dtypes.int64, name="unknown")
|
unknown = array_ops.placeholder(dtypes.int64, name="unknown")
|
||||||
partial = array_ops.placeholder(dtypes.float32,
|
partial = array_ops.placeholder(dtypes.float32,
|
||||||
|
@ -758,6 +758,7 @@ class TensorUtilTest(test.TestCase):
|
|||||||
self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
|
self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
|
||||||
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
|
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testMockArray(self):
|
def testMockArray(self):
|
||||||
|
|
||||||
class MockArray(object):
|
class MockArray(object):
|
||||||
@ -787,6 +788,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
tf_val = constant_op.constant(np_val)
|
tf_val = constant_op.constant(np_val)
|
||||||
self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
|
self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testUnknown(self):
|
def testUnknown(self):
|
||||||
tf_val = gen_state_ops.variable(
|
tf_val = gen_state_ops.variable(
|
||||||
shape=[3, 4, 7],
|
shape=[3, 4, 7],
|
||||||
@ -815,12 +817,14 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertEqual(6, c_val)
|
self.assertEqual(6, c_val)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSizeOfScalar(self):
|
def testSizeOfScalar(self):
|
||||||
tf_val = array_ops.size(constant_op.constant(0.0))
|
tf_val = array_ops.size(constant_op.constant(0.0))
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertEqual(1, c_val)
|
self.assertEqual(1, c_val)
|
||||||
self.assertEqual(np.ndarray, type(c_val))
|
self.assertEqual(np.ndarray, type(c_val))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRank(self):
|
def testRank(self):
|
||||||
tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
|
tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
@ -852,6 +856,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertAllClose(np_val.astype(np.float64), c_val)
|
self.assertAllClose(np_val.astype(np.float64), c_val)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
np_val = np.random.rand(3, 4, 7).astype(np.float32)
|
np_val = np.random.rand(3, 4, 7).astype(np.float32)
|
||||||
tf_val = array_ops.concat(
|
tf_val = array_ops.concat(
|
||||||
@ -871,6 +876,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertIs(None, c_val)
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPack_Axis0(self):
|
def testPack_Axis0(self):
|
||||||
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
||||||
np_val = np.array(inputs)
|
np_val = np.array(inputs)
|
||||||
@ -883,6 +889,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertIs(None, c_val)
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPack_Axis1(self):
|
def testPack_Axis1(self):
|
||||||
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
||||||
tf_val = array_ops.stack(inputs, axis=1)
|
tf_val = array_ops.stack(inputs, axis=1)
|
||||||
@ -894,6 +901,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertIs(None, c_val)
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPack_Partial_Axis0(self):
|
def testPack_Partial_Axis0(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
||||||
@ -901,6 +909,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
self.assertAllClose(input_, c_val[0])
|
self.assertAllClose(input_, c_val[0])
|
||||||
self.assertIsNone(c_val[1])
|
self.assertIsNone(c_val[1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPack_Partial_Axis1(self):
|
def testPack_Partial_Axis1(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)],
|
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)],
|
||||||
@ -966,12 +975,14 @@ class ConstantValueAsShapeTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
self.assertEqual([None, 1, None], c_val.as_list())
|
self.assertEqual([None, 1, None], c_val.as_list())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testPack(self):
|
def testPack(self):
|
||||||
tf_val = array_ops.stack(
|
tf_val = array_ops.stack(
|
||||||
[constant_op.constant(16), 37, array_ops.placeholder(dtypes.int32)])
|
[constant_op.constant(16), 37, array_ops.placeholder(dtypes.int32)])
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
self.assertEqual([16, 37, None], c_val.as_list())
|
self.assertEqual([16, 37, None], c_val.as_list())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
tf_val = array_ops.concat(
|
tf_val = array_ops.concat(
|
||||||
[[16, 37], array_ops.placeholder(
|
[[16, 37], array_ops.placeholder(
|
||||||
@ -985,6 +996,7 @@ class ConstantValueAsShapeTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSlice(self):
|
def testSlice(self):
|
||||||
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
|
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
@ -49,6 +49,7 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_assert_ops_in_graph(self):
|
def test_assert_ops_in_graph(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
constant_op.constant(["hello", "taffy"], name="hello")
|
constant_op.constant(["hello", "taffy"], name="hello")
|
||||||
@ -60,6 +61,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
|
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
|
||||||
{"hello": "Variable"}, ops.get_default_graph())
|
{"hello": "Variable"}, ops.get_default_graph())
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_session_functions(self):
|
def test_session_functions(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess_ref = weakref.ref(sess)
|
sess_ref = weakref.ref(sess)
|
||||||
@ -551,6 +553,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
self.assertAllLessEqual(x, 95.0)
|
self.assertAllLessEqual(x, 95.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAssertAllInRangeWithNonNumericValuesFails(self):
|
def testAssertAllInRangeWithNonNumericValuesFails(self):
|
||||||
s1 = constant_op.constant("Hello, ", name="s1")
|
s1 = constant_op.constant("Hello, ", name="s1")
|
||||||
c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
|
c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
|
||||||
@ -614,6 +617,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
self.assertAllInSet(x, (42,))
|
self.assertAllInSet(x, (42,))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRandomSeed(self):
|
def testRandomSeed(self):
|
||||||
# Call setUp again for WithCApi case (since it makes a new defeault graph
|
# Call setUp again for WithCApi case (since it makes a new defeault graph
|
||||||
# after setup).
|
# after setup).
|
||||||
@ -706,6 +710,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
test_util.run_in_graph_and_eager_modes(_test)(self)
|
test_util.run_in_graph_and_eager_modes(_test)(self)
|
||||||
self.assertEqual(modes, ["graph"])
|
self.assertEqual(modes, ["graph"])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
|
def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
|
||||||
modes = []
|
modes = []
|
||||||
mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
|
mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.training import adam
|
|||||||
|
|
||||||
class CostAnalysisTest(test.TestCase):
|
class CostAnalysisTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicCost(self):
|
def testBasicCost(self):
|
||||||
"""Make sure arguments can be passed correctly."""
|
"""Make sure arguments can be passed correctly."""
|
||||||
a = constant_op.constant(10, name="a")
|
a = constant_op.constant(10, name="a")
|
||||||
@ -62,6 +63,7 @@ class CostAnalysisTest(test.TestCase):
|
|||||||
# Also print the report to make it easier to debug
|
# Also print the report to make it easier to debug
|
||||||
print("{}".format(report))
|
print("{}".format(report))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testVerbose(self):
|
def testVerbose(self):
|
||||||
"""Make sure the full report is generated with verbose=True."""
|
"""Make sure the full report is generated with verbose=True."""
|
||||||
a = constant_op.constant(10, name="a")
|
a = constant_op.constant(10, name="a")
|
||||||
@ -81,6 +83,7 @@ class CostAnalysisTest(test.TestCase):
|
|||||||
# Also print the report to make it easier to debug
|
# Also print the report to make it easier to debug
|
||||||
print("{}".format(report))
|
print("{}".format(report))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSmallNetworkCost(self):
|
def testSmallNetworkCost(self):
|
||||||
image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
|
image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
|
||||||
label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
|
label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
|
||||||
@ -129,6 +132,7 @@ class CostAnalysisTest(test.TestCase):
|
|||||||
# self.assertTrue(0 < upper)
|
# self.assertTrue(0 < upper)
|
||||||
# self.assertTrue(lower <= upper)
|
# self.assertTrue(lower <= upper)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicMemory(self):
|
def testBasicMemory(self):
|
||||||
"""Make sure arguments can be passed correctly."""
|
"""Make sure arguments can be passed correctly."""
|
||||||
with test_util.device(use_gpu=False):
|
with test_util.device(use_gpu=False):
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import errors_impl
|
|||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.grappler import item
|
from tensorflow.python.grappler import item
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
@ -107,6 +108,7 @@ class ItemTest(test.TestCase):
|
|||||||
newest_tf_item = grappler_item.tf_item
|
newest_tf_item = grappler_item.tf_item
|
||||||
self.assertEqual(new_tf_item, newest_tf_item)
|
self.assertEqual(new_tf_item, newest_tf_item)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testColocationContraints(self):
|
def testColocationContraints(self):
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
c = constant_op.constant([10])
|
c = constant_op.constant([10])
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.grappler import cluster as gcluster
|
from tensorflow.python.grappler import cluster as gcluster
|
||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.layers import convolutional as conv_layers
|
from tensorflow.python.layers import convolutional as conv_layers
|
||||||
@ -1441,6 +1442,7 @@ class LayoutOptimizerTest(test.TestCase):
|
|||||||
self._assert_trans_nchw_to_nhwc('Add-0-0', nodes)
|
self._assert_trans_nchw_to_nhwc('Add-0-0', nodes)
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
meta_graph = _simple_metagraph()
|
meta_graph = _simple_metagraph()
|
||||||
config = config_pb2.ConfigProto()
|
config = config_pb2.ConfigProto()
|
||||||
@ -1458,6 +1460,7 @@ class LayoutOptimizerTest(test.TestCase):
|
|||||||
self.assertEqual(node.attr['data_format'].s, b'NCHW')
|
self.assertEqual(node.attr['data_format'].s, b'NCHW')
|
||||||
self.assertEqual(found, 5)
|
self.assertEqual(found, 5)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDepthwise(self):
|
def testDepthwise(self):
|
||||||
meta_graph = _simple_metagraph(depthwise=True)
|
meta_graph = _simple_metagraph(depthwise=True)
|
||||||
config = config_pb2.ConfigProto()
|
config = config_pb2.ConfigProto()
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.client import session
|
|||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
@ -37,6 +38,7 @@ from tensorflow.python.training import training as train
|
|||||||
class MemoryOptimizerSwapTest(test.TestCase):
|
class MemoryOptimizerSwapTest(test.TestCase):
|
||||||
"""Tests the Grappler memory optimizer."""
|
"""Tests the Grappler memory optimizer."""
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testNoSwapping(self):
|
def testNoSwapping(self):
|
||||||
"""Make sure the graph is preserved when there is nothing to swap."""
|
"""Make sure the graph is preserved when there is nothing to swap."""
|
||||||
a = variables.VariableV1(10, name='a')
|
a = variables.VariableV1(10, name='a')
|
||||||
@ -60,6 +62,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
|||||||
self.assertEqual(len(graph.node), graph_size)
|
self.assertEqual(len(graph.node), graph_size)
|
||||||
self.assertItemsEqual([node.name for node in graph.node], nodes)
|
self.assertItemsEqual([node.name for node in graph.node], nodes)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testSimpleSwap(self):
|
def testSimpleSwap(self):
|
||||||
"""Check that the swap annotations are followed."""
|
"""Check that the swap annotations are followed."""
|
||||||
a = variables.VariableV1(10, name='a')
|
a = variables.VariableV1(10, name='a')
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user