PrintV2 shape inference should succeed if rank is unknown
PiperOrigin-RevId: 276737419 Change-Id: I07ce500cb932d9f3e9e31952bd77cc9ddc53f00b
This commit is contained in:
parent
80b16fc82a
commit
a582a54b5b
tensorflow/core
@ -143,6 +143,10 @@ class PrintV2Op : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* input_;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(input_->shape()),
|
||||
errors::InvalidArgument("Input is expected to be scalar, but got ",
|
||||
input_->shape()));
|
||||
const string& msg = input_->scalar<tstring>()();
|
||||
|
||||
string ended_msg = strings::StrCat(msg, end_);
|
||||
|
@ -50,6 +50,12 @@ TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
|
||||
ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
|
||||
}
|
||||
|
||||
TEST_F(PrintingV2GraphTest, InvalidInputRank) {
|
||||
TF_ASSERT_OK(Init());
|
||||
AddInputFromArray<tstring>(TensorShape({2}), {"bar", "foo"});
|
||||
ASSERT_NE(::tensorflow::Status::OK(), RunOpKernel());
|
||||
}
|
||||
|
||||
class PrintingGraphTest : public OpsTestBase {
|
||||
protected:
|
||||
Status Init(DataType input_type1, DataType input_type2, string msg = "",
|
||||
|
@ -52,6 +52,8 @@ REGISTER_OP("PrintV2")
|
||||
.Attr("output_stream: string = 'stderr'")
|
||||
.Attr("end: string = '\n'")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Early exit if rank is unknown.
|
||||
if (!c->RankKnown(c->input(0))) return Status::OK();
|
||||
// Make sure that the input is a scalar.
|
||||
if (c->Rank(c->input(0)) != 0) {
|
||||
return errors::InvalidArgument("input must be a scalar, but has rank: ",
|
||||
|
Loading…
Reference in New Issue
Block a user