diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py index a6d29b646a8..fdb70c52778 100644 --- a/tensorflow/python/ops/parallel_for/gradients_test.py +++ b/tensorflow/python/ops/parallel_for/gradients_test.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.layers import layers as tf_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops @@ -215,7 +214,7 @@ def create_lstm_per_eg_grad(batch_size, state_size, steps, inputs_size=None): # Importing the code from tensorflow_models seems to cause errors. Hence we # duplicate the model definition here. # TODO(agarwal): Use the version in tensorflow_models/official instead. -class Mnist(keras_training.Model): +class Mnist(tf_layers.Layer): def __init__(self, data_format): """Creates a model for classifying a hand-written digit.