[tfdbg2] Fork local_cli_wrapper_test for keras related tests.
PiperOrigin-RevId: 316950499 Change-Id: I428273592694426f72427e6236c68bdfb4e95eba
This commit is contained in:
parent
9426d35abc
commit
cca9b615b2
|
@ -788,7 +788,6 @@ cuda_py_test(
|
|||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -1400,7 +1399,6 @@ py_test(
|
|||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -22,10 +22,10 @@ import tempfile
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
|
@ -36,9 +36,6 @@ from tensorflow.python.framework import errors
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
@ -832,40 +829,6 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||
run_output = wrapped_sess.run([])
|
||||
self.assertEqual([], run_output)
|
||||
|
||||
def testDebuggingKerasFitWithSkippedRunsWorks(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run"], ["run"], ["run", "-t", "10"]], self.sess)
|
||||
|
||||
backend.set_session(wrapped_sess)
|
||||
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Dense(4, input_shape=[2], activation="relu"))
|
||||
model.add(core.Dense(1))
|
||||
model.compile(loss="mse", optimizer="sgd")
|
||||
|
||||
x = np.zeros([8, 2])
|
||||
y = np.zeros([8, 1])
|
||||
model.fit(x, y, epochs=2)
|
||||
|
||||
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
|
||||
|
||||
def testDebuggingKerasFitWithProfilingWorks(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run", "-p"]] * 10, self.sess)
|
||||
|
||||
backend.set_session(wrapped_sess)
|
||||
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Dense(4, input_shape=[2], activation="relu"))
|
||||
model.add(core.Dense(1))
|
||||
model.compile(loss="mse", optimizer="sgd")
|
||||
|
||||
x = np.zeros([8, 2])
|
||||
y = np.zeros([8, 1])
|
||||
model.fit(x, y, epochs=2)
|
||||
|
||||
self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
|
||||
|
||||
def testRunsWithEmptyNestedFetchWorks(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run"]], self.sess, dump_root="")
|
||||
|
|
Loading…
Reference in New Issue