Moving ComputeArgumentShapes to tpu_compile_op_support.

PiperOrigin-RevId: 324064650
Change-Id: I057003860e7849b63351d855b4a9c0a166a10cad
This commit is contained in:
Henry Tan 2020-07-30 12:59:17 -07:00 committed by TensorFlower Gardener
parent c1336e9a40
commit 34bd3aaad4
4 changed files with 46 additions and 49 deletions

View File

@ -413,46 +413,6 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
return Status::OK();
}
/* static */
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes) {
arg_shapes->resize(metadata.args_size());
int dynamic_shape_pos = 0;
for (int i = 0; i < metadata.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
// The XLA compiler determines the shape of each constant by inspecting the
// value of its corresponding host-memory tensor. As a result, we don't need
// to give the compiler graph-inferred shapes for constant arguments.
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
continue;
}
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
PartialTensorShape static_shape(arg.shape());
TensorShape& shape = (*arg_shapes)[i];
if (static_shape.IsFullyDefined()) {
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
} else {
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
<< "Too few dynamic shapes";
shape = dynamic_shapes[dynamic_shape_pos++];
if (!static_shape.IsCompatibleWith(shape)) {
return errors::InvalidArgument(
"Mismatch between static and dynamic shape for argument. Static "
"shape: ",
static_shape.DebugString(),
"; dynamic shape: ", shape.DebugString());
}
}
}
// Checks we consumed all of the dynamic shapes.
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
<< "Too many dynamic shapes";
return Status::OK();
}
// Function arguments and return values lose their device assignments, so we
// must recreate them.
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(

View File

@ -99,15 +99,6 @@ class TpuCompileOpKernelCommon {
const std::vector<TensorShape>& arg_shapes,
TpuProgramGroupInterface* tpu_program_group) = 0;
// Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not
// defined. There must be one dynamic_shape for each argument with a
// partially defined shape, in index order.
static Status ComputeArgumentShapes(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes);
// Performs shape inference on `computation`, filling shape_info with operator
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
static Status RunShapeInferenceOnComputation(

View File

@ -540,5 +540,43 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
}
return Status::OK();
}
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes) {
arg_shapes->resize(metadata.args_size());
int dynamic_shape_pos = 0;
for (int i = 0; i < metadata.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
// The XLA compiler determines the shape of each constant by inspecting the
// value of its corresponding host-memory tensor. As a result, we don't need
// to give the compiler graph-inferred shapes for constant arguments.
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
continue;
}
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
PartialTensorShape static_shape(arg.shape());
TensorShape& shape = (*arg_shapes)[i];
if (static_shape.IsFullyDefined()) {
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
} else {
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
<< "Too few dynamic shapes";
shape = dynamic_shapes[dynamic_shape_pos++];
if (!static_shape.IsCompatibleWith(shape)) {
return errors::InvalidArgument(
"Mismatch between static and dynamic shape for argument. Static "
"shape: ",
static_shape.DebugString(),
"; dynamic shape: ", shape.DebugString());
}
}
}
// Checks we consumed all of the dynamic shapes.
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
<< "Too many dynamic shapes";
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -159,6 +159,14 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
TPUCompileMetadataProto* metadata,
NameAttrList* function_name,
std::string* mlir_module);
// Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not
// defined. There must be one dynamic_shape for each argument with a
// partially defined shape, in index order.
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes);
} // namespace tpu
} // namespace tensorflow