Internal change

PiperOrigin-RevId: 321900773
Change-Id: If154d6c4f44e002db2a4b4f431f17f715dcc06ca
This commit is contained in:
A. Unique TensorFlower 2020-07-17 20:11:34 -07:00 committed by TensorFlower Gardener
parent 9b59779bc3
commit b4ee16059c
7 changed files with 120 additions and 19 deletions

View File

@ -35,6 +35,9 @@
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
type annotation for variables representing a Tensor or a value that can be
converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* `tf.data`:
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.

View File

@ -233,7 +233,7 @@ def make_tensor(v, arg_name):
(repr(v), arg_name))
def args_to_matching_eager(l, ctx, default_dtype=None):
def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
"""Convert sequence `l` to eager same-type Tensors."""
if (not l) and (default_dtype is not None):
return default_dtype, [] # List is empty; assume default dtype.
@ -243,8 +243,6 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
break
else: # note: intentional for-else
return l[0]._datatype_enum(), l # pylint: disable=protected-access
# TODO(josh11b): Could we do a better job if we also passed in the
# allowed dtypes when that was known?
# Is some input already a Tensor with a dtype?
dtype = None
@ -256,13 +254,28 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
if dtype is None:
# Infer a dtype based on the first value, and use that dtype for the
# remaining values.
ret = []
for t in l:
ret.append(
ops.convert_to_tensor(
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
tensor = None
# First see if we can get a valid dtype with the default conversion
# and see if it matches an allowed dtypes. Some ops like ConcatV2 may
# not list allowed dtypes, in which case we should skip this.
if dtype is None and allowed_dtypes:
tensor = ops.convert_to_tensor(t, ctx=ctx)
# If we did not match an allowed dtype, try again with the default
# dtype. This could be because we have an empty tensor and thus we
# picked the wrong type.
if tensor.dtype not in allowed_dtypes:
tensor = None
if tensor is None:
tensor = ops.convert_to_tensor(
t, dtype, preferred_dtype=default_dtype, ctx=ctx)
ret.append(tensor)
if dtype is None:
dtype = ret[-1].dtype
dtype = tensor.dtype
else:
ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]

View File

@ -326,17 +326,42 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testArgsToMatchingEagerDefault(self):
# Uses default
ctx = context.context()
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
allowed_dtypes = [dtypes.int32, dtypes.int64]
# Follows standard int conversion rules
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
dtypes.int32)
self.assertEqual(t, dtypes.int32)
self.assertEqual(r[0].dtype, dtypes.int32)
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
dtypes.int64)
self.assertEqual(t, dtypes.int32)
self.assertEqual(r[0].dtype, dtypes.int32)
# Use int64 since it is a better fit
t, r = execute.args_to_matching_eager([[2**48]], ctx, allowed_dtypes,
dtypes.int32)
self.assertEqual(t, dtypes.int64)
self.assertEqual(r[0].dtype, dtypes.int64)
t, r = execute.args_to_matching_eager([], ctx, dtypes.int64)
# When the regular tensor conversion fails, then use the default type as a
# hint.
allowed_dtypes = [dtypes.uint32, dtypes.uint32]
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
dtypes.uint32)
self.assertEqual(t, dtypes.uint32)
self.assertEqual(r[0].dtype, dtypes.uint32)
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
dtypes.uint64)
self.assertEqual(t, dtypes.uint64)
self.assertEqual(r[0].dtype, dtypes.uint64)
t, r = execute.args_to_matching_eager([], ctx, allowed_dtypes, dtypes.int64)
self.assertEqual(t, dtypes.int64)
# Doesn't use default
t, r = execute.args_to_matching_eager(
[['string', 'arg']], ctx, dtypes.int32)
allowed_dtypes = [dtypes.int32, dtypes.string]
t, r = execute.args_to_matching_eager([['string', 'arg']], ctx,
allowed_dtypes, dtypes.int32)
self.assertEqual(t, dtypes.string)
self.assertEqual(r[0].dtype, dtypes.string)

View File

