Use tf.function's default autograph=True in saved_model/integration_test.
There is currently no demonstrable need to do something non-obvious here. PiperOrigin-RevId: 259782705
This commit is contained in:
parent
7251a1efe4
commit
d0eeef269d
@ -117,7 +117,7 @@ def wrap_keras_model_for_export(model, batch_input_shape,
|
||||
# the desired argspec.
|
||||
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
|
||||
return call_fn(*args, **kwargs)
|
||||
traced_call_fn = tf.function(autograph=False)(
|
||||
traced_call_fn = tf.function(
|
||||
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
|
||||
|
||||
# Now we need to trigger traces for all supported combinations of the
|
||||
|
@ -37,7 +37,7 @@ def main(argv):
|
||||
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
||||
|
||||
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
||||
root.next_state = tf.function(root.rnn_cell.__call__, autograph=False)
|
||||
root.next_state = tf.function(root.rnn_cell.__call__)
|
||||
|
||||
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
|
||||
# attribute with the same name.
|
||||
|
Loading…
Reference in New Issue
Block a user