Tighten up constant_op.constant casting in eager-mode.
Previously, eager would always cast all values to the requested dtype.
This didn't match graph mode, which would only allow casting values between
'compatible' types (e.g. all integer types are compatible with each other, but
no floating type is compatible with any integer type).
Graph mode uses _AssertCompatible (dc10ac4559/tensorflow/python/framework/tensor_util.py (L345)
)
to determine type compatibility. Eager mode type inference is a little
different.
After this CL, the intention is that constant_op.constant behave identically in graph and eager.
Note that this doesn't check correctly for overflows (in graph or eager). This means "tf.constant(544444, dtype=tf.uint8) < 200" will both return True.
PiperOrigin-RevId: 218717988
This commit is contained in:
parent
52c3f583aa
commit
5982692eae
@ -347,7 +347,7 @@ class Mean(Metric):
|
||||
Raises:
|
||||
ValueError: if the optional argument is not bool
|
||||
"""
|
||||
# Convert the boolean to tensor for tf.cond, if it is not.
|
||||
# Convert the boolean to tensor for tf.cond, if it is not.
|
||||
if not isinstance(write_summary, ops.Tensor):
|
||||
write_summary = ops.convert_to_tensor(write_summary)
|
||||
t = self.numer / self.denom
|
||||
@ -487,6 +487,8 @@ class BinaryAccuracy(Mean):
|
||||
message="Shapes of labels and predictions are unequal")
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
predictions = predictions > self.threshold
|
||||
# Convert labels to bool to match predictions.
|
||||
labels = math_ops.cast(labels, dtypes.bool)
|
||||
matches = math_ops.equal(labels, predictions)
|
||||
matches = math_ops.cast(matches, self.dtype)
|
||||
super(BinaryAccuracy, self).call(matches, weights=weights)
|
||||
|
@ -27,6 +27,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tape",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/python:cpp_python_util",
|
||||
|
@ -574,10 +574,14 @@ def _num_elements(grad):
|
||||
raise ValueError("`grad` not a Tensor or IndexedSlices.")
|
||||
|
||||
|
||||
def _cast_constant(value, dtype):
|
||||
return math_ops.cast(constant_op.constant(value), dtype)
|
||||
|
||||
|
||||
def _fast_fill(value, shape, dtype):
|
||||
return array_ops.fill(
|
||||
constant_op.constant(shape, dtype=dtypes.int32),
|
||||
constant_op.constant(value, dtype=dtype))
|
||||
_cast_constant(shape, dtype=dtypes.int32),
|
||||
_cast_constant(value, dtype=dtype))
|
||||
|
||||
|
||||
def _zeros(shape, dtype):
|
||||
@ -605,7 +609,7 @@ def _ones(shape, dtype):
|
||||
return array_ops.ones(shape, dtype)
|
||||
|
||||
if shape == (): # pylint: disable=g-explicit-bool-comparison
|
||||
return constant_op.constant(1, dtype=dtype)
|
||||
return _cast_constant(1, dtype=dtype)
|
||||
return _fast_fill(1, shape, dtype)
|
||||
|
||||
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
#include "structmember.h" // NOLINT // For PyMemberDef
|
||||
|
||||
// forward declare
|
||||
@ -135,6 +137,39 @@ PyObject* PyIntFromDataType(TF_DataType l) {
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
// This function checks whether the desired type is "compatible" with the
|
||||
// inferred type. At a high level, compatibility means that all integral types
|
||||
// are compatible with each other, and all floating types are compatible with
|
||||
// each other.
|
||||
//
|
||||
// Type compatibility doesn't consider overflows (i.e. int64 is *always*
|
||||
// compatible with int32). This is intended to match graph behavior.
|
||||
bool IsCompatible(int desired_dtype, TF_DataType returned_dtype) {
|
||||
tensorflow::DataType desired =
|
||||
static_cast<tensorflow::DataType>(desired_dtype);
|
||||
tensorflow::DataType returned =
|
||||
static_cast<tensorflow::DataType>(returned_dtype);
|
||||
|
||||
if (desired == returned) return true;
|
||||
|
||||
if (tensorflow::DataTypeIsInteger(desired) &&
|
||||
tensorflow::DataTypeIsInteger(returned)) {
|
||||
return true;
|
||||
} else if (tensorflow::DataTypeIsFloating(desired) &&
|
||||
tensorflow::DataTypeIsFloating(returned)) {
|
||||
return true;
|
||||
} else if (tensorflow::DataTypeIsComplex(desired) &&
|
||||
(tensorflow::DataTypeIsComplex(returned) ||
|
||||
tensorflow::DataTypeIsInteger(returned) ||
|
||||
tensorflow::DataTypeIsFloating(returned))) {
|
||||
return true;
|
||||
} else if (tensorflow::DataTypeIsQuantized(desired) &&
|
||||
tensorflow::DataTypeIsInteger(returned)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Casts data referred to by `handle` from type `src_type_enum` to type
|
||||
// `dst_type_enum`.
|
||||
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
@ -376,20 +411,33 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||
if (handle == nullptr) return -1;
|
||||
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
|
||||
handle = tensorflow::make_safe(tensorflow::EagerCast(
|
||||
GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), self->status));
|
||||
if (TF_GetCode(self->status) != TF_OK) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Error while casting from DataType ", handle_dtype,
|
||||
" to ", desired_dtype, ". ", TF_Message(self->status))
|
||||
.c_str());
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
// Check type compatibility.
|
||||
if (tensorflow::IsCompatible(desired_dtype, handle_dtype)) {
|
||||
handle = tensorflow::make_safe(tensorflow::EagerCast(
|
||||
GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), self->status));
|
||||
if (TF_GetCode(self->status) != TF_OK) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat("Error while casting from DataType ",
|
||||
handle_dtype, " to ", desired_dtype,
|
||||
". ", TF_Message(self->status))
|
||||
.c_str());
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
return -1;
|
||||
}
|
||||
handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
} else {
|
||||
tensorflow::Safe_PyObjectPtr value_str(PyObject_Str(value));
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Cannot convert value ", TFE_GetPythonString(value_str.get()),
|
||||
" to EagerTensor with requested dtype: ", desired_dtype)
|
||||
.c_str());
|
||||
return -1;
|
||||
}
|
||||
handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
}
|
||||
|
||||
// Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
||||
|
@ -26,6 +26,7 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
||||
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
|
||||
|
||||
namespace tensorflow {
|
||||
bool IsCompatible(int desired_dtype, TF_DataType returned_dtype);
|
||||
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
|
||||
|
||||
// TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <cstring>
|
||||
#include <thread>
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -2113,7 +2114,9 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
|
||||
*handle = tensorflow::make_safe(
|
||||
tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), status));
|
||||
if (!status->status.ok()) return false;
|
||||
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
||||
return false;
|
||||
}
|
||||
output_dtype = desired_dtype;
|
||||
}
|
||||
|
||||
@ -2122,7 +2125,9 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
|
||||
// if copying to the same device.
|
||||
*handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
|
||||
handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
|
||||
if (!status->status.ok()) return false;
|
||||
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -2205,14 +2210,19 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
||||
return true;
|
||||
}
|
||||
|
||||
// Supports only 2 cases at the moment:
|
||||
// i) input is an EagerTensor
|
||||
// Supports 3 cases at the moment:
|
||||
// i) input is an EagerTensor.
|
||||
// ii) input is a ResourceVariable - in this case, the is_variable param is
|
||||
// set to true.
|
||||
// iii) input is an arbitrary python list/tuple (note, this handling doesn't
|
||||
// support packing).
|
||||
//
|
||||
// NOTE: dtype_hint_getter must *always* return a PyObject that can be
|
||||
// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
|
||||
// increfs Py_None).
|
||||
//
|
||||
// NOTE: This function sets a python error directly, and returns false.
|
||||
// TF_Status is only passed since we don't want to have to reallocate it.
|
||||
bool ConvertToTensor(
|
||||
const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
||||
tensorflow::Safe_PyObjectPtr* output_handle,
|
||||
@ -2239,25 +2249,43 @@ bool ConvertToTensor(
|
||||
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
|
||||
tensorflow::ConvertToEagerTensor(input, dtype_hint.get())));
|
||||
if (handle == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to convert value to tensor");
|
||||
return false;
|
||||
return MaybeRaiseExceptionFromTFStatus(status, nullptr);
|
||||
}
|
||||
|
||||
int desired_dtype = -1;
|
||||
if (dtype_hint.get() != Py_None) {
|
||||
if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Expecting a DataType value for dtype. Got ",
|
||||
Py_TYPE(dtype_hint.get())->tp_name);
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a DataType value for dtype. Got ",
|
||||
Py_TYPE(dtype_hint.get())->tp_name)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
|
||||
&handle, status)) {
|
||||
return false;
|
||||
}
|
||||
// Maybe cast to the desired type. This is intended to match python
|
||||
// convert_to_tensor behavior.
|
||||
TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
if (desired_dtype >= 0 && desired_dtype != output_dtype) {
|
||||
if (tensorflow::IsCompatible(desired_dtype, output_dtype)) {
|
||||
if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
|
||||
&handle, status)) {
|
||||
return false;
|
||||
}
|
||||
output_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
} else {
|
||||
tensorflow::Safe_PyObjectPtr input_str(PyObject_Str(input));
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Cannot convert value ", TFE_GetPythonString(input_str.get()),
|
||||
" to EagerTensor with requested dtype: ", desired_dtype)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
output_handle->reset(EagerTensorFromHandle(handle.release()));
|
||||
dtype_setter(output_dtype);
|
||||
|
||||
|
@ -139,8 +139,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(t, 1.0)
|
||||
|
||||
def testConstantDtype(self):
|
||||
self.assertEqual(constant_op.constant(1.0, dtype=np.int64).dtype,
|
||||
dtypes.int64)
|
||||
self.assertEqual(
|
||||
constant_op.constant(1, dtype=np.int64).dtype, dtypes.int64)
|
||||
|
||||
def testTensorAndNumpyMatrix(self):
|
||||
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
||||
@ -250,6 +250,61 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(UnicodeDecodeError):
|
||||
io_ops.read_file(b"\xff")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvertToTensorPreferredDtypeIsRespected(self):
|
||||
self.assertEqual(
|
||||
ops.convert_to_tensor(0.5, preferred_dtype=dtypes.int32).dtype,
|
||||
dtypes.float32)
|
||||
self.assertEqual(
|
||||
ops.convert_to_tensor(0.5, preferred_dtype=dtypes.float64).dtype,
|
||||
dtypes.float64)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCompatibility(self):
|
||||
# TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode.
|
||||
integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||
dtypes.uint8, dtypes.uint16]
|
||||
|
||||
# Floats are not compatible with ints
|
||||
for t in integer_types:
|
||||
with self.assertRaises(TypeError):
|
||||
constant_op.constant(0.5, dtype=t)
|
||||
|
||||
# Ints compatible with floats
|
||||
self.assertEqual(
|
||||
self.evaluate(constant_op.constant(5, dtype=dtypes.float16)), 5.0)
|
||||
self.assertEqual(
|
||||
self.evaluate(constant_op.constant(5, dtype=dtypes.float32)), 5.0)
|
||||
self.assertEqual(
|
||||
self.evaluate(constant_op.constant(5, dtype=dtypes.float64)), 5.0)
|
||||
|
||||
# Ints and floats are compatible with complex types
|
||||
self.assertEqual(
|
||||
constant_op.constant([[1.0]], dtype=dtypes.complex128).dtype,
|
||||
dtypes.complex128)
|
||||
self.assertEqual(
|
||||
constant_op.constant([[1]], dtype=dtypes.complex128).dtype,
|
||||
dtypes.complex128)
|
||||
|
||||
# Quantized types are not compatible with floats
|
||||
quantized_types = [dtypes.qint16, dtypes.qint32, dtypes.qint8,
|
||||
dtypes.quint16, dtypes.quint8]
|
||||
|
||||
for t in quantized_types:
|
||||
with self.assertRaises(TypeError):
|
||||
constant_op.constant(0.5, dtype=t)
|
||||
|
||||
# TODO(b/118402529): quantized types are broken in eager.
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCConvertToTensor(self):
|
||||
with self.assertRaises(TypeError):
|
||||
_ = constant_op.constant(0) < 0.5
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvertToTensorAllowsOverflow(self):
|
||||
_ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8)
|
||||
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user