Merge pull request #20159 from luk-ai:cast-quantized

PiperOrigin-RevId: 205881436
This commit is contained in:
TensorFlower Gardener 2018-07-24 13:38:33 -07:00
commit 4f4091db6d
5 changed files with 62 additions and 11 deletions

View File

@ -482,6 +482,7 @@ class Tensor {
friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape
friend class CastOpBase; // For access to set_dtype;
friend class OpKernelContext; // For access to RefCountIsOne().
friend class ScopedAllocator; // For access to buf_.
friend class XlaTensor; // For access to RefCountIsOne().

View File

@ -55,8 +55,39 @@ typedef Eigen::SyclDevice SYCLDevice;
FN(arg0, std::complex<double>)
CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
// Quantized data types use the same underlying format as their non quantized
// version so we use the non quantized implementation for casting.
if (external_dst_dtype_ == DT_QUINT8) {
dst_dtype_ = DT_UINT8;
} else if (external_dst_dtype_ == DT_QINT8) {
dst_dtype_ = DT_INT8;
} else if (external_dst_dtype_ == DT_QINT32) {
dst_dtype_ = DT_INT32;
} else if (external_dst_dtype_ == DT_QINT16) {
dst_dtype_ = DT_INT16;
} else if (external_dst_dtype_ == DT_QUINT16) {
dst_dtype_ = DT_UINT16;
} else {
dst_dtype_ = external_dst_dtype_;
}
if (external_src_dtype_ == DT_QUINT8) {
src_dtype_ = DT_UINT8;
} else if (external_src_dtype_ == DT_QINT8) {
src_dtype_ = DT_INT8;
} else if (external_src_dtype_ == DT_QINT32) {
src_dtype_ = DT_INT32;
} else if (external_src_dtype_ == DT_QINT16) {
src_dtype_ = DT_INT16;
} else if (external_src_dtype_ == DT_QUINT16) {
src_dtype_ = DT_UINT16;
} else {
src_dtype_ = external_src_dtype_;
}
}
void CastOpBase::Compute(OpKernelContext* ctx) {
@ -64,15 +95,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
if (work_ == nullptr) {
ctx->set_output(0, inp);
} else {
Tensor in;
in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
work_(ctx, inp, out);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
work_(ctx, in, out);
out->set_dtype(external_dst_dtype_);
}
}
Status CastOpBase::Unimplemented() {
return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
DataTypeString(dst_dtype_), " is not supported");
return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
" to ", DataTypeString(external_dst_dtype_),
" is not supported");
}
CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
@ -80,7 +116,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
}
Status CpuCastOp::Prepare() {
if (src_dtype_ == dst_dtype_) {
if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
@ -133,7 +169,7 @@ class GpuCastOp : public CastOpBase {
private:
Status Prepare() {
if (src_dtype_ == dst_dtype_) {
if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
@ -215,7 +251,7 @@ class SyclCastOp : public CastOpBase {
private:
Status Prepare() {
if (src_dtype_ == dst_dtype_) {
if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}

View File

@ -36,6 +36,8 @@ class CastOpBase : public OpKernel {
protected:
DataType src_dtype_;
DataType dst_dtype_;
DataType external_src_dtype_;
DataType external_dst_dtype_;
std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
Status Unimplemented();

View File

@ -78,7 +78,12 @@ class CastOpTest : public OpsTestBase {
TEST_CAST(in, half); \
TEST_CAST(in, float); \
TEST_CAST(in, double); \
TEST_CAST(in, bfloat16);
TEST_CAST(in, bfloat16); \
TEST_CAST(in, quint8); \
TEST_CAST(in, qint8); \
TEST_CAST(in, qint32); \
TEST_CAST(in, qint16); \
TEST_CAST(in, quint16);
TEST_ALL_CASTS_FROM(uint8)
TEST_ALL_CASTS_FROM(uint16)
@ -91,6 +96,11 @@ TEST_ALL_CASTS_FROM(half)
TEST_ALL_CASTS_FROM(float)
TEST_ALL_CASTS_FROM(double)
TEST_ALL_CASTS_FROM(bfloat16)
TEST_ALL_CASTS_FROM(quint8)
TEST_ALL_CASTS_FROM(qint8)
TEST_ALL_CASTS_FROM(qint32)
TEST_ALL_CASTS_FROM(qint16)
TEST_ALL_CASTS_FROM(quint16)
#undef TEST_ALL_CASTS_FROM
#undef TEST_CAST

View File

@ -311,8 +311,10 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self):
with self.test_session():
param = ops.convert_to_tensor(
np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
dtype=dtypes.qint16)
with self.assertRaises(TypeError):
param = array_ops.ones([int(2**11+1)], dtype=dtypes.qint16)
du.embed_check_categorical_event_shape(param)