Skip to content

Commit

Permalink
fix(torii-grpc): sql query for typed enums in nested arrays (#2579)
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo authored Oct 26, 2024
1 parent e54c4c2 commit 51b0297
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 34 deletions.
116 changes: 83 additions & 33 deletions crates/torii/core/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,37 @@ pub fn build_sql_query(
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(String, HashMap<String, String>, String), Error> {
#[derive(Default)]
struct TableInfo {
table_name: String,
parent_table: Option<String>,
is_optional: bool,
depth: usize, // Track nesting depth for proper ordering
}

#[allow(clippy::too_many_arguments)]
fn parse_ty(
path: &str,
name: &str,
ty: &Ty,
selections: &mut Vec<String>,
tables: &mut Vec<String>,
arrays_queries: &mut HashMap<String, (Vec<String>, Vec<String>)>,
tables: &mut Vec<TableInfo>,
arrays_queries: &mut HashMap<String, (Vec<String>, Vec<TableInfo>)>,
parent_is_optional: bool,
depth: usize,
) {
match &ty {
Ty::Struct(s) => {
// struct can be the main entrypoint to our model schema
// so we dont format the table name if the path is empty
let table_name =
if path.is_empty() { name.to_string() } else { format!("{}${}", path, name) };

tables.push(TableInfo {
table_name: table_name.clone(),
parent_table: if path.is_empty() { None } else { Some(path.to_string()) },
is_optional: parent_is_optional,
depth,
});

for child in &s.children {
parse_ty(
&table_name,
Expand All @@ -257,13 +273,21 @@ pub fn build_sql_query(
selections,
tables,
arrays_queries,
parent_is_optional,
depth + 1,
);
}

tables.push(table_name);
}
Ty::Tuple(t) => {
let table_name = format!("{}${}", path, name);

tables.push(TableInfo {
table_name: table_name.clone(),
parent_table: Some(path.to_string()),
is_optional: parent_is_optional,
depth,
});

for (i, child) in t.iter().enumerate() {
parse_ty(
&table_name,
Expand All @@ -272,16 +296,22 @@ pub fn build_sql_query(
selections,
tables,
arrays_queries,
parent_is_optional,
depth + 1,
);
}

tables.push(table_name);
}
Ty::Array(t) => {
let table_name = format!("{}${}", path, name);
let is_optional = true;

let mut array_selections = Vec::new();
let mut array_tables = vec![table_name.clone()];
let mut array_tables = vec![TableInfo {
table_name: table_name.clone(),
parent_table: Some(path.to_string()),
is_optional: true,
depth,
}];

parse_ty(
&table_name,
Expand All @@ -290,12 +320,15 @@ pub fn build_sql_query(
&mut array_selections,
&mut array_tables,
arrays_queries,
is_optional,
depth + 1,
);

arrays_queries.insert(table_name, (array_selections, array_tables));
}
Ty::Enum(e) => {
let table_name = format!("{}${}", path, name);
let is_optional = true;

let mut is_typed = false;
for option in &e.options {
Expand All @@ -312,26 +345,31 @@ pub fn build_sql_query(
selections,
tables,
arrays_queries,
is_optional,
depth + 1,
);
is_typed = true;
}

selections.push(format!("[{path}].external_{name} AS \"{path}.{name}\""));
selections.push(format!("[{}].external_{} AS \"{}.{}\"", path, name, path, name));
if is_typed {
tables.push(table_name);
tables.push(TableInfo {
table_name,
parent_table: Some(path.to_string()),
is_optional: parent_is_optional || is_optional,
depth,
});
}
}
_ => {
// alias selected columns to avoid conflicts in `JOIN`
selections.push(format!("[{path}].external_{name} AS \"{path}.{name}\""));
selections.push(format!("[{}].external_{} AS \"{}.{}\"", path, name, path, name));
}
}
}

let mut global_selections = Vec::new();
let mut global_tables = Vec::new();

let mut arrays_queries: HashMap<String, (Vec<String>, Vec<String>)> = HashMap::new();
let mut arrays_queries: HashMap<String, (Vec<String>, Vec<TableInfo>)> = HashMap::new();

for model in schemas {
parse_ty(
Expand All @@ -341,45 +379,56 @@ pub fn build_sql_query(
&mut global_selections,
&mut global_tables,
&mut arrays_queries,
false,
0,
);
}

// TODO: Fallback to subqueries, SQLite has a max limit of 64 on 'table 'JOIN'
if global_tables.len() > 64 {
return Err(QueryError::SqliteJoinLimit.into());
}

// Sort tables by depth to ensure proper join order
global_tables.sort_by_key(|table| table.depth);

let selections_clause = global_selections.join(", ");
let join_clause = global_tables
.into_iter()
.iter()
.map(|table| {
format!(" JOIN [{table}] ON {entities_table}.id = [{table}].{entity_relation_column}")
let join_type = if table.is_optional { "LEFT JOIN" } else { "JOIN" };
let join_condition =
format!("{entities_table}.id = [{}].{entity_relation_column}", table.table_name);
format!(" {join_type} [{}] ON {join_condition}", table.table_name)
})
.collect::<Vec<_>>()
.join(" ");

let mut formatted_arrays_queries: HashMap<String, String> = arrays_queries
.into_iter()
.map(|(table, (selections, tables))| {
.map(|(table, (selections, mut tables))| {
let mut selections_clause = selections.join(", ");
if !selections_clause.is_empty() {
selections_clause = format!(", {}", selections_clause);
}

// Sort array tables by depth
tables.sort_by_key(|table| table.depth);

let join_clause = tables
.iter()
.enumerate()
.map(|(idx, table)| {
if idx == 0 {
.map(|table| {
if table.parent_table.is_none() {
format!(
" JOIN [{table}] ON {entities_table}.id = \
[{table}].{entity_relation_column}"
" JOIN [{}] ON {entities_table}.id = [{}].{entity_relation_column}",
table.table_name, table.table_name
)
} else {
let join_type = if table.is_optional { "LEFT JOIN" } else { "JOIN" };
format!(
" JOIN [{table}] ON [{table}].full_array_id = \
[{prev_table}].full_array_id",
prev_table = tables[idx - 1]
" {join_type} [{}] ON [{}].full_array_id = [{}].full_array_id",
table.table_name,
table.table_name,
table.parent_table.as_ref().unwrap()
)
}
})
Expand All @@ -401,7 +450,7 @@ pub fn build_sql_query(
{entities_table}{join_clause}"
);
let mut count_query =
format!("SELECT COUNT({entities_table}.id) FROM {entities_table}{join_clause}",);
format!("SELECT COUNT({entities_table}.id) FROM {entities_table}{join_clause}");

if let Some(where_clause) = where_clause {
query += &format!(" WHERE {}", where_clause);
Expand Down Expand Up @@ -1023,11 +1072,12 @@ mod tests {
[Test-Position$vec].external_y AS \"Test-Position$vec.y\", \
[Test-PlayerConfig$favorite_item].external_Some AS \
\"Test-PlayerConfig$favorite_item.Some\", [Test-PlayerConfig].external_favorite_item \
AS \"Test-PlayerConfig.favorite_item\" FROM entities JOIN [Test-Position$vec] ON \
entities.id = [Test-Position$vec].entity_id JOIN [Test-Position] ON entities.id = \
[Test-Position].entity_id JOIN [Test-PlayerConfig$favorite_item] ON entities.id = \
[Test-PlayerConfig$favorite_item].entity_id JOIN [Test-PlayerConfig] ON entities.id \
= [Test-PlayerConfig].entity_id ORDER BY entities.event_id DESC";
AS \"Test-PlayerConfig.favorite_item\" FROM entities JOIN [Test-Position] ON \
entities.id = [Test-Position].entity_id JOIN [Test-PlayerConfig] ON entities.id = \
[Test-PlayerConfig].entity_id JOIN [Test-Position$vec] ON entities.id = \
[Test-Position$vec].entity_id LEFT JOIN [Test-PlayerConfig$favorite_item] ON \
entities.id = [Test-PlayerConfig$favorite_item].entity_id ORDER BY entities.event_id \
DESC";
// todo: completely tests arrays
assert_eq!(query.0, expected_query);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/torii/grpc/src/types/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl From<Enum> for proto::types::Enum {
fn from(r#enum: Enum) -> Self {
proto::types::Enum {
name: r#enum.name,
option: r#enum.option.expect("option value") as u32,
option: r#enum.option.unwrap_or_default() as u32,
options: r#enum.options.into_iter().map(Into::into).collect::<Vec<_>>(),
}
}
Expand Down

0 comments on commit 51b0297

Please sign in to comment.