"""
Cursor object for Kinetica which fits the DB API spec.
"""
import json
import re
import datetime, uuid, decimal
# pylint: disable=c-extension-no-member
import weakref
from enum import Enum, unique
from itertools import islice
from typing import (
Optional,
Sequence,
Type,
Union,
TYPE_CHECKING,
List,
Iterator,
Tuple,
Any,
)
from gpudb import GPUdbSqlIterator, GPUdbException
from gpudb.dbapi.core.exceptions import (
ProgrammingError,
NotSupportedError,
convert_runtime_errors,
)
from gpudb.dbapi.core.utils import raise_if_closed, ignore_transaction_error
from gpudb.dbapi.pep249 import (
SQLQuery,
QueryParameters,
ColumnDescription,
ProcName,
ProcArgs,
ResultRow,
ResultSet,
CursorConnectionMixin,
IterableCursorMixin,
TransactionalCursor,
)
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from gpudb.dbapi.core import KineticaConnection
@unique
class ParamStyle(Enum):
QMARK = "qmark"
NUMERIC = "numeric"
FORMAT = "format"
NUMERIC_DOLLAR = "numeric_dollar"
__all__ = ["Cursor", "ParamStyle"]
kinetica_to_python_type_map = {
"boolean": "bool",
"int8": "int",
"int16": "int",
"int": "int",
"long": "int",
"float": "float",
"double": "float",
"decimal": "decimal.Decimal",
"string": "str",
"char1": "str",
"char2": "str",
"char4": "str",
"char8": "str",
"char16": "str",
"char32": "str",
"char64": "str",
"char128": "str",
"char256": "str",
"ipv4": "int",
"uuid": "uuid.UUID",
"wkt": "str",
"date": "datetime.date",
"datetime": "datetime.datetime",
"time": "datetime.time",
"timestamp": "int",
"bytes": "bytes",
"wkb": "str",
}
# pylint: disable=too-many-ancestors
[docs]
class Cursor(CursorConnectionMixin, IterableCursorMixin, TransactionalCursor):
"""
A DB API 2.0 compliant cursor for Kinetica, as outlined in
PEP 249.
"""
__INSERT_BATCH_SIZE = 50000
def __init__(
self,
connection: "KineticaConnection",
query: SQLQuery = None,
query_params: QueryParameters = None,
):
self._connection: KineticaConnection = weakref.proxy(connection)
self._sql = query
self._query_params = query_params
self._cursors: List[GPUdbSqlIterator] = []
self.__closed = False
@classmethod
def from_connection(cls, conn: "KineticaConnection"):
return cls(conn)
@property
def sql(self):
return self._sql
@sql.setter
def sql(self, value):
self._sql = value
@property
def query_params(self):
return self._query_params
@query_params.setter
def query_params(self, value):
self._query_params = value
@property
def _closed(self) -> bool:
# pylint: disable=protected-access
try:
return self.__closed or self.connection.closed
except ReferenceError:
# Parent connection already GC'd.
return True
@_closed.setter
def _closed(self, value: bool):
self.__closed = value
@property
def connection(self) -> "KineticaConnection":
return self._connection
@property
def description(self) -> Optional[Sequence[ColumnDescription]]:
"""
This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
name
type_code
display_size
internal_size
precision
scale
null_ok
The first two items (name and type_code) are mandatory, the other five
are optional and are set to None if no meaningful values can be provided.
This attribute will be None for operations that do not return rows or
if the cursor has not had an operation invoked via the .execute*() method yet.
Returns:
Optional[Sequence[ColumnDescription]]: a sequence of immutable tuples like:
[('field_1', <class 'int'>, None, None, None, None, None),
('field_2', <class 'int'>, None, None, None, None, None),
('field_3', <class 'str'>, None, None, None, None, None),
('field_4', <class 'float'>, None, None, None, None, None)]
"""
if self._cursors and len(self._cursors) > 0 and self._cursors[-1].type_map:
try:
return [
(n, eval(kinetica_to_python_type_map[t]), None, None, None, None, None)
for n, t in self._cursors[-1].type_map.items()
]
except RuntimeError:
return None
return None
@property
def rowcount(self) -> int:
return self._cursors[-1].total_count
[docs]
@raise_if_closed
@convert_runtime_errors
def commit(self) -> None:
return None
[docs]
@raise_if_closed
@ignore_transaction_error
@convert_runtime_errors
def rollback(self) -> None:
return None
[docs]
@convert_runtime_errors
def close(self) -> None:
"""Close the cursor."""
if self._closed:
return
for cursor in self._cursors:
cursor.close()
self._closed = True
[docs]
def callproc(
self, procname: ProcName, parameters: Optional[ProcArgs] = None
) -> Optional[ProcArgs]:
sql_statement = f"EXECUTE PROCEDURE {procname}"
cursor = self.execute(sql_statement, parameters[:-1] if len(parameters) > 0 else None, parameters[-1] if len(parameters) > 0 else None)
cursor.close()
return None
[docs]
def nextset(self) -> Optional[bool]:
raise NotSupportedError(
"Kinetica Cursors do not support more than one result set."
)
[docs]
@raise_if_closed
def setoutputsize(self, size: int, column: Optional[int]) -> None:
pass
[docs]
@raise_if_closed
@convert_runtime_errors
def execute(
self, operation: SQLQuery, parameters: Optional[QueryParameters] = None, default_schema: Optional[str] = None
) -> "Cursor":
"""Executes an SQL statement and returns a Cursor instance which can
used to iterate over the results of the query
Args:
operation (SQLQuery): an SQL statement
parameters (Optional[QueryParameters], optional): the parameters
to the SQL statement; typically a heterogeneous list. Defaults to None.
default_schema (Optional[str], optional): the default schema to
Returns:
Cursor: Returns an instance of self which is used by KineticaConnection
"""
sql_statement = None
valid, placeholder = Cursor.__is_valid_statement(operation)
if valid:
if placeholder == ParamStyle.QMARK:
sql_statement = Cursor.__process_params_qmark(operation)
elif placeholder == ParamStyle.NUMERIC:
sql_statement = Cursor.__process_params_numeric(operation)
elif placeholder == ParamStyle.FORMAT:
sql_statement = Cursor.__process_params_format(operation)
else:
raise ProgrammingError(
"Invalid SQL statement {}; has non-supported parameter placeholders {}".format(
operation, placeholder
)
)
sql_statement = sql_statement or operation
internal_cursor = GPUdbSqlIterator(
self._connection.connection,
sql_statement,
sql_params=list(parameters) if parameters else [],
sql_opts={"current_schema": default_schema} if default_schema and len(default_schema) > 0 else {},
)
self._cursors.append(internal_cursor)
self.arraysize = self._cursors[-1].batch_size
self._cursors[-1]._GPUdbSqlIterator__execute(
sql_statement, parameters=json.dumps(parameters)
)
self._closed = False
return self
[docs]
def executescript(self, script: SQLQuery) -> "Cursor":
"""Not supported as of now."""
raise NotSupportedError("Kinetica does not support this call ...")
[docs]
@raise_if_closed
@convert_runtime_errors
def executemany(
self, operation: SQLQuery, seq_of_parameters: Sequence[QueryParameters], default_schema: Optional[str] = None
) -> "Cursor":
def split_list_iter(lst, size):
it = iter(lst)
return [list(islice(it, size)) for _ in range(0, len(lst), size)]
valid, placeholder = Cursor.__is_valid_statement(operation)
if not valid:
raise GPUdbException(f"Invalid SQL statement : {operation}")
statement_type = Cursor.__check_sql_statement_type(operation)
if statement_type == "INSERT":
json_lists = split_list_iter(seq_of_parameters, Cursor.__INSERT_BATCH_SIZE)
for json_list in json_lists:
options = {"query_parameters": json.dumps(json_list)}
if default_schema and len(default_schema) > 0:
options["default_schema"] = default_schema
resp = self.connection.connection.execute_sql_and_decode(
operation, options=options
)
if resp and resp["status_info"]["status"] == "ERROR":
raise GPUdbException(resp["status_info"]["message"])
return self
else:
cursor_list = []
for params in seq_of_parameters:
cursor = self.execute(operation, params, default_schema)
cursor_list.append(cursor)
last_cursor = cursor_list[:-1][0]
for cursor in cursor_list:
cursor.close()
return last_cursor
[docs]
@raise_if_closed
@convert_runtime_errors
def fetchone(self) -> Optional[ResultRow]:
row = next(self._cursors[-1], None)
return row
[docs]
@raise_if_closed
@convert_runtime_errors
def fetchmany(self, size: Optional[int] = None) -> ResultSet:
if size is None:
size = self.arraysize
result_set = []
for _ in range(size):
row = self.fetchone()
if row is None:
break
result_set.append(row)
return result_set
[docs]
@raise_if_closed
@convert_runtime_errors
def fetchall(self) -> ResultSet:
return self.fetchmany()
def __iter__(self) -> Union[Iterator[dict], Iterator[tuple]]:
"""Iteration over the result set."""
while True:
_next = self.fetchone()
if _next is None:
break
yield _next
@staticmethod
def __has_qmark_params(sql_statement):
valid, placeholder = Cursor.__is_valid_statement(sql_statement)
return placeholder == ParamStyle.QMARK and valid
@staticmethod
def __has_numeric_params(sql_statement):
valid, placeholder = Cursor.__is_valid_statement(sql_statement)
return placeholder == ParamStyle.NUMERIC and valid
@staticmethod
def __has_format_params(sql_statement):
valid, placeholder = Cursor.__is_valid_statement(sql_statement)
return placeholder == ParamStyle.FORMAT and valid
@staticmethod
def __is_valid_statement(
sql_statement: str,
) -> Tuple[Union[bool, Any], Union[str, None]]:
placeholders = list(
Cursor.__extract_parameter_placeholders(sql_statement).keys()
)
if len(placeholders) > 1:
raise ProgrammingError(
f"SQL statement {sql_statement} contains different parameter placeholder formats {placeholders}"
)
placeholder = placeholders[0] if len(placeholders) == 1 else None
supported_paramstyles = [e.value for e in ParamStyle]
return (
placeholder is None or placeholder.value in supported_paramstyles,
placeholder,
)
@staticmethod
def __extract_parameter_placeholders(sql_statement: str):
# Define regular expression patterns to match different DBAPI v2 placeholders
patterns = {
ParamStyle.QMARK: r"\?", # Question mark style
ParamStyle.NUMERIC: r":(\d+)", # Numeric style (e.g., :1, :2)
ParamStyle.NUMERIC_DOLLAR: r"\$\d+", # Numeric dollar style (e.g., $1, $2)
# 'named': r':\w+', # Named style (e.g., :name)
ParamStyle.FORMAT: r"%[sdifl]", # ANSI C printf format codes (e.g., %s)
# 'pyformat': r'%\(\w+\)s' # Python extended format codes (e.g., %(name)s)
}
# Dictionary to hold found placeholders
placeholders = {}
# Extract placeholders based on patterns
for key, pattern in patterns.items():
found = re.findall(pattern, sql_statement)
if len(found) > 0:
placeholders[key] = found
return placeholders
@staticmethod
def __transform_sql(sql: str, param_style: ParamStyle) -> str:
"""
Parses SQL to safely replace placeholders with $n syntax, skipping
placeholders found inside quotes (', ") or comments.
"""
# Regex to match SQL constructs:
# 1. Single-quoted strings (handling standard SQL escaped quotes '')
# 2. Double-quoted identifiers (handling escaped quotes "")
# 3. Block comments /* ... */
# 4. Line comments -- ...
# 5. The specific parameter placeholders we want to replace
# Base patterns for static SQL structures
pattern_string = r"('(''|[^'])*')"
pattern_ident = r'("(""|[^"])*")'
pattern_comment_blk = r"(/\*.*?\*/)"
pattern_comment_line = r"(--[^\r\n]*)"
# Define the placeholder regex based on the style
if param_style == ParamStyle.QMARK:
pattern_param = r"(\?)"
elif param_style == ParamStyle.NUMERIC:
pattern_param = r"(:[0-9]+)"
elif param_style == ParamStyle.FORMAT:
# FIX: Matches %s, %d, %i, %f, %l to align with detection logic
pattern_param = r"(%[sdifl])"
else:
return sql # Unsupported or no transformation needed
# Combine into one master regex
# Order matters: check strings/comments first to "consume" them
master_pattern = re.compile(
f"{pattern_string}|{pattern_ident}|{pattern_comment_blk}|{pattern_comment_line}|{pattern_param}",
re.DOTALL | re.MULTILINE
)
output = []
last_pos = 0
param_counter = 1
for match in master_pattern.finditer(sql):
# Append everything between the last match and this one
output.append(sql[last_pos:match.start()])
# match.group() returns the full match. We check what it looks like.
full_match = match.group(0)
# If it starts with ' or " or / or -, it's a safe zone (string/comment). Keep as is.
if full_match[0] in ("'", '"', '/', '-'):
output.append(full_match)
elif param_style == ParamStyle.QMARK and full_match == '?':
output.append(f"${param_counter}")
param_counter += 1
elif param_style == ParamStyle.NUMERIC and full_match.startswith(':'):
# :1 -> $1. Extract number.
num = full_match[1:]
output.append(f"${num}")
elif param_style == ParamStyle.FORMAT and full_match.startswith('%'):
output.append(f"${param_counter}")
param_counter += 1
else:
# Should not happen given the regex, but safe fallback
output.append(full_match)
last_pos = match.end()
# Append remaining string
output.append(sql[last_pos:])
return "".join(output)
@staticmethod
def __process_params_qmark(query: str):
"""Replace all occurrences of '?' with $1, $2, ..., $n in the SQL statement
safely ignoring strings and comments.
"""
return Cursor.__transform_sql(query, ParamStyle.QMARK)
@staticmethod
def __process_params_numeric(query: str):
"""Replace each ':n' with corresponding '$n' in the SQL statement
safely ignoring strings and comments.
"""
return Cursor.__transform_sql(query, ParamStyle.NUMERIC)
@staticmethod
def __process_params_format(query: str):
"""Replace each ANSI C format specifier with corresponding $n in the SQL statement
safely ignoring strings and comments.
"""
return Cursor.__transform_sql(query, ParamStyle.FORMAT)
@staticmethod
def __extract_table_name_from_insert_statement(sql):
# Define a regex pattern to match the table name in an INSERT statement
pattern = r"INSERT\s+INTO\s+([a-zA-Z0-9_]+\.[a-zA-Z0-9_]+|[a-zA-Z0-9_]+)"
# Search for the pattern in the SQL statement
match = re.search(pattern, sql, re.IGNORECASE)
# If a match is found, return the table name
if match:
return match.group(1)
else:
# If no match is found, return None or raise an error
return None
@staticmethod
def __check_sql_statement_type(sql):
# Trim leading/trailing whitespace and make the statement uppercase for comparison
sql = sql.strip().upper()
# Patterns to match INSERT, DELETE, and UPDATE statements
patterns = {
"INSERT": r"^INSERT\s+INTO",
"DELETE": r"^DELETE\s+FROM",
"UPDATE": r"^UPDATE\s+",
}
# Check each pattern against the SQL statement
for statement_type, pattern in patterns.items():
if re.match(pattern, sql):
return statement_type
# If no match is found, return None
return None
@staticmethod
def classify_sql_statement(sql):
# Remove leading/trailing whitespace and convert to uppercase
sql = sql.strip().upper()
# Check for DDL
ddl_keywords = ['CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME']
if any(sql.startswith(keyword) for keyword in ddl_keywords):
return 'DDL'
# Check for DML
dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
if any(sql.startswith(keyword) for keyword in dml_keywords):
return 'DML'
# Check for Query (SELECT statement)
if sql.startswith('SELECT'):
return 'QUERY'
# Check for other types if needed (optional)
other_keywords = ['GRANT', 'REVOKE', 'CALL', 'EXPLAIN']
if any(sql.startswith(keyword) for keyword in other_keywords):
return 'OTHER'
# If no matches found
return 'UNKNOWN'