"""Question answer database module"""
import csv
import datetime
import os
import sqlite3
import threading
import pandas as pd
from pykoi.chat.db.constants import QA_CSV_HEADER
[docs]class QuestionAnswerDatabase:
"""Question Answer Database class"""
def __init__(
self,
db_file: str = os.path.join(os.getcwd(), "qd.db"),
debug: bool = False,
):
"""
Initializes a new instance of the QuestionAnswerDatabase class.
Args:
db_file (str): The path to the SQLite database file.
debug (bool, optional): Whether to print debug messages. Defaults to False.
"""
self._db_file = db_file
self._debug = debug
self._local = threading.local() # Thread-local storage
self._lock = threading.Lock() # Lock for concurrent write operations
self.create_table()
[docs] def get_connection(self):
"""Returns the thread-local database connection"""
if not hasattr(self._local, "connection"):
self._local.connection = sqlite3.connect(self._db_file)
return self._local.connection
[docs] def get_cursor(self):
"""Returns the thread-local database cursor"""
if not hasattr(self._local, "cursor"):
self._local.cursor = self.get_connection().cursor()
return self._local.cursor
[docs] def create_table(self):
"""
Creates the question_answer table if it does not already exist in the database.
The table has four columns: id (primary key), question, answer, and vote_status.
vote_status is a text field that can only have the values 'up', 'down', or 'n/a'.
"""
query = """
CREATE TABLE IF NOT EXISTS question_answer (
id INTEGER PRIMARY KEY AUTOINCREMENT,
question TEXT,
answer TEXT,
vote_status TEXT CHECK (vote_status IN ('up', 'down', 'n/a')),
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query)
self.get_connection().commit()
if self._debug:
rows = self.retrieve_all_question_answers()
print("Table contents after creating table:")
self.print_table(rows)
[docs] def insert_question_answer(self, question: str, answer: str):
"""
Inserts a new question-answer pair into the database with the given question and answer.
The vote_status field is set to 'n/a' by default.
Returns the ID of the newly inserted row.
Args:
question (str): The question to insert.
answer (str): The answer to insert.
Returns:
int: The ID of the newly inserted row.
"""
timestamp = datetime.datetime.now()
query = """
INSERT INTO question_answer (question, answer, vote_status, timestamp)
VALUES (?, ?, 'n/a', ?);
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query, (question, answer, timestamp))
self.get_connection().commit()
if self._debug:
rows = self.retrieve_all_question_answers()
print("Table contents after inserting table:")
self.print_table(rows)
return cursor.lastrowid
[docs] def update_vote_status(self, id, vote_status):
"""
Updates the vote status of a question-answer pair with the given ID.
Args:
id (int): The ID of the question-answer pair to update.
vote_status (str): The new vote status to set. Must be one of 'up', 'down', or 'n/a'.
Raises:
ValueError: If the question with the given ID does not exist.
"""
query = """
UPDATE question_answer
SET vote_status = ?
WHERE id = ?;
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query, (vote_status, id))
self.get_connection().commit()
if cursor.rowcount == 0:
raise ValueError(f"Question with ID {id} does not exist.")
if self._debug:
rows = self.retrieve_all_question_answers()
print("Table contents after updating table:")
self.print_table(rows)
[docs] def retrieve_all_question_answers(self):
"""
Retrieves all question-answer pairs from the database.
Returns:
list: A list of tuples representing the question-answer pairs.
"""
query = """
SELECT * FROM question_answer;
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query)
rows = cursor.fetchall()
return rows
[docs] def retrieve_all_question_answers_as_pandas(self):
"""
Retrieves all question-answer pairs from the database as a pandas dataframe.
Returns:
DataFrame: A pandas dataframe.
"""
rows = self.retrieve_all_question_answers()
rows_to_pd = pd.DataFrame(rows)
rows_to_pd.columns = QA_CSV_HEADER
return rows_to_pd
[docs] def close_connection(self):
"""
Closes the connection to the database.
"""
if hasattr(self._local, "cursor"):
self._local.cursor.close()
del self._local.cursor
if hasattr(self._local, "connection"):
self._local.connection.close()
del self._local.connection
[docs] def print_table(self, rows):
"""
Prints the contents of the table in a formatted manner.
Args:
rows (list): A list of tuples where each tuple represents a row in the table.
Each tuple contains five elements: ID, Question, Answer, Timestamp, and Vote Status.
"""
for row in rows:
print(
f"ID: {row[0]}, Question: {row[1]}, "
f"Answer: {row[2]}, Vote Status: {row[3]}, Timestamp: {row[4]}"
)
[docs] def save_to_csv(self, csv_file_name="question_answer_votes.csv"):
"""
This method saves the contents of the question_answer table into a CSV file.
Args:
csv_file_name (str, optional): The name of the CSV file to which the data will be written.
Defaults to "question_answer_votes.csv".
The CSV file will have the following columns: ID, Question, Answer, Vote Status. Each row in the
CSV file corresponds to a row in the question_answer table.
This method first retrieves all question-answer pairs from the database by calling the
retrieve_all_question_answers method. It then writes this data to the CSV file.
"""
my_sql_data = self.retrieve_all_question_answers()
with open(csv_file_name, "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(QA_CSV_HEADER)
writer.writerows(my_sql_data)