diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index d2f2609d3b8..1b19ab5da31 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -482,6 +482,7 @@ class Tensor { friend class VariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape friend class TensorTestHelper; // For access to set_shape + friend class CastOpBase; // For access to set_dtype; friend class OpKernelContext; // For access to RefCountIsOne(). friend class ScopedAllocator; // For access to buf_. friend class XlaTensor; // For access to RefCountIsOne(). diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index e6e388b3d10..b4c97df38b2 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -55,8 +55,39 @@ typedef Eigen::SyclDevice SYCLDevice; FN(arg0, std::complex) CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_)); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_)); + + // Quantized data types use the same underlying format as their non quantized + // version so we use the non quantized implementation for casting. + if (external_dst_dtype_ == DT_QUINT8) { + dst_dtype_ = DT_UINT8; + } else if (external_dst_dtype_ == DT_QINT8) { + dst_dtype_ = DT_INT8; + } else if (external_dst_dtype_ == DT_QINT32) { + dst_dtype_ = DT_INT32; + } else if (external_dst_dtype_ == DT_QINT16) { + dst_dtype_ = DT_INT16; + } else if (external_dst_dtype_ == DT_QUINT16) { + dst_dtype_ = DT_UINT16; + } else { + dst_dtype_ = external_dst_dtype_; + } + + if (external_src_dtype_ == DT_QUINT8) { + src_dtype_ = DT_UINT8; + } else if (external_src_dtype_ == DT_QINT8) { + src_dtype_ = DT_INT8; + } else if (external_src_dtype_ == DT_QINT32) { + src_dtype_ = DT_INT32; + } else if (external_src_dtype_ == DT_QINT16) { + src_dtype_ = DT_INT16; + } else if (external_src_dtype_ == DT_QUINT16) { + src_dtype_ = DT_UINT16; + } else { + src_dtype_ = external_src_dtype_; + } } void CastOpBase::Compute(OpKernelContext* ctx) { @@ -64,15 +95,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) { if (work_ == nullptr) { ctx->set_output(0, inp); } else { + Tensor in; + in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); - work_(ctx, inp, out); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); + out->set_dtype(dst_dtype_); + work_(ctx, in, out); + out->set_dtype(external_dst_dtype_); } } Status CastOpBase::Unimplemented() { - return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", - DataTypeString(dst_dtype_), " is not supported"); + return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_), + " to ", DataTypeString(external_dst_dtype_), + " is not supported"); } CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { @@ -80,7 +116,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { } Status CpuCastOp::Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } @@ -133,7 +169,7 @@ class GpuCastOp : public CastOpBase { private: Status Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } @@ -215,7 +251,7 @@ class SyclCastOp : public CastOpBase { private: Status Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index 16d2e0e0a56..aae1e7ff190 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -36,6 +36,8 @@ class CastOpBase : public OpKernel { protected: DataType src_dtype_; DataType dst_dtype_; + DataType external_src_dtype_; + DataType external_dst_dtype_; std::function work_ = nullptr; Status Unimplemented(); diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 75e21802c05..9bbf7afb162 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -78,7 +78,12 @@ class CastOpTest : public OpsTestBase { TEST_CAST(in, half); \ TEST_CAST(in, float); \ TEST_CAST(in, double); \ - TEST_CAST(in, bfloat16); + TEST_CAST(in, bfloat16); \ + TEST_CAST(in, quint8); \ + TEST_CAST(in, qint8); \ + TEST_CAST(in, qint32); \ + TEST_CAST(in, qint16); \ + TEST_CAST(in, quint16); TEST_ALL_CASTS_FROM(uint8) TEST_ALL_CASTS_FROM(uint16) @@ -91,6 +96,11 @@ TEST_ALL_CASTS_FROM(half) TEST_ALL_CASTS_FROM(float) TEST_ALL_CASTS_FROM(double) TEST_ALL_CASTS_FROM(bfloat16) +TEST_ALL_CASTS_FROM(quint8) +TEST_ALL_CASTS_FROM(qint8) +TEST_ALL_CASTS_FROM(qint32) +TEST_ALL_CASTS_FROM(qint16) +TEST_ALL_CASTS_FROM(quint16) #undef TEST_ALL_CASTS_FROM #undef TEST_CAST diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 9d38ffcb4a9..61faa8466ed 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -311,8 +311,10 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): with self.test_session(): + param = ops.convert_to_tensor( + np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), + dtype=dtypes.qint16) with self.assertRaises(TypeError): - param = array_ops.ones([int(2**11+1)], dtype=dtypes.qint16) du.embed_check_categorical_event_shape(param)