Move ImmediateExecutionTensorHandle::DataType() to AbstractTensorHandle.

Minor signature change of AddInputList.
Implement AddInputList for GraphContext. This is needed for downstream
implementation of gradients.

PiperOrigin-RevId: 320513498
Change-Id: I16ee672c6d6f1f9b260b319c954158694c82a1db
This commit is contained in:
Saurabh Saxena 2020-07-09 18:06:17 -07:00 committed by TensorFlower Gardener
parent 78199b484c
commit fdce31d861
8 changed files with 41 additions and 11 deletions

View File

@ -177,7 +177,9 @@ cc_library(
visibility = [
"//tensorflow:internal",
],
deps = [],
deps = [
"//tensorflow/core:protos_all_cc",
],
)
cc_library(

View File

@ -73,7 +73,8 @@ class AbstractOperation {
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandle* input) = 0;
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
virtual Status AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;

View File

@ -16,6 +16,8 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
@ -27,6 +29,9 @@ class AbstractTensorHandle {
virtual ~AbstractTensorHandle() {}
public:
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.

View File

@ -49,6 +49,10 @@ class GraphTensor : public TracingTensorHandle {
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
}
TF_Output output_;
// For LLVM style RTTI.
@ -102,9 +106,18 @@ class GraphOperation : public TracingOperation {
TF_AddInput(op_.get(), t->output_);
return Status::OK();
}
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
return tensorflow::errors::Unimplemented(
"AddInputList has not been implemented yet.");
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
std::vector<TF_Output> tf_outputs(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
tf_outputs[i] = t->output_;
}
TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
return Status::OK();
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {

View File

@ -33,8 +33,6 @@ namespace tensorflow {
// is needed a static_cast can be applied.
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
public:
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns number of dimensions.
virtual Status NumDims(int* num_dims) const = 0;
// Returns number of elements across all dimensions.

View File

@ -93,6 +93,15 @@ class MlirTensor : public TracingTensorHandle {
explicit MlirTensor(Value value)
: TracingTensorHandle(kMlir), value_(value) {}
tensorflow::DataType DataType() const override {
tensorflow::DataType type;
Status s = ConvertScalarTypeToDataType(value_.getType(), &type);
if (!s.ok()) {
return tensorflow::DT_INVALID;
}
return type;
}
void Release() override { delete this; }
Value getValue() { return value_; }
@ -127,7 +136,7 @@ class MlirAbstractOp : public TracingOperation {
Status SetDeviceName(const char* name) override;
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
@ -464,7 +473,8 @@ Status MlirAbstractOp::SetDeviceName(const char* name) {
return Status::OK();
}
Status MlirAbstractOp::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
Status MlirAbstractOp::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
return tensorflow::errors::Unimplemented(
"AddInputList has not been implemented yet.");
}

View File

@ -260,7 +260,8 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
return MaybeInferSingleInputAttrs(h);
}
Status EagerOperation::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
Status EagerOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input);
AddTensorHandle(h);

View File

@ -80,7 +80,7 @@ class EagerOperation : public ImmediateExecutionOperation {
Status SetAttrValue(const char* attr_name, const AttrValue& value);
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override { return op_def_; };