Make "missing variant function" error status a bit more readable

PiperOrigin-RevId: 357657033
Change-Id: Id97aad3ec193dee78dcfe0c8dc3d9fdb4e236ab7
This commit is contained in:
Sanjoy Das 2021-02-15 23:41:21 -08:00 committed by TensorFlower Gardener
parent 460d16750d
commit 603a21810f
2 changed files with 32 additions and 8 deletions

View File

@ -26,6 +26,26 @@ limitations under the License.
namespace tensorflow {
const char* VariantUnaryOpToString(VariantUnaryOp op) {
switch (op) {
case INVALID_VARIANT_UNARY_OP:
return "INVALID";
case ZEROS_LIKE_VARIANT_UNARY_OP:
return "ZEROS_LIKE";
case CONJ_VARIANT_UNARY_OP:
return "CONJ";
}
}
const char* VariantBinaryOpToString(VariantBinaryOp op) {
switch (op) {
case INVALID_VARIANT_BINARY_OP:
return "INVALID";
case ADD_VARIANT_BINARY_OP:
return "ADD";
}
}
std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
static std::unordered_set<string>* string_storage =
new std::unordered_set<string>();

View File

@ -44,11 +44,15 @@ enum VariantUnaryOp {
CONJ_VARIANT_UNARY_OP = 2,
};
const char* VariantUnaryOpToString(VariantUnaryOp op);
enum VariantBinaryOp {
INVALID_VARIANT_BINARY_OP = 0,
ADD_VARIANT_BINARY_OP = 1,
};
const char* VariantBinaryOpToString(VariantBinaryOp op);
enum VariantDeviceCopyDirection {
INVALID_DEVICE_COPY_DIRECTION = 0,
HOST_TO_DEVICE = 1,
@ -311,9 +315,10 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
return errors::Internal("No unary variant unary_op function found for op ",
VariantUnaryOpToString(op),
" Variant type_name: ", v.TypeName(),
" for device type: ", device);
}
return (*unary_op_fn)(ctx, v, v_out);
}
@ -340,11 +345,10 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
"enum: ",
op, " Variant type_name: '", a.TypeName(), "' for device type: ",
device);
return errors::Internal("No unary variant binary_op function found for op ",
VariantBinaryOpToString(op),
" Variant type_name: '", a.TypeName(),
"' for device type: ", device);
}
return (*binary_op_fn)(ctx, a, b, out);
}