Skip to content

Commit

Permalink
Fix a few typehints and make code more DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanderlee committed Mar 13, 2024
1 parent 66c6310 commit 3af56ee
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import uuid
from datetime import date
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)

from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
from pypika.utils import (
Expand Down Expand Up @@ -323,11 +337,11 @@ def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_plac
self._parameters = list()

@property
def placeholder(self):
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return self._placeholder
return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters
Expand All @@ -342,11 +356,11 @@ def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_pl
self._parameters = dict()

@property
def placeholder(self):
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return self._placeholder
return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters
Expand Down Expand Up @@ -439,6 +453,12 @@ def get_formatted_value(cls, value: Any, **kwargs):
return "null"
return str(value)

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key

def get_sql(
self,
quote_char: Optional[str] = None,
Expand All @@ -449,45 +469,28 @@ def get_sql(
if parameter is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)

# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql, param_key = self._get_param_data(parameter, **kwargs)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)


class ParameterValueWrapper(ValueWrapper):
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
super().__init__(value, alias)
self._parameter = parameter

def get_sql(
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
parameter: Parameter = None,
**kwargs: Any,
) -> str:
if parameter is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
else:
# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key


class JSON(Term):
Expand Down

0 comments on commit 3af56ee

Please sign in to comment.