Internal change
PiperOrigin-RevId: 321900773 Change-Id: If154d6c4f44e002db2a4b4f431f17f715dcc06ca
This commit is contained in:
parent
9b59779bc3
commit
b4ee16059c
@ -35,6 +35,9 @@
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
||||
type annotation for variables representing a Tensor or a value that can be
|
||||
converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
|
@ -233,7 +233,7 @@ def make_tensor(v, arg_name):
|
||||
(repr(v), arg_name))
|
||||
|
||||
|
||||
def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
|
||||
"""Convert sequence `l` to eager same-type Tensors."""
|
||||
if (not l) and (default_dtype is not None):
|
||||
return default_dtype, [] # List is empty; assume default dtype.
|
||||
@ -243,8 +243,6 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
break
|
||||
else: # note: intentional for-else
|
||||
return l[0]._datatype_enum(), l # pylint: disable=protected-access
|
||||
# TODO(josh11b): Could we do a better job if we also passed in the
|
||||
# allowed dtypes when that was known?
|
||||
|
||||
# Is some input already a Tensor with a dtype?
|
||||
dtype = None
|
||||
@ -256,13 +254,28 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
if dtype is None:
|
||||
# Infer a dtype based on the first value, and use that dtype for the
|
||||
# remaining values.
|
||||
|
||||
ret = []
|
||||
for t in l:
|
||||
ret.append(
|
||||
ops.convert_to_tensor(
|
||||
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
|
||||
tensor = None
|
||||
# First see if we can get a valid dtype with the default conversion
|
||||
# and see if it matches an allowed dtypes. Some ops like ConcatV2 may
|
||||
# not list allowed dtypes, in which case we should skip this.
|
||||
if dtype is None and allowed_dtypes:
|
||||
tensor = ops.convert_to_tensor(t, ctx=ctx)
|
||||
# If we did not match an allowed dtype, try again with the default
|
||||
# dtype. This could be because we have an empty tensor and thus we
|
||||
# picked the wrong type.
|
||||
if tensor.dtype not in allowed_dtypes:
|
||||
tensor = None
|
||||
|
||||
if tensor is None:
|
||||
tensor = ops.convert_to_tensor(
|
||||
t, dtype, preferred_dtype=default_dtype, ctx=ctx)
|
||||
|
||||
ret.append(tensor)
|
||||
if dtype is None:
|
||||
dtype = ret[-1].dtype
|
||||
dtype = tensor.dtype
|
||||
else:
|
||||
ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||
|
||||
|
@ -326,17 +326,42 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
def testArgsToMatchingEagerDefault(self):
|
||||
# Uses default
|
||||
ctx = context.context()
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
|
||||
allowed_dtypes = [dtypes.int32, dtypes.int64]
|
||||
|
||||
# Follows standard int conversion rules
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
|
||||
dtypes.int32)
|
||||
self.assertEqual(t, dtypes.int32)
|
||||
self.assertEqual(r[0].dtype, dtypes.int32)
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
|
||||
dtypes.int64)
|
||||
self.assertEqual(t, dtypes.int32)
|
||||
self.assertEqual(r[0].dtype, dtypes.int32)
|
||||
# Use int64 since it is a better fit
|
||||
t, r = execute.args_to_matching_eager([[2**48]], ctx, allowed_dtypes,
|
||||
dtypes.int32)
|
||||
self.assertEqual(t, dtypes.int64)
|
||||
self.assertEqual(r[0].dtype, dtypes.int64)
|
||||
t, r = execute.args_to_matching_eager([], ctx, dtypes.int64)
|
||||
|
||||
# When the regular tensor conversion fails, then use the default type as a
|
||||
# hint.
|
||||
allowed_dtypes = [dtypes.uint32, dtypes.uint32]
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
|
||||
dtypes.uint32)
|
||||
self.assertEqual(t, dtypes.uint32)
|
||||
self.assertEqual(r[0].dtype, dtypes.uint32)
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
|
||||
dtypes.uint64)
|
||||
self.assertEqual(t, dtypes.uint64)
|
||||
self.assertEqual(r[0].dtype, dtypes.uint64)
|
||||
|
||||
t, r = execute.args_to_matching_eager([], ctx, allowed_dtypes, dtypes.int64)
|
||||
self.assertEqual(t, dtypes.int64)
|
||||
|
||||
# Doesn't use default
|
||||
t, r = execute.args_to_matching_eager(
|
||||
[['string', 'arg']], ctx, dtypes.int32)
|
||||
allowed_dtypes = [dtypes.int32, dtypes.string]
|
||||
t, r = execute.args_to_matching_eager([['string', 'arg']], ctx,
|
||||
allowed_dtypes, dtypes.int32)
|
||||
self.assertEqual(t, dtypes.string)
|
||||
self.assertEqual(r[0].dtype, dtypes.string)
|
||||
|
||||
|
@ -40,7 +40,7 @@ def _eager_reshape(tensor, shape, ctx):
|
||||
"""Eager-only version of Reshape op; requires tensor is an eager Tensor."""
|
||||
attr_t = tensor._datatype_enum() # pylint: disable=protected-access
|
||||
attr_tshape, (shape,) = execute.args_to_matching_eager(
|
||||
[shape], ctx, dtypes.int32)
|
||||
[shape], ctx, [dtypes.int32, dtypes.int64], dtypes.int32)
|
||||
inputs_flat = [tensor, shape]
|
||||
attrs = ("T", attr_t, "Tshape", attr_tshape)
|
||||
result, = execute.execute(
|
||||
|
@ -337,6 +337,7 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
# on the other. Handling this will require restructuring this code
|
||||
# significantly.
|
||||
default_type_attr_map = {}
|
||||
allowed_list_attr_map = {}
|
||||
for attr_def in op_def.attr:
|
||||
if attr_def.type != "type":
|
||||
continue
|
||||
@ -344,6 +345,8 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
if attr_def.HasField("default_value"):
|
||||
default_type_attr_map[key] = dtypes.as_dtype(
|
||||
attr_def.default_value.type)
|
||||
if attr_def.HasField("allowed_values"):
|
||||
allowed_list_attr_map[key] = attr_def.allowed_values.list.type
|
||||
|
||||
# Requires that op_def has passed validation (using the C++
|
||||
# ValidateOpDef() from ../framework/op_def_util.h).
|
||||
@ -451,6 +454,7 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
# arguments to that type.
|
||||
dtype = None
|
||||
default_dtype = None
|
||||
allowed_list = None
|
||||
if input_arg.type != types_pb2.DT_INVALID:
|
||||
dtype = input_arg.type
|
||||
elif input_arg.type_attr in attrs:
|
||||
@ -460,14 +464,41 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
# so we prefer the attr's default, so code that adds a new attr
|
||||
# with a default is backwards compatible.
|
||||
default_dtype = default_type_attr_map[input_arg.type_attr]
|
||||
allowed_list = allowed_list_attr_map.get(input_arg.type_attr)
|
||||
|
||||
try:
|
||||
values = ops.convert_to_tensor(
|
||||
values,
|
||||
name=input_arg.name,
|
||||
dtype=dtype,
|
||||
as_ref=input_arg.is_ref,
|
||||
preferred_dtype=default_dtype)
|
||||
# First see if we can get a valid dtype with the default conversion
|
||||
# and see if it matches an allowed dtypes. Some ops like ConcatV2 may
|
||||
# not list allowed dtypes, in which case we should skip this.
|
||||
if dtype is None and allowed_list:
|
||||
inferred = None
|
||||
try:
|
||||
inferred = ops.convert_to_tensor(
|
||||
values, name=input_arg.name, as_ref=input_arg.is_ref)
|
||||
except TypeError as err:
|
||||
# When converting a python object such as a list of Dimensions, we
|
||||
# need a dtype to be specified, thus tensor conversion may throw
|
||||
# an exception which we will ignore and try again below.
|
||||
pass
|
||||
|
||||
# If we did not match an allowed dtype, try again with the default
|
||||
# dtype. This could be because we have an empty tensor and thus we
|
||||
# picked the wrong type.
|
||||
if inferred is not None and inferred.dtype in allowed_list:
|
||||
values = inferred
|
||||
else:
|
||||
values = ops.convert_to_tensor(
|
||||
values,
|
||||
name=input_arg.name,
|
||||
as_ref=input_arg.is_ref,
|
||||
preferred_dtype=default_dtype)
|
||||
else:
|
||||
values = ops.convert_to_tensor(
|
||||
values,
|
||||
name=input_arg.name,
|
||||
dtype=dtype,
|
||||
as_ref=input_arg.is_ref,
|
||||
preferred_dtype=default_dtype)
|
||||
except TypeError as err:
|
||||
if dtype is None:
|
||||
raise err
|
||||
|
@ -1008,6 +1008,16 @@ void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
|
||||
FlattenInputs(&arg_list->second, &output_sizes);
|
||||
string conversion = strings::StrCat("_execute.args_to_matching_eager(",
|
||||
flattened, ", ctx");
|
||||
|
||||
strings::StrAppend(&conversion, ", [");
|
||||
for (int t : attr.allowed_values().list().type()) {
|
||||
DataType dtype = static_cast<DataType>(t);
|
||||
const string py_dtype =
|
||||
python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
|
||||
strings::StrAppend(&conversion, py_dtype, ", ");
|
||||
}
|
||||
strings::StrAppend(&conversion, "]");
|
||||
|
||||
if (attr.has_default_value()) {
|
||||
strings::StrAppend(
|
||||
&conversion, ", ",
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -191,6 +192,24 @@ class ReshapeTest(test.TestCase):
|
||||
dtypes.float32, shape=[None, 37, None])))
|
||||
self.assertEqual([None, 37, None], y.get_shape().as_list())
|
||||
|
||||
def testTensorShape(self):
|
||||
x = array_ops.zeros([1, 100])
|
||||
y = array_ops.reshape(
|
||||
x, [tensor_shape.Dimension(100),
|
||||
tensor_shape.Dimension(1)])
|
||||
self.assertEqual([100, 1], y.get_shape().as_list())
|
||||
y = array_ops.reshape(x, tensor_shape.TensorShape([100, 1]))
|
||||
self.assertEqual([100, 1], y.get_shape().as_list())
|
||||
|
||||
def testInt64Shape(self):
|
||||
x = array_ops.zeros([50000, 50000])
|
||||
# Provide dimension larger than int32
|
||||
y = array_ops.reshape(x, [50000**2])
|
||||
self.assertEqual([50000**2], y.get_shape().as_list())
|
||||
# Even if first dimension is within int32, ensure we correctly go to int64
|
||||
y = array_ops.reshape(x, [1, 50000**2])
|
||||
self.assertEqual([1, 50000**2], y.get_shape().as_list())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user