Merge pull request #20159 from luk-ai:cast-quantized
PiperOrigin-RevId: 205881436
This commit is contained in:
commit
4f4091db6d
@ -482,6 +482,7 @@ class Tensor {
|
||||
friend class VariableOp; // For access to set_shape
|
||||
friend class AutoReloadVariableOp; // For access to set_shape
|
||||
friend class TensorTestHelper; // For access to set_shape
|
||||
friend class CastOpBase; // For access to set_dtype;
|
||||
friend class OpKernelContext; // For access to RefCountIsOne().
|
||||
friend class ScopedAllocator; // For access to buf_.
|
||||
friend class XlaTensor; // For access to RefCountIsOne().
|
||||
|
@ -55,8 +55,39 @@ typedef Eigen::SyclDevice SYCLDevice;
|
||||
FN(arg0, std::complex<double>)
|
||||
|
||||
CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
|
||||
|
||||
// Quantized data types use the same underlying format as their non quantized
|
||||
// version so we use the non quantized implementation for casting.
|
||||
if (external_dst_dtype_ == DT_QUINT8) {
|
||||
dst_dtype_ = DT_UINT8;
|
||||
} else if (external_dst_dtype_ == DT_QINT8) {
|
||||
dst_dtype_ = DT_INT8;
|
||||
} else if (external_dst_dtype_ == DT_QINT32) {
|
||||
dst_dtype_ = DT_INT32;
|
||||
} else if (external_dst_dtype_ == DT_QINT16) {
|
||||
dst_dtype_ = DT_INT16;
|
||||
} else if (external_dst_dtype_ == DT_QUINT16) {
|
||||
dst_dtype_ = DT_UINT16;
|
||||
} else {
|
||||
dst_dtype_ = external_dst_dtype_;
|
||||
}
|
||||
|
||||
if (external_src_dtype_ == DT_QUINT8) {
|
||||
src_dtype_ = DT_UINT8;
|
||||
} else if (external_src_dtype_ == DT_QINT8) {
|
||||
src_dtype_ = DT_INT8;
|
||||
} else if (external_src_dtype_ == DT_QINT32) {
|
||||
src_dtype_ = DT_INT32;
|
||||
} else if (external_src_dtype_ == DT_QINT16) {
|
||||
src_dtype_ = DT_INT16;
|
||||
} else if (external_src_dtype_ == DT_QUINT16) {
|
||||
src_dtype_ = DT_UINT16;
|
||||
} else {
|
||||
src_dtype_ = external_src_dtype_;
|
||||
}
|
||||
}
|
||||
|
||||
void CastOpBase::Compute(OpKernelContext* ctx) {
|
||||
@ -64,15 +95,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
|
||||
if (work_ == nullptr) {
|
||||
ctx->set_output(0, inp);
|
||||
} else {
|
||||
Tensor in;
|
||||
in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
|
||||
work_(ctx, inp, out);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
|
||||
out->set_dtype(dst_dtype_);
|
||||
work_(ctx, in, out);
|
||||
out->set_dtype(external_dst_dtype_);
|
||||
}
|
||||
}
|
||||
|
||||
Status CastOpBase::Unimplemented() {
|
||||
return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
|
||||
DataTypeString(dst_dtype_), " is not supported");
|
||||
return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
|
||||
" to ", DataTypeString(external_dst_dtype_),
|
||||
" is not supported");
|
||||
}
|
||||
|
||||
CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
|
||||
@ -80,7 +116,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
|
||||
}
|
||||
|
||||
Status CpuCastOp::Prepare() {
|
||||
if (src_dtype_ == dst_dtype_) {
|
||||
if (external_src_dtype_ == external_dst_dtype_) {
|
||||
work_ = nullptr; // Identity
|
||||
return Status::OK();
|
||||
}
|
||||
@ -133,7 +169,7 @@ class GpuCastOp : public CastOpBase {
|
||||
|
||||
private:
|
||||
Status Prepare() {
|
||||
if (src_dtype_ == dst_dtype_) {
|
||||
if (external_src_dtype_ == external_dst_dtype_) {
|
||||
work_ = nullptr; // Identity
|
||||
return Status::OK();
|
||||
}
|
||||
@ -215,7 +251,7 @@ class SyclCastOp : public CastOpBase {
|
||||
|
||||
private:
|
||||
Status Prepare() {
|
||||
if (src_dtype_ == dst_dtype_) {
|
||||
if (external_src_dtype_ == external_dst_dtype_) {
|
||||
work_ = nullptr; // Identity
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -36,6 +36,8 @@ class CastOpBase : public OpKernel {
|
||||
protected:
|
||||
DataType src_dtype_;
|
||||
DataType dst_dtype_;
|
||||
DataType external_src_dtype_;
|
||||
DataType external_dst_dtype_;
|
||||
std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
|
||||
|
||||
Status Unimplemented();
|
||||
|
@ -78,7 +78,12 @@ class CastOpTest : public OpsTestBase {
|
||||
TEST_CAST(in, half); \
|
||||
TEST_CAST(in, float); \
|
||||
TEST_CAST(in, double); \
|
||||
TEST_CAST(in, bfloat16);
|
||||
TEST_CAST(in, bfloat16); \
|
||||
TEST_CAST(in, quint8); \
|
||||
TEST_CAST(in, qint8); \
|
||||
TEST_CAST(in, qint32); \
|
||||
TEST_CAST(in, qint16); \
|
||||
TEST_CAST(in, quint16);
|
||||
|
||||
TEST_ALL_CASTS_FROM(uint8)
|
||||
TEST_ALL_CASTS_FROM(uint16)
|
||||
@ -91,6 +96,11 @@ TEST_ALL_CASTS_FROM(half)
|
||||
TEST_ALL_CASTS_FROM(float)
|
||||
TEST_ALL_CASTS_FROM(double)
|
||||
TEST_ALL_CASTS_FROM(bfloat16)
|
||||
TEST_ALL_CASTS_FROM(quint8)
|
||||
TEST_ALL_CASTS_FROM(qint8)
|
||||
TEST_ALL_CASTS_FROM(qint32)
|
||||
TEST_ALL_CASTS_FROM(qint16)
|
||||
TEST_ALL_CASTS_FROM(quint16)
|
||||
|
||||
#undef TEST_ALL_CASTS_FROM
|
||||
#undef TEST_CAST
|
||||
|
@ -311,8 +311,10 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testUnsupportedDtype(self):
|
||||
with self.test_session():
|
||||
param = ops.convert_to_tensor(
|
||||
np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
|
||||
dtype=dtypes.qint16)
|
||||
with self.assertRaises(TypeError):
|
||||
param = array_ops.ones([int(2**11+1)], dtype=dtypes.qint16)
|
||||
du.embed_check_categorical_event_shape(param)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user