Source code for pykoi.chat.db.abs_database
"""abs database"""
import abc
import sqlite3
import threading
from typing import List, Tuple
[docs]class AbsDatabase:
"""Base Database class"""
def __init__(self, db_file: str, debug: bool = False) -> None:
"""
Initializes a new instance of the BaseDatabase class.
Args:
db_file (str): The path to the SQLite database file.
debug (bool, optional): Whether to print debug messages. Defaults to False.
"""
self._db_file = db_file
self._debug = debug
self._local = threading.local() # Thread-local storage
self._lock = threading.Lock() # Lock for concurrent write operations
[docs] def get_connection(self) -> sqlite3.Connection:
"""Returns the thread-local database connection"""
if not hasattr(self._local, "connection"):
self._local.connection = sqlite3.connect(self._db_file)
return self._local.connection
[docs] def get_cursor(self) -> sqlite3.Cursor:
"""Returns the thread-local database cursor"""
if not hasattr(self._local, "cursor"):
self._local.cursor = self.get_connection().cursor()
return self._local.cursor
[docs] def create_table(self, query: str) -> None:
"""
Creates the table if it does not already exist in the database.
Args:
query (str): The SQL query to create the table.
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query)
self.get_connection().commit()
if self._debug:
rows = self.retrieve_all()
print("Table contents after creating table:")
self.print_table(rows)
[docs] def close_connection(self):
"""
Closes the connection to the database.
"""
if hasattr(self._local, "cursor"):
self._local.cursor.close()
del self._local.cursor
if hasattr(self._local, "connection"):
self._local.connection.close()
del self._local.connection
[docs] @abc.abstractmethod
def insert(self, **kwargs) -> None:
"""
Inserts into the database.
Args:
kwargs (dict): The key-value pairs to insert into the database.
"""
raise NotImplementedError(
"Insert method must be implemented by subclasses."
)
[docs] @abc.abstractmethod
def update(self, **kwargs) -> None:
"""
Updates the database.
Args:
kwargs (dict): The key-value pairs to update in the database.
"""
raise NotImplementedError(
"Update method must be implemented by subclasses."
)
[docs] def retrieve_all(self) -> List[Tuple]:
"""
Retrieves all pairs from the database.
"""
raise NotImplementedError(
"Retrieve method must be implemented by subclasses."
)
[docs] @abc.abstractmethod
def print_table(self, rows: str) -> None:
"""
Prints the table to the console.
Args:
rows (str): The rows to print.
"""
raise NotImplementedError(
"Print method must be implemented by subclasses."
)