From 502a4bf641e730ab7a269384b0be1bcdec9c3f61 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 9 Jun 2020 08:31:28 -0700 Subject: [PATCH] Remove the deps to Keras model from gradients_test.py The Mnist model was not using any Keras model functionality like compile/fit. We can just use a layer to replace it, which already include __call__(). PiperOrigin-RevId: 315488824 Change-Id: Ibbffd05b448f8d02f0211b6f4e4f40e5e065e3de --- tensorflow/python/ops/parallel_for/gradients_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py index a6d29b646a8..fdb70c52778 100644 --- a/tensorflow/python/ops/parallel_for/gradients_test.py +++ b/tensorflow/python/ops/parallel_for/gradients_test.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.layers import layers as tf_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops @@ -215,7 +214,7 @@ def create_lstm_per_eg_grad(batch_size, state_size, steps, inputs_size=None): # Importing the code from tensorflow_models seems to cause errors. Hence we # duplicate the model definition here. # TODO(agarwal): Use the version in tensorflow_models/official instead. -class Mnist(keras_training.Model): +class Mnist(tf_layers.Layer): def __init__(self, data_format): """Creates a model for classifying a hand-written digit.