Update all context.graph_context() to ops.get_default_graph().as_default().
PiperOrigin-RevId: 295230258 Change-Id: Id7da36f985d6eae3f9e884e1c8e1352565c4454f
This commit is contained in:
parent
81323b7924
commit
79ed5077ce
@ -571,24 +571,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# with the shape the Layer will be called on (these users will have to
|
# with the shape the Layer will be called on (these users will have to
|
||||||
# implement `compute_output_shape` themselves).
|
# implement `compute_output_shape` themselves).
|
||||||
self._maybe_build(input_shape)
|
self._maybe_build(input_shape)
|
||||||
with context.graph_mode():
|
with func_graph.FuncGraph('graph').as_default():
|
||||||
graph = func_graph.FuncGraph('graph')
|
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
||||||
with graph.as_default():
|
def _make_placeholder_like(shape):
|
||||||
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
ph = backend.placeholder(shape=shape, dtype=self.dtype)
|
||||||
def _make_placeholder_like(shape):
|
ph._keras_mask = None
|
||||||
ph = backend.placeholder(shape=shape, dtype=self.dtype)
|
return ph
|
||||||
ph._keras_mask = None
|
inputs = nest.map_structure(_make_placeholder_like, input_shape)
|
||||||
return ph
|
try:
|
||||||
inputs = nest.map_structure(_make_placeholder_like, input_shape)
|
outputs = self(inputs, training=False)
|
||||||
try:
|
except TypeError as e:
|
||||||
outputs = self(inputs, training=False)
|
six.raise_from(
|
||||||
except TypeError as e:
|
NotImplementedError(
|
||||||
six.raise_from(
|
'We could not automatically infer the static shape of the '
|
||||||
NotImplementedError(
|
'layer\'s output. Please implement the '
|
||||||
'We could not automatically infer the static shape of the '
|
'`compute_output_shape` method on your layer (%s).' %
|
||||||
'layer\'s output. Please implement the '
|
self.__class__.__name__), e)
|
||||||
'`compute_output_shape` method on your layer (%s).' %
|
|
||||||
self.__class__.__name__), e)
|
|
||||||
return nest.map_structure(lambda t: t.shape, outputs)
|
return nest.map_structure(lambda t: t.shape, outputs)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_dynamic_layer_error_running_in_graph_mode(self):
|
def test_dynamic_layer_error_running_in_graph_mode(self):
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||||
input_shape=(3,))
|
input_shape=(3,))
|
||||||
self.assertEqual(model.dynamic, True)
|
self.assertEqual(model.dynamic, True)
|
||||||
|
@ -535,7 +535,7 @@ class Layer(base_layer.Layer):
|
|||||||
# with the shape the Layer will be called on (these users will have to
|
# with the shape the Layer will be called on (these users will have to
|
||||||
# implement `compute_output_shape` themselves).
|
# implement `compute_output_shape` themselves).
|
||||||
self._maybe_build(input_shape)
|
self._maybe_build(input_shape)
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
graph = func_graph.FuncGraph('graph')
|
graph = func_graph.FuncGraph('graph')
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
||||||
|
@ -154,7 +154,7 @@ class CompileTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(model._loss_weights_list, [1., 2.])
|
self.assertAllEqual(model._loss_weights_list, [1., 2.])
|
||||||
|
|
||||||
def test_compile_with_multi_output_and_loss_weights_dict(self):
|
def test_compile_with_multi_output_and_loss_weights_dict(self):
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
model = self._get_multi_output_model()
|
model = self._get_multi_output_model()
|
||||||
loss_weights = {'dense_1': 1., 'dense_2': 2.}
|
loss_weights = {'dense_1': 1., 'dense_2': 2.}
|
||||||
model.compile(optimizer='adam', loss='mse', loss_weights=loss_weights)
|
model.compile(optimizer='adam', loss='mse', loss_weights=loss_weights)
|
||||||
@ -2142,7 +2142,7 @@ class LossWeightingTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def test_sample_weight_tensor(self):
|
def test_sample_weight_tensor(self):
|
||||||
"""Tests that sample weight may be defined as a tensor in the graph."""
|
"""Tests that sample weight may be defined as a tensor in the graph."""
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
# Create a simple pass-through model
|
# Create a simple pass-through model
|
||||||
input_layer = keras.layers.Input(shape=1, name='input_layer')
|
input_layer = keras.layers.Input(shape=1, name='input_layer')
|
||||||
model = keras.Model(inputs=input_layer, outputs=input_layer)
|
model = keras.Model(inputs=input_layer, outputs=input_layer)
|
||||||
|
@ -27,6 +27,7 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -399,10 +400,11 @@ def run_all_keras_modes(test_or_class=None,
|
|||||||
|
|
||||||
|
|
||||||
def _v1_session_test(f, test_or_class, config, *args, **kwargs):
|
def _v1_session_test(f, test_or_class, config, *args, **kwargs):
|
||||||
with context.graph_mode(), testing_utils.run_eagerly_scope(False):
|
with ops.get_default_graph().as_default():
|
||||||
with testing_utils.experimental_run_tf_function_scope(False):
|
with testing_utils.run_eagerly_scope(False):
|
||||||
with test_or_class.test_session(use_gpu=True, config=config):
|
with testing_utils.experimental_run_tf_function_scope(False):
|
||||||
f(test_or_class, *args, **kwargs)
|
with test_or_class.test_session(use_gpu=True, config=config):
|
||||||
|
f(test_or_class, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _v2_eager_test(f, test_or_class, *args, **kwargs):
|
def _v2_eager_test(f, test_or_class, *args, **kwargs):
|
||||||
|
@ -1063,8 +1063,9 @@ class LSTMPerformanceTest(test.Benchmark):
|
|||||||
' of normal LSTM, got {0:.2f}'.format(v2_vs_normal))
|
' of normal LSTM, got {0:.2f}'.format(v2_vs_normal))
|
||||||
|
|
||||||
def benchmark_performance_graph(self):
|
def benchmark_performance_graph(self):
|
||||||
with context.graph_mode(), session_lib.Session(config=_config):
|
with ops.get_default_graph().as_default():
|
||||||
self._benchmark_performance_with_standard_cudnn_impl()
|
with session_lib.Session(config=_config):
|
||||||
|
self._benchmark_performance_with_standard_cudnn_impl()
|
||||||
|
|
||||||
def benchmark_performance_eager(self):
|
def benchmark_performance_eager(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
|
@ -115,7 +115,7 @@ class KerasSumTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(self.evaluate(m.total), 63.75, 2)
|
self.assertAlmostEqual(self.evaluate(m.total), 63.75, 2)
|
||||||
|
|
||||||
def test_sum_graph_with_placeholder(self):
|
def test_sum_graph_with_placeholder(self):
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with ops.get_default_graph().as_default(), self.cached_session() as sess:
|
||||||
m = metrics.Sum()
|
m = metrics.Sum()
|
||||||
v = array_ops.placeholder(dtypes.float32)
|
v = array_ops.placeholder(dtypes.float32)
|
||||||
w = array_ops.placeholder(dtypes.float32)
|
w = array_ops.placeholder(dtypes.float32)
|
||||||
@ -265,7 +265,7 @@ class MeanTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_mean_graph_with_placeholder(self):
|
def test_mean_graph_with_placeholder(self):
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with ops.get_default_graph().as_default(), self.cached_session() as sess:
|
||||||
m = metrics.Mean()
|
m = metrics.Mean()
|
||||||
v = array_ops.placeholder(dtypes.float32)
|
v = array_ops.placeholder(dtypes.float32)
|
||||||
w = array_ops.placeholder(dtypes.float32)
|
w = array_ops.placeholder(dtypes.float32)
|
||||||
@ -575,7 +575,7 @@ class KerasAccuracyTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
|
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
|
||||||
|
|
||||||
def test_sparse_categorical_accuracy_mismatched_dims_dynamic(self):
|
def test_sparse_categorical_accuracy_mismatched_dims_dynamic(self):
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with ops.get_default_graph().as_default(), self.cached_session() as sess:
|
||||||
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
|
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
|
||||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||||
|
|
||||||
|
@ -607,7 +607,7 @@ class OptimizerTest(test.TestCase):
|
|||||||
self.assertLen(var_list(), 4)
|
self.assertLen(var_list(), 4)
|
||||||
|
|
||||||
def testVarKey(self):
|
def testVarKey(self):
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
a = variables.Variable([1., 2.], name='var')
|
a = variables.Variable([1., 2.], name='var')
|
||||||
b = variables.Variable([1.], name='var')
|
b = variables.Variable([1.], name='var')
|
||||||
self.assertTrue(a._in_graph_mode)
|
self.assertTrue(a._in_graph_mode)
|
||||||
@ -618,7 +618,7 @@ class OptimizerTest(test.TestCase):
|
|||||||
self.assertEqual('var_1', var_key)
|
self.assertEqual('var_1', var_key)
|
||||||
|
|
||||||
def testVarName(self):
|
def testVarName(self):
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
var = variables.Variable([1., 2.], name='var')
|
var = variables.Variable([1., 2.], name='var')
|
||||||
loss = var + 1.
|
loss = var + 1.
|
||||||
opt = adam.Adam()
|
opt = adam.Adam()
|
||||||
|
@ -1007,7 +1007,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
model.load_weights(fname)
|
model.load_weights(fname)
|
||||||
|
|
||||||
def test_no_graph_pollution(self):
|
def test_no_graph_pollution(self):
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default(), self.session(graph) as session:
|
with graph.as_default(), self.session(graph) as session:
|
||||||
model = SubclassedModel()
|
model = SubclassedModel()
|
||||||
|
@ -24,7 +24,6 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
@ -152,7 +151,7 @@ class SplitUtilsTest(keras_parameterized.TestCase):
|
|||||||
model = keras.Sequential([keras.layers.Dense(1)])
|
model = keras.Sequential([keras.layers.Dense(1)])
|
||||||
model.compile('sgd', 'mse')
|
model.compile('sgd', 'mse')
|
||||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
with context.graph_mode():
|
with ops.get_default_graph().as_default():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, 'instance was constructed with eager mode enabled'):
|
ValueError, 'instance was constructed with eager mode enabled'):
|
||||||
model.fit(x, y, batch_size=2)
|
model.fit(x, y, batch_size=2)
|
||||||
|
Loading…
Reference in New Issue
Block a user