Source code for pykoi.chat.db.ranking_database

"""This file contains the RankingDatabase class which is used to store and retrieve"""
import csv
import os
import sqlite3
import threading

import pandas as pd

from pykoi.chat.db.constants import RANKING_CSV_HEADER


[docs]class RankingDatabase: """Ranking Database class""" def __init__( self, db_file: str = os.path.join(os.getcwd(), "ranking.db"), debug: bool = False, ): """ Initializes a new instance of the RankingDatabase class. Args: db_file (str): The path to the SQLite database file. debug (bool, optional): Whether to print debug messages. Defaults to False. """ self._db_file = db_file self._debug = debug self._local = threading.local() # Thread-local storage self._lock = threading.Lock() # Lock for concurrent write operations self.create_table()
[docs] def get_connection(self): """Returns the thread-local database connection""" if not hasattr(self._local, "connection"): self._local.connection = sqlite3.connect(self._db_file) return self._local.connection
[docs] def get_cursor(self): """Returns the thread-local database cursor""" if not hasattr(self._local, "cursor"): self._local.cursor = self.get_connection().cursor() return self._local.cursor
[docs] def create_table(self): """ Creates the ranking table if it does not already exist in the database. The table has four columns: id (primary key), question, up_ranking_answer, and low_ranking_answer. """ query = """ CREATE TABLE IF NOT EXISTS ranking ( id INTEGER PRIMARY KEY AUTOINCREMENT, question TEXT, up_ranking_answer TEXT, low_ranking_answer TEXT ); """ with self._lock: cursor = self.get_cursor() cursor.execute(query) self.get_connection().commit() if self._debug: rows = self.retrieve_all_question_answers() print("Table contents after creating table:") self.print_table(rows)
[docs] def insert_ranking( self, question: str, up_ranking_answer: str, low_ranking_answer: str ): """ Inserts a new ranking entry into the database with the given question, up_ranking_answer, and low_ranking_answer. Args: question (str): The question of the ranking entry. up_ranking_answer (str): The higher ranked answer of the ranking entry. low_ranking_answer (str): The lower ranked answer of the ranking entry. Returns: int: The ID of the newly inserted row. """ query = """ INSERT INTO ranking (question, up_ranking_answer, low_ranking_answer) VALUES (?, ?, ?); """ with self._lock: cursor = self.get_cursor() cursor.execute( query, (question, up_ranking_answer, low_ranking_answer) ) self.get_connection().commit() if self._debug: rows = self.retrieve_all_question_answers() print("Table contents after inserting entry:") self.print_table(rows) return cursor.lastrowid
[docs] def retrieve_all_question_answers(self): """ Retrieves all question-answer pairs from the database. Returns: list: A list of tuples representing the question-answer pairs. """ query = """ SELECT * FROM ranking; """ with self._lock: cursor = self.get_cursor() cursor.execute(query) rows = cursor.fetchall() return rows
[docs] def close_connection(self): """ Closes the connection to the database. """ if hasattr(self._local, "cursor"): self._local.cursor.close() del self._local.cursor if hasattr(self._local, "connection"): self._local.connection.close() del self._local.connection
[docs] def print_table(self, rows): """ Prints the contents of the table in a formatted manner. Args: rows (list): A list of tuples where each tuple represents a row in the table. Each tuple contains four elements: ID, Question, Up_Ranking_Answer, Low_Ranking_Answer. """ for row in rows: print( f"ID: {row[0]}, Question: {row[1]}, " f"Up_Ranking_Answer: {row[2]}, Low_Ranking_Answer: {row[3]}" )
[docs] def save_to_csv(self, csv_file_name="ranking_data.csv"): """ Saves the contents of the ranking table into a CSV file. Args: csv_file_name (str, optional): The name of the CSV file to which the data will be written. Defaults to "ranking_data.csv". The CSV file will have the following columns: ID, Question, Up_Ranking_Answer, Low_Ranking_Answer. Each row in the CSV file corresponds to a row in the ranking table. This method retrieves all ranking entries from the database by calling the retrieve_all_question_answers method. It then writes this data to the CSV file. """ ranking_data = self.retrieve_all_question_answers() with open(csv_file_name, "w", newline="") as file: writer = csv.writer(file) writer.writerow(RANKING_CSV_HEADER) writer.writerows(ranking_data)
[docs] def retrieve_all_question_answers_as_pandas(self): """ Retrieves pairs from the database as a pandas dataframe. Returns: DataFrame: A pandas dataframe. """ rows = self.retrieve_all_question_answers() rows_to_pd = pd.DataFrame(rows) rows_to_pd.columns = RANKING_CSV_HEADER return rows_to_pd