Unify bitcast errors between eager and graph mode
PiperOrigin-RevId: 324251337 Change-Id: I5945713530d5ed00e647db98be281e545bc73d09
This commit is contained in:
parent
4f2eefab89
commit
7f3772b7b8
@ -22,8 +22,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* shape, size_t input_type_size,
|
||||
size_t output_type_size, TF_Status* status) {
|
||||
TF_ShapeHandle* shape, TF_DataType input_type,
|
||||
TF_DataType output_type, TF_Status* status) {
|
||||
size_t input_type_size = TF_DataTypeSize(input_type);
|
||||
size_t output_type_size = TF_DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast type " << input_type << " to " << output_type
|
||||
<< " because one of the type sizes is zero";
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (input_type_size < output_type_size) {
|
||||
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
|
||||
@ -37,9 +48,9 @@ static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast due to shape. "
|
||||
<< TF_DimensionHandleValue(last_dim) << " does not match "
|
||||
<< divisor_val;
|
||||
err << "Cannot bitcast from " << input_type << " to " << output_type
|
||||
<< " due to shape. " << TF_DimensionHandleValue(last_dim)
|
||||
<< " does not match " << divisor_val;
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
TF_DeleteDimensionHandle(last_dim);
|
||||
@ -78,23 +89,8 @@ static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
|
||||
}
|
||||
|
||||
size_t input_type_size;
|
||||
size_t output_type_size;
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
input_type_size = TF_DataTypeSize(input_type);
|
||||
output_type_size = TF_DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast type " << input_type << " to " << output_type
|
||||
<< " because one of the type sizes is zero";
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
|
||||
ComputeNewShape(ctx, result, input_type, output_type, status);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -60,11 +62,11 @@ class BitcastTest(test.TestCase):
|
||||
shape = [3, 4]
|
||||
self._testBitcast(x, dtypes.int64, shape)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testErrors(self):
|
||||
x = np.zeros([1, 1], np.int8)
|
||||
datatype = dtypes.int32
|
||||
with self.assertRaisesRegex(ValueError, "Cannot bitcast due to shape"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"Cannot bitcast from 6 to 3"):
|
||||
array_ops.bitcast(x, datatype, None)
|
||||
|
||||
def testEmpty(self):
|
||||
@ -73,11 +75,12 @@ class BitcastTest(test.TestCase):
|
||||
shape = [4]
|
||||
self._testBitcast(x, datatype, shape)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testUnknown(self):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
datatype = dtypes.int8
|
||||
array_ops.bitcast(x, datatype, None)
|
||||
def testUnknownShape(self):
|
||||
# Need to use placeholder for unknown shape
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
datatype = dtypes.int8
|
||||
array_ops.bitcast(x, datatype, None)
|
||||
|
||||
def testQuantizedType(self):
|
||||
shape = [3, 4]
|
||||
|
Loading…
Reference in New Issue
Block a user