PrintV2 shape inference should succeed if rank is unknown

PiperOrigin-RevId: 276737419
Change-Id: I07ce500cb932d9f3e9e31952bd77cc9ddc53f00b
This commit is contained in:
Prakalp Srivastava 2019-10-25 12:22:41 -07:00 committed by TensorFlower Gardener
parent 80b16fc82a
commit a582a54b5b
3 changed files with 12 additions and 0 deletions

View File

@ -143,6 +143,10 @@ class PrintV2Op : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor* input_;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_->shape()),
errors::InvalidArgument("Input is expected to be scalar, but got ",
input_->shape()));
const string& msg = input_->scalar<tstring>()();
string ended_msg = strings::StrCat(msg, end_);

View File

@ -50,6 +50,12 @@ TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
}
TEST_F(PrintingV2GraphTest, InvalidInputRank) {
TF_ASSERT_OK(Init());
AddInputFromArray<tstring>(TensorShape({2}), {"bar", "foo"});
ASSERT_NE(::tensorflow::Status::OK(), RunOpKernel());
}
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",

View File

@ -52,6 +52,8 @@ REGISTER_OP("PrintV2")
.Attr("output_stream: string = 'stderr'")
.Attr("end: string = '\n'")
.SetShapeFn([](InferenceContext* c) {
// Early exit if rank is unknown.
if (!c->RankKnown(c->input(0))) return Status::OK();
// Make sure that the input is a scalar.
if (c->Rank(c->input(0)) != 0) {
return errors::InvalidArgument("input must be a scalar, but has rank: ",