Source code for pykoi.chat.db.comparator_database

"""Comparator Database"""
import datetime
import os

from typing import List, Tuple

from pykoi.chat.db.abs_database import AbsDatabase


[docs]class ComparatorQuestionDatabase(AbsDatabase): """Comparator Question Database class""" def __init__( self, db_file: str = os.path.join(os.getcwd(), "comparator.db"), debug: bool = False, ) -> None: """ Initializes a new instance of the ComparatorQuestionDatabase class. Args: db_file (str): The path to the SQLite database file. debug (bool, optional): Whether to print debug messages. Defaults to False. """ query = """ CREATE TABLE IF NOT EXISTS comparator_question ( id INTEGER PRIMARY KEY AUTOINCREMENT, question TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); """ super().__init__(db_file, debug) self.create_table(query)
[docs] def insert(self, **kwargs) -> None: """ Inserts question, timestamp into the database. Args: kwargs (dict): The key-value pairs to insert into the database. Returns: int: The ID of the newly inserted row. """ question = kwargs["question"] timestamp = datetime.datetime.now() query = """ INSERT INTO comparator_question (question, timestamp) VALUES (?, ?); """ with self._lock: cursor = self.get_cursor() cursor.execute(query, (question, timestamp)) self.get_connection().commit() if self._debug: rows = self.retrieve_all() print("Table contents after inserting table:") self.print_table(rows) return cursor.lastrowid
[docs] def update(self, **kwargs) -> None: """ Updates the database. """ raise NotImplementedError( "ComparatorQuestionDatabase does not support update." )
[docs] def retrieve_all(self) -> List[Tuple]: """ Retrieves all pairs from the database. Returns: list: A list of tuples. """ query = """ SELECT * FROM comparator_question; """ with self._lock: cursor = self.get_cursor() cursor.execute(query) rows = cursor.fetchall() return rows
[docs] def print_table(self, rows: List[Tuple]) -> None: """ Prints the contents of the table in a formatted manner. Args: rows (list): A list of tuples where each tuple represents a row in the table. Each tuple contains five elements: ID, Question. """ for row in rows: print(f"ID: {row[0]}, Question: {row[1]}, Timestamp: {row[2]}")
[docs]class ComparatorDatabase(AbsDatabase): """ComparatorDatabase class.""" def __init__( self, db_file: str = os.path.join(os.getcwd(), "comparator.db"), debug: bool = False, ) -> None: query = """ CREATE TABLE IF NOT EXISTS comparator ( id INTEGER PRIMARY KEY AUTOINCREMENT, model TEXT NOT NULL, qid INTEGER NOT NULL, rank INTEGER NOT NULL, answer TEXT NOT NULL, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); """ super().__init__(db_file, debug) self.create_table(query)
[docs] def insert(self, **kwargs) -> None: """ Inserts a new row into the comparator table. Args: kwargs (dict): The key-value pairs to insert into the database. """ timestamp = datetime.datetime.now() check_query = """ SELECT * FROM comparator WHERE model = ? AND qid = ?; """ with self._lock: cursor = self.get_cursor() cursor.execute(check_query, (kwargs["model"], kwargs["qid"])) existing_row = cursor.fetchone() if existing_row is not None: raise ValueError( f"Row with model={kwargs['model']} and" f" qid={kwargs['qid']} already exists" ) query = """ INSERT INTO comparator (model, qid, rank, answer, timestamp) VALUES (?, ?, ?, ?, ?); """ cursor.execute( query, ( kwargs["model"], kwargs["qid"], kwargs["rank"], kwargs["answer"], timestamp, ), ) self.get_connection().commit() if self._debug: rows = self.retrieve_all() print("Table contents after inserting table") self.print_table(rows)
[docs] def update(self, **kwargs) -> None: """ Updates the rank of a row in the comparator table by its id. Args: kwargs (dict): The key-value pairs to update in the database. """ query = """ UPDATE comparator SET rank = ? WHERE qid = ? AND model = ?; """ with self._lock: cursor = self.get_cursor() cursor.execute( query, (kwargs["rank"], kwargs["qid"], kwargs["model"]) ) self.get_connection().commit() if self._debug: rows = self.retrieve_all() print("Table contents after updating table") self.print_table(rows)
[docs] def retrieve_all(self) -> List[Tuple]: """ Retrieves all pairs from the database. Returns: list: A list of tuples. """ query = """ SELECT * FROM comparator; """ with self._lock: cursor = self.get_cursor() cursor.execute(query) rows = cursor.fetchall() return rows
[docs] def print_table(self, rows: List[Tuple]) -> None: """ Prints the comparator table. Args: rows (list): A list of tuples where each tuple represents a row in the table. Each tuple contains five elements: ID, Model, QID, Rank, Answer, Timestamp. """ for row in rows: print( f"ID: {row[0]}, " f"Model: {row[1]}, " f"QID: {row[2]}, " f"Rank: {row[3]}, " f"Answer: {row[4]}, " f"Timestamp: {row[5]}" )