Remove the deps to Keras model from gradients_test.py
The Mnist model was not using any Keras model functionality like compile/fit. We can just use a layer to replace it, which already include __call__(). PiperOrigin-RevId: 315488824 Change-Id: Ibbffd05b448f8d02f0211b6f4e4f40e5e065e3de
This commit is contained in:
parent
d2b02580e6
commit
502a4bf641
@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.layers import layers as tf_layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
|
||||
@ -215,7 +214,7 @@ def create_lstm_per_eg_grad(batch_size, state_size, steps, inputs_size=None):
|
||||
# Importing the code from tensorflow_models seems to cause errors. Hence we
|
||||
# duplicate the model definition here.
|
||||
# TODO(agarwal): Use the version in tensorflow_models/official instead.
|
||||
class Mnist(keras_training.Model):
|
||||
class Mnist(tf_layers.Layer):
|
||||
|
||||
def __init__(self, data_format):
|
||||
"""Creates a model for classifying a hand-written digit.
|
||||
|
Loading…
x
Reference in New Issue
Block a user