Merge pull request #20159 from luk-ai:cast-quantized
PiperOrigin-RevId: 205881436
This commit is contained in:
commit
4f4091db6d
@ -482,6 +482,7 @@ class Tensor {
|
|||||||
friend class VariableOp; // For access to set_shape
|
friend class VariableOp; // For access to set_shape
|
||||||
friend class AutoReloadVariableOp; // For access to set_shape
|
friend class AutoReloadVariableOp; // For access to set_shape
|
||||||
friend class TensorTestHelper; // For access to set_shape
|
friend class TensorTestHelper; // For access to set_shape
|
||||||
|
friend class CastOpBase; // For access to set_dtype;
|
||||||
friend class OpKernelContext; // For access to RefCountIsOne().
|
friend class OpKernelContext; // For access to RefCountIsOne().
|
||||||
friend class ScopedAllocator; // For access to buf_.
|
friend class ScopedAllocator; // For access to buf_.
|
||||||
friend class XlaTensor; // For access to RefCountIsOne().
|
friend class XlaTensor; // For access to RefCountIsOne().
|
||||||
|
@ -55,8 +55,39 @@ typedef Eigen::SyclDevice SYCLDevice;
|
|||||||
FN(arg0, std::complex<double>)
|
FN(arg0, std::complex<double>)
|
||||||
|
|
||||||
CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
|
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
|
||||||
|
|
||||||
|
// Quantized data types use the same underlying format as their non quantized
|
||||||
|
// version so we use the non quantized implementation for casting.
|
||||||
|
if (external_dst_dtype_ == DT_QUINT8) {
|
||||||
|
dst_dtype_ = DT_UINT8;
|
||||||
|
} else if (external_dst_dtype_ == DT_QINT8) {
|
||||||
|
dst_dtype_ = DT_INT8;
|
||||||
|
} else if (external_dst_dtype_ == DT_QINT32) {
|
||||||
|
dst_dtype_ = DT_INT32;
|
||||||
|
} else if (external_dst_dtype_ == DT_QINT16) {
|
||||||
|
dst_dtype_ = DT_INT16;
|
||||||
|
} else if (external_dst_dtype_ == DT_QUINT16) {
|
||||||
|
dst_dtype_ = DT_UINT16;
|
||||||
|
} else {
|
||||||
|
dst_dtype_ = external_dst_dtype_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (external_src_dtype_ == DT_QUINT8) {
|
||||||
|
src_dtype_ = DT_UINT8;
|
||||||
|
} else if (external_src_dtype_ == DT_QINT8) {
|
||||||
|
src_dtype_ = DT_INT8;
|
||||||
|
} else if (external_src_dtype_ == DT_QINT32) {
|
||||||
|
src_dtype_ = DT_INT32;
|
||||||
|
} else if (external_src_dtype_ == DT_QINT16) {
|
||||||
|
src_dtype_ = DT_INT16;
|
||||||
|
} else if (external_src_dtype_ == DT_QUINT16) {
|
||||||
|
src_dtype_ = DT_UINT16;
|
||||||
|
} else {
|
||||||
|
src_dtype_ = external_src_dtype_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CastOpBase::Compute(OpKernelContext* ctx) {
|
void CastOpBase::Compute(OpKernelContext* ctx) {
|
||||||
@ -64,15 +95,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
|
|||||||
if (work_ == nullptr) {
|
if (work_ == nullptr) {
|
||||||
ctx->set_output(0, inp);
|
ctx->set_output(0, inp);
|
||||||
} else {
|
} else {
|
||||||
|
Tensor in;
|
||||||
|
in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
|
||||||
Tensor* out = nullptr;
|
Tensor* out = nullptr;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
|
||||||
work_(ctx, inp, out);
|
out->set_dtype(dst_dtype_);
|
||||||
|
work_(ctx, in, out);
|
||||||
|
out->set_dtype(external_dst_dtype_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CastOpBase::Unimplemented() {
|
Status CastOpBase::Unimplemented() {
|
||||||
return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
|
return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
|
||||||
DataTypeString(dst_dtype_), " is not supported");
|
" to ", DataTypeString(external_dst_dtype_),
|
||||||
|
" is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
|
CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
|
||||||
@ -80,7 +116,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CpuCastOp::Prepare() {
|
Status CpuCastOp::Prepare() {
|
||||||
if (src_dtype_ == dst_dtype_) {
|
if (external_src_dtype_ == external_dst_dtype_) {
|
||||||
work_ = nullptr; // Identity
|
work_ = nullptr; // Identity
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -133,7 +169,7 @@ class GpuCastOp : public CastOpBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Status Prepare() {
|
Status Prepare() {
|
||||||
if (src_dtype_ == dst_dtype_) {
|
if (external_src_dtype_ == external_dst_dtype_) {
|
||||||
work_ = nullptr; // Identity
|
work_ = nullptr; // Identity
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -215,7 +251,7 @@ class SyclCastOp : public CastOpBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Status Prepare() {
|
Status Prepare() {
|
||||||
if (src_dtype_ == dst_dtype_) {
|
if (external_src_dtype_ == external_dst_dtype_) {
|
||||||
work_ = nullptr; // Identity
|
work_ = nullptr; // Identity
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,8 @@ class CastOpBase : public OpKernel {
|
|||||||
protected:
|
protected:
|
||||||
DataType src_dtype_;
|
DataType src_dtype_;
|
||||||
DataType dst_dtype_;
|
DataType dst_dtype_;
|
||||||
|
DataType external_src_dtype_;
|
||||||
|
DataType external_dst_dtype_;
|
||||||
std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
|
std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
|
||||||
|
|
||||||
Status Unimplemented();
|
Status Unimplemented();
|
||||||
|
@ -78,7 +78,12 @@ class CastOpTest : public OpsTestBase {
|
|||||||
TEST_CAST(in, half); \
|
TEST_CAST(in, half); \
|
||||||
TEST_CAST(in, float); \
|
TEST_CAST(in, float); \
|
||||||
TEST_CAST(in, double); \
|
TEST_CAST(in, double); \
|
||||||
TEST_CAST(in, bfloat16);
|
TEST_CAST(in, bfloat16); \
|
||||||
|
TEST_CAST(in, quint8); \
|
||||||
|
TEST_CAST(in, qint8); \
|
||||||
|
TEST_CAST(in, qint32); \
|
||||||
|
TEST_CAST(in, qint16); \
|
||||||
|
TEST_CAST(in, quint16);
|
||||||
|
|
||||||
TEST_ALL_CASTS_FROM(uint8)
|
TEST_ALL_CASTS_FROM(uint8)
|
||||||
TEST_ALL_CASTS_FROM(uint16)
|
TEST_ALL_CASTS_FROM(uint16)
|
||||||
@ -91,6 +96,11 @@ TEST_ALL_CASTS_FROM(half)
|
|||||||
TEST_ALL_CASTS_FROM(float)
|
TEST_ALL_CASTS_FROM(float)
|
||||||
TEST_ALL_CASTS_FROM(double)
|
TEST_ALL_CASTS_FROM(double)
|
||||||
TEST_ALL_CASTS_FROM(bfloat16)
|
TEST_ALL_CASTS_FROM(bfloat16)
|
||||||
|
TEST_ALL_CASTS_FROM(quint8)
|
||||||
|
TEST_ALL_CASTS_FROM(qint8)
|
||||||
|
TEST_ALL_CASTS_FROM(qint32)
|
||||||
|
TEST_ALL_CASTS_FROM(qint16)
|
||||||
|
TEST_ALL_CASTS_FROM(quint16)
|
||||||
|
|
||||||
#undef TEST_ALL_CASTS_FROM
|
#undef TEST_ALL_CASTS_FROM
|
||||||
#undef TEST_CAST
|
#undef TEST_CAST
|
||||||
|
@ -311,8 +311,10 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testUnsupportedDtype(self):
|
def testUnsupportedDtype(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
param = ops.convert_to_tensor(
|
||||||
|
np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
|
||||||
|
dtype=dtypes.qint16)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
param = array_ops.ones([int(2**11+1)], dtype=dtypes.qint16)
|
|
||||||
du.embed_check_categorical_event_shape(param)
|
du.embed_check_categorical_event_shape(param)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user