Unify bitcast errors between eager and graph mode

PiperOrigin-RevId: 324251337
Change-Id: I5945713530d5ed00e647db98be281e545bc73d09
This commit is contained in:
Gaurav Jain 2020-07-31 11:44:36 -07:00 committed by TensorFlower Gardener
parent 4f2eefab89
commit 7f3772b7b8
2 changed files with 27 additions and 28 deletions

View File

@ -22,8 +22,19 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape, size_t input_type_size,
size_t output_type_size, TF_Status* status) {
TF_ShapeHandle* shape, TF_DataType input_type,
TF_DataType output_type, TF_Status* status) {
size_t input_type_size = TF_DataTypeSize(input_type);
size_t output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
TF_SetStatus(status, TF_OK, "");
if (input_type_size < output_type_size) {
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
@ -37,9 +48,9 @@ static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
} else {
std::ostringstream err;
err << "Cannot bitcast due to shape. "
<< TF_DimensionHandleValue(last_dim) << " does not match "
<< divisor_val;
err << "Cannot bitcast from " << input_type << " to " << output_type
<< " due to shape. " << TF_DimensionHandleValue(last_dim)
<< " does not match " << divisor_val;
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
TF_DeleteDimensionHandle(last_dim);
@ -78,23 +89,8 @@ static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
}
size_t input_type_size;
size_t output_type_size;
if (TF_GetCode(status) == TF_OK) {
input_type_size = TF_DataTypeSize(input_type);
output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(status) == TF_OK) {
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
ComputeNewShape(ctx, result, input_type, output_type, status);
}
if (TF_GetCode(status) == TF_OK) {

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -60,11 +62,11 @@ class BitcastTest(test.TestCase):
shape = [3, 4]
self._testBitcast(x, dtypes.int64, shape)
@test_util.run_deprecated_v1
def testErrors(self):
x = np.zeros([1, 1], np.int8)
datatype = dtypes.int32
with self.assertRaisesRegex(ValueError, "Cannot bitcast due to shape"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Cannot bitcast from 6 to 3"):
array_ops.bitcast(x, datatype, None)
def testEmpty(self):
@ -73,11 +75,12 @@ class BitcastTest(test.TestCase):
shape = [4]
self._testBitcast(x, datatype, shape)
@test_util.run_deprecated_v1
def testUnknown(self):
x = array_ops.placeholder(dtypes.float32)
datatype = dtypes.int8
array_ops.bitcast(x, datatype, None)
def testUnknownShape(self):
# Need to use placeholder for unknown shape
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32)
datatype = dtypes.int8
array_ops.bitcast(x, datatype, None)
def testQuantizedType(self):
shape = [3, 4]