Improves constant shape inference for resource variables.

PiperOrigin-RevId: 223367586
This commit is contained in:
Alexandre Passos 2018-11-29 10:21:35 -08:00 committed by TensorFlower Gardener
parent d2253ab518
commit 2b559a9a08
2 changed files with 22 additions and 1 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -137,6 +138,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(v[0].assign(2.0))
self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
@test_util.run_in_graph_and_eager_modes
def testVariableShape(self):
v = resource_variable_ops.ResourceVariable([1., 1.])
self.assertAllEqual(
tensor_util.constant_value(
resource_variable_ops.variable_shape(v.handle)),
[2])
def testDifferentAssignGraph(self):
with ops.Graph().as_default():
v = resource_variable_ops.ResourceVariable(1.0)

View File

@ -26,6 +26,7 @@ from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -64,6 +65,7 @@ def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
name=name,
container=container)
if graph_mode:
handle._handle_data = get_resource_handle_data(handle) # pylint: disable=protected-access
return handle
# We do not want two distinct ResourceVariable objects for the same
@ -1410,13 +1412,23 @@ def _ReadGrad(_, grad):
return grad
def variable_shape(handle, out_type=dtypes.int32):
if getattr(
handle, "_handle_data", None) is None or not handle._handle_data.is_set:
return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
shape_proto = handle._handle_data.shape_and_type[0].shape
if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
@ops.RegisterGradient("ResourceGather")
def _GatherGrad(op, grad):
"""Gradient for gather op."""
# Build appropriately shaped IndexedSlices
handle = op.inputs[0]
indices = op.inputs[1]
params_shape = gen_resource_variable_ops.variable_shape(handle)
params_shape = variable_shape(handle)
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat([size, params_shape[1:]], 0)
values = array_ops.reshape(grad, values_shape)