Improves constant shape inference for resource variables.
PiperOrigin-RevId: 223367586
This commit is contained in:
parent
d2253ab518
commit
2b559a9a08
@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -137,6 +138,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.evaluate(v[0].assign(2.0))
|
self.evaluate(v[0].assign(2.0))
|
||||||
self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
|
self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testVariableShape(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([1., 1.])
|
||||||
|
self.assertAllEqual(
|
||||||
|
tensor_util.constant_value(
|
||||||
|
resource_variable_ops.variable_shape(v.handle)),
|
||||||
|
[2])
|
||||||
|
|
||||||
def testDifferentAssignGraph(self):
|
def testDifferentAssignGraph(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
v = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.core.framework import variable_pb2
|
|||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -64,6 +65,7 @@ def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
|||||||
name=name,
|
name=name,
|
||||||
container=container)
|
container=container)
|
||||||
if graph_mode:
|
if graph_mode:
|
||||||
|
handle._handle_data = get_resource_handle_data(handle) # pylint: disable=protected-access
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
# We do not want two distinct ResourceVariable objects for the same
|
# We do not want two distinct ResourceVariable objects for the same
|
||||||
@ -1410,13 +1412,23 @@ def _ReadGrad(_, grad):
|
|||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
|
def variable_shape(handle, out_type=dtypes.int32):
|
||||||
|
if getattr(
|
||||||
|
handle, "_handle_data", None) is None or not handle._handle_data.is_set:
|
||||||
|
return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
|
||||||
|
shape_proto = handle._handle_data.shape_and_type[0].shape
|
||||||
|
if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
|
||||||
|
return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
|
||||||
|
return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("ResourceGather")
|
@ops.RegisterGradient("ResourceGather")
|
||||||
def _GatherGrad(op, grad):
|
def _GatherGrad(op, grad):
|
||||||
"""Gradient for gather op."""
|
"""Gradient for gather op."""
|
||||||
# Build appropriately shaped IndexedSlices
|
# Build appropriately shaped IndexedSlices
|
||||||
handle = op.inputs[0]
|
handle = op.inputs[0]
|
||||||
indices = op.inputs[1]
|
indices = op.inputs[1]
|
||||||
params_shape = gen_resource_variable_ops.variable_shape(handle)
|
params_shape = variable_shape(handle)
|
||||||
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
||||||
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
||||||
values = array_ops.reshape(grad, values_shape)
|
values = array_ops.reshape(grad, values_shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user