Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor core input size checks #382

Merged
merged 9 commits into from
Sep 4, 2024
73 changes: 45 additions & 28 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();

// Fill metadata for raw input
if (!raw_input_name_.empty()) {
Expand All @@ -922,7 +923,7 @@ InferenceRequest::Normalize()
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
model_name + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
Expand Down Expand Up @@ -1055,7 +1056,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
model_name + "'");
}

if (batch_size_ == 0) {
Expand All @@ -1064,7 +1065,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"' batch size does not match other inputs for '" + model_name +
"'");
}

Expand All @@ -1080,7 +1081,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
model_name + "'");
}

// Verify that each input shape is valid for the model, make
Expand All @@ -1089,17 +1090,17 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input_name = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + input_id + "' data-type is '" +
LogRequest() + "inference input '" + input_name + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
"', but model '" + model_name + "' expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"'");
Expand All @@ -1119,7 +1120,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
input_id + "' for model '" + ModelName() + "', got " +
input_name + "' for model '" + model_name + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1148,8 +1149,8 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
LogRequest() + "unexpected shape for input '" + input_name +
"' for model '" + model_name + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
implicit_batch_note);
Expand Down Expand Up @@ -1201,32 +1202,25 @@ InferenceRequest::Normalize()
// TensorRT backend.
if (!input.IsNonLinearFormatIo()) {
TRITONSERVER_MemoryType input_memory_type;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));

// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
} else {
// Shape tensor with dynamic batching does not introduce a new
// dimension to the tensor but adds an additional value to the 1-D
// array.
const std::vector<int64_t>& input_dims =
input.IsShapeTensor() ? input.OriginalShape()
: input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
expected_byte_size =
int64_t expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > INT_MAX) ||
if ((byte_size > LLONG_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
input_name + "' for model '" + model_name + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
Expand Down Expand Up @@ -1300,7 +1294,8 @@ InferenceRequest::ValidateRequestInputs()

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& input_name, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();
Expand All @@ -1325,27 +1320,48 @@ InferenceRequest::ValidateBytesInputs(
buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size,
buffer_memory_type, &buffer_memory_id));

// GPU tensors are validated at platform backends to avoid additional
// data copying. Check "ValidateStringBuffer" in backend_common.cc.
if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) {
return Status::Success;
}
}

constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
// Get the next element if not currently processing one.
if (!remaining_element_size) {
// Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add the
// first 4 bytes for each element to find expected byte size.
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);

// FIXME: Assume the string element's byte size indicator is not spread
// across buffer boundaries for simplicity.
if (remaining_buffer_size < kElementSizeIndicator) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
"incomplete string length indicator for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(sizeof(uint32_t)) + " bytes but only " +
std::to_string(remaining_buffer_size) +
" bytes available. Please make sure the string length "
"indicator is in one buffer.");
}

// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Early stop
if (element_checked > element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_checked) + " for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(element_count));
}

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
Expand All @@ -1371,16 +1387,17 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
" buffers for inference input '" + input_name + "' for model '" +
model_name + "', got " + std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" string elements for inference input '" + input_id + "', got " +
" string elements for inference input '" + input_name +
"' for model '" + model_name + "', got " +
std::to_string(element_checked));
}

Expand Down
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down
Loading
Loading