Remove the deps to Keras model from gradients_test.py

The Mnist model was not using any Keras model functionality like compile/fit. We can just use a layer to replace it, which already include __call__().

PiperOrigin-RevId: 315488824
Change-Id: Ibbffd05b448f8d02f0211b6f4e4f40e5e065e3de
This commit is contained in:
Scott Zhu 2020-06-09 08:31:28 -07:00 committed by TensorFlower Gardener
parent d2b02580e6
commit 502a4bf641

View File

@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import layers as tf_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
@ -215,7 +214,7 @@ def create_lstm_per_eg_grad(batch_size, state_size, steps, inputs_size=None):
# Importing the code from tensorflow_models seems to cause errors. Hence we
# duplicate the model definition here.
# TODO(agarwal): Use the version in tensorflow_models/official instead.
class Mnist(keras_training.Model):
class Mnist(tf_layers.Layer):
def __init__(self, data_format):
"""Creates a model for classifying a hand-written digit.