Update keras local layer to direct import python modules, rather than use it from backend.
PiperOrigin-RevId: 335569756 Change-Id: I2c639a96f23a9c9da9033dfec97b428f749fa2ea
This commit is contained in:
parent
739e277d0b
commit
ee4897dc41
@ -29,6 +29,8 @@ from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -775,7 +777,7 @@ def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
|
||||
kernel = kernel_mask * kernel
|
||||
kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
|
||||
|
||||
output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
|
||||
output_flat = math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
|
||||
output = K.reshape(output_flat,
|
||||
[K.shape(output_flat)[0],] + output_shape.as_list()[1:])
|
||||
return output
|
||||
@ -834,11 +836,11 @@ def make_2d(tensor, split_dim):
|
||||
Tensor of shape
|
||||
`(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
|
||||
"""
|
||||
shape = K.array_ops.shape(tensor)
|
||||
shape = array_ops.shape(tensor)
|
||||
in_dims = shape[:split_dim]
|
||||
out_dims = shape[split_dim:]
|
||||
|
||||
in_size = K.math_ops.reduce_prod(in_dims)
|
||||
out_size = K.math_ops.reduce_prod(out_dims)
|
||||
in_size = math_ops.reduce_prod(in_dims)
|
||||
out_size = math_ops.reduce_prod(out_dims)
|
||||
|
||||
return K.array_ops.reshape(tensor, (in_size, out_size))
|
||||
return array_ops.reshape(tensor, (in_size, out_size))
|
||||
|
||||
@ -22,9 +22,13 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
@ -423,9 +427,9 @@ def get_inputs(data_format, filters, height, num_samples, width):
|
||||
def xent(y_true, y_pred):
|
||||
y_true = keras.backend.cast(
|
||||
keras.backend.reshape(y_true, (-1,)),
|
||||
keras.backend.dtypes_module.int32)
|
||||
dtypes.int32)
|
||||
|
||||
return keras.backend.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
return nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=y_true,
|
||||
logits=y_pred)
|
||||
|
||||
@ -496,9 +500,9 @@ def copy_lc_weights_2_to_1(lc_layer_2_from, lc_layer_1_to):
|
||||
lc_2_kernel_masked = keras.backend.permute_dimensions(
|
||||
lc_2_kernel_masked, permutation)
|
||||
|
||||
lc_2_kernel_mask = keras.backend.math_ops.not_equal(
|
||||
lc_2_kernel_mask = math_ops.not_equal(
|
||||
lc_2_kernel_masked, 0)
|
||||
lc_2_kernel_flat = keras.backend.array_ops.boolean_mask(
|
||||
lc_2_kernel_flat = array_ops.boolean_mask(
|
||||
lc_2_kernel_masked, lc_2_kernel_mask)
|
||||
lc_2_kernel_reshaped = keras.backend.reshape(lc_2_kernel_flat,
|
||||
lc_layer_1_to.kernel.shape)
|
||||
@ -516,8 +520,8 @@ def copy_lc_weights_2_to_3(lc_layer_2_from, lc_layer_3_to):
|
||||
lc_2_kernel_masked = keras.layers.local.make_2d(
|
||||
lc_2_kernel_masked, split_dim=keras.backend.ndim(lc_2_kernel_masked) // 2)
|
||||
lc_2_kernel_masked = keras.backend.transpose(lc_2_kernel_masked)
|
||||
lc_2_kernel_mask = keras.backend.math_ops.not_equal(lc_2_kernel_masked, 0)
|
||||
lc_2_kernel_flat = keras.backend.array_ops.boolean_mask(
|
||||
lc_2_kernel_mask = math_ops.not_equal(lc_2_kernel_masked, 0)
|
||||
lc_2_kernel_flat = array_ops.boolean_mask(
|
||||
lc_2_kernel_masked, lc_2_kernel_mask)
|
||||
|
||||
lc_2_kernel_flat = keras.backend.get_value(lc_2_kernel_flat)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user