@ -40,7 +40,7 @@ def _eager_reshape(tensor, shape, ctx):
"""Eager-only version of Reshape op; requires tensor is an eager Tensor."""
attr_t = tensor._datatype_enum() # pylint: disable=protected-access
attr_tshape, (shape,) = execute.args_to_matching_eager(
[shape], ctx, dtypes.int32)
[shape], ctx, [dtypes.int32, dtypes.int64], dtypes.int32)
inputs_flat = [tensor, shape]
attrs = ("T", attr_t, "Tshape", attr_tshape)
result, = execute.execute(

View File

@ -337,6 +337,7 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
# on the other. Handling this will require restructuring this code
# significantly.
default_type_attr_map = {}
allowed_list_attr_map = {}
for attr_def in op_def.attr:
if attr_def.type != "type":
continue
@ -344,6 +345,8 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
if attr_def.HasField("default_value"):
default_type_attr_map[key] = dtypes.as_dtype(
attr_def.default_value.type)
if attr_def.HasField("allowed_values"):
allowed_list_attr_map[key] = attr_def.allowed_values.list.type
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
@ -451,6 +454,7 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
# arguments to that type.
dtype = None
default_dtype = None
allowed_list = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.type_attr in attrs:
@ -460,14 +464,41 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
# so we prefer the attr's default, so code that adds a new attr
# with a default is backwards compatible.
default_dtype = default_type_attr_map[input_arg.type_attr]
allowed_list = allowed_list_attr_map.get(input_arg.type_attr)
try:
values = ops.convert_to_tensor(
values,
name=input_arg.name,
dtype=dtype,
as_ref=input_arg.is_ref,
preferred_dtype=default_dtype)
# First see if we can get a valid dtype with the default conversion
# and see if it matches an allowed dtypes. Some ops like ConcatV2 may
# not list allowed dtypes, in which case we should skip this.
if dtype is None and allowed_list:
inferred = None
try:
inferred = ops.convert_to_tensor(
values, name=input_arg.name, as_ref=input_arg.is_ref)
except TypeError as err:
# When converting a python object such as a list of Dimensions, we
# need a dtype to be specified, thus tensor conversion may throw
# an exception which we will ignore and try again below.
pass
# If we did not match an allowed dtype, try again with the default
# dtype. This could be because we have an empty tensor and thus we
# picked the wrong type.
if inferred is not None and inferred.dtype in allowed_list:
values = inferred
else:
values = ops.convert_to_tensor(
values,
name=input_arg.name,
as_ref=input_arg.is_ref,
preferred_dtype=default_dtype)
else:
values = ops.convert_to_tensor(
values,
name=input_arg.name,
dtype=dtype,
as_ref=input_arg.is_ref,
preferred_dtype=default_dtype)
except TypeError as err:
if dtype is None:
raise err

View File

@ -1008,6 +1008,16 @@ void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
FlattenInputs(&arg_list->second, &output_sizes);
string conversion = strings::StrCat("_execute.args_to_matching_eager(",
flattened, ", ctx");
strings::StrAppend(&conversion, ", [");
for (int t : attr.allowed_values().list().type()) {
DataType dtype = static_cast<DataType>(t);
const string py_dtype =
python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
strings::StrAppend(&conversion, py_dtype, ", ");
}
strings::StrAppend(&conversion, "]");
if (attr.has_default_value()) {
strings::StrAppend(
&conversion, ", ",

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
@ -191,6 +192,24 @@ class ReshapeTest(test.TestCase):
dtypes.float32, shape=[None, 37, None])))
self.assertEqual([None, 37, None], y.get_shape().as_list())
def testTensorShape(self):
x = array_ops.zeros([1, 100])
y = array_ops.reshape(
x, [tensor_shape.Dimension(100),
tensor_shape.Dimension(1)])
self.assertEqual([100, 1], y.get_shape().as_list())
y = array_ops.reshape(x, tensor_shape.TensorShape([100, 1]))
self.assertEqual([100, 1], y.get_shape().as_list())
def testInt64Shape(self):
x = array_ops.zeros([50000, 50000])
# Provide dimension larger than int32
y = array_ops.reshape(x, [50000**2])
self.assertEqual([50000**2], y.get_shape().as_list())
# Even if first dimension is within int32, ensure we correctly go to int64
y = array_ops.reshape(x, [1, 50000**2])
self.assertEqual([1, 50000**2], y.get_shape().as_list())
if __name__ == "__main__":
test.main()