Unify bitcast errors between eager and graph mode

PiperOrigin-RevId: 324251337
Change-Id: I5945713530d5ed00e647db98be281e545bc73d09
This commit is contained in:
Gaurav Jain 2020-07-31 11:44:36 -07:00 committed by TensorFlower Gardener
parent 4f2eefab89
commit 7f3772b7b8
2 changed files with 27 additions and 28 deletions

View File

@ -22,8 +22,19 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
static void ComputeNewShape(TF_ShapeInferenceContext* ctx, static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape, size_t input_type_size, TF_ShapeHandle* shape, TF_DataType input_type,
size_t output_type_size, TF_Status* status) { TF_DataType output_type, TF_Status* status) {
size_t input_type_size = TF_DataTypeSize(input_type);
size_t output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
if (input_type_size < output_type_size) { if (input_type_size < output_type_size) {
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status); TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
@ -37,9 +48,9 @@ static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status); TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
} else { } else {
std::ostringstream err; std::ostringstream err;
err << "Cannot bitcast due to shape. " err << "Cannot bitcast from " << input_type << " to " << output_type
<< TF_DimensionHandleValue(last_dim) << " does not match " << " due to shape. " << TF_DimensionHandleValue(last_dim)
<< divisor_val; << " does not match " << divisor_val;
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
} }
TF_DeleteDimensionHandle(last_dim); TF_DeleteDimensionHandle(last_dim);
@ -78,23 +89,8 @@ static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status); TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
} }
size_t input_type_size;
size_t output_type_size;
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
input_type_size = TF_DataTypeSize(input_type); ComputeNewShape(ctx, result, input_type, output_type, status);
output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(status) == TF_OK) {
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
} }
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -60,11 +62,11 @@ class BitcastTest(test.TestCase):
shape = [3, 4] shape = [3, 4]
self._testBitcast(x, dtypes.int64, shape) self._testBitcast(x, dtypes.int64, shape)
@test_util.run_deprecated_v1
def testErrors(self): def testErrors(self):
x = np.zeros([1, 1], np.int8) x = np.zeros([1, 1], np.int8)
datatype = dtypes.int32 datatype = dtypes.int32
with self.assertRaisesRegex(ValueError, "Cannot bitcast due to shape"): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Cannot bitcast from 6 to 3"):
array_ops.bitcast(x, datatype, None) array_ops.bitcast(x, datatype, None)
def testEmpty(self): def testEmpty(self):
@ -73,8 +75,9 @@ class BitcastTest(test.TestCase):
shape = [4] shape = [4]
self._testBitcast(x, datatype, shape) self._testBitcast(x, datatype, shape)
@test_util.run_deprecated_v1 def testUnknownShape(self):
def testUnknown(self): # Need to use placeholder for unknown shape
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32) x = array_ops.placeholder(dtypes.float32)
datatype = dtypes.int8 datatype = dtypes.int8
array_ops.bitcast(x, datatype, None) array_ops.bitcast(x, datatype, None)