Skip to content

Commit

Permalink
fix: Don't panic in SQL temporal string check; raise suitable `Column…
Browse files Browse the repository at this point in the history
…NotFound` error (#19473)
  • Loading branch information
alexander-beedie authored Oct 27, 2024
1 parent 4ddae71 commit f103fa8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
21 changes: 11 additions & 10 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,34 +374,35 @@ impl SQLExprVisitor<'_> {
},
// identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
(Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => {
if let Expr::Column(name) = &**expr {
(Some(name.clone()), Some(s), Some(dtype))
} else {
(None, Some(s), Some(dtype))
match &**expr {
Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
_ => (None, Some(s), Some(dtype)),
}
},
_ => (None, None, None),
} {
if expr_dtype.is_none() && self.active_schema.is_none() {
right.clone()
} else {
let left_dtype = expr_dtype
.unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap());

let left_dtype = expr_dtype.or_else(|| {
self.active_schema
.as_ref()
.and_then(|schema| schema.get(&name))
});
match left_dtype {
DataType::Time if is_iso_time(s) => {
Some(DataType::Time) if is_iso_time(s) => {
right.clone().str().to_time(StrptimeOptions {
strict: true,
..Default::default()
})
},
DataType::Date if is_iso_date(s) => {
Some(DataType::Date) if is_iso_date(s) => {
right.clone().str().to_date(StrptimeOptions {
strict: true,
..Default::default()
})
},
DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => {
Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
if s.len() == 10 {
// handle upcast from ISO date string (10 chars) to datetime
lit(format!("{}T00:00:00", s))
Expand Down
25 changes: 24 additions & 1 deletion py-polars/tests/unit/sql/test_miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import polars as pl
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -362,3 +362,26 @@ def test_global_variable_inference_17398() -> None:
eager=True,
)
assert_frame_equal(res, users)


@pytest.mark.parametrize(
"query",
[
"SELECT invalid_column FROM self",
"SELECT key, invalid_column FROM self",
"SELECT invalid_column * 2 FROM self",
"SELECT * FROM self ORDER BY invalid_column",
"SELECT * FROM self WHERE invalid_column = 200",
"SELECT * FROM self WHERE invalid_column = '200'",
"SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column",
],
)
def test_invalid_cols(query: str) -> None:
df = pl.DataFrame(
{
"key": ["xx", "xx", "yy"],
"n": ["100", "200", "300"],
}
)
with pytest.raises(ColumnNotFoundError, match="invalid_column"):
df.sql(query)

0 comments on commit f103fa8

Please sign in to comment.