"""Application module."""
import os
import socket
import threading
import time
from datetime import datetime
from typing import List, Optional, Any, Dict, Union
from fastapi import FastAPI, Depends, HTTPException, UploadFile, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from passlib.context import CryptContext
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pyngrok import ngrok
from starlette.middleware.cors import CORSMiddleware
from pykoi.component.base import Dropdown
from pykoi.interactives.chatbot import Chatbot
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import AppStartEvent, AppStopEvent
oauth_scheme = HTTPBasic()
[docs]class UpdateQATable(BaseModel):
id: int
vote_status: str
[docs]class RankingTableUpdate(BaseModel):
question: str
up_ranking_answer: str
low_ranking_answer: str
[docs]class InferenceRankingTable(BaseModel):
n: Optional[int] = 2
[docs]class ModelAnswer(BaseModel):
model: str
qid: int
rank: int
answer: str
[docs]class ComparatorInsertRequest(BaseModel):
data: List[ModelAnswer]
[docs]class UserInDB:
def __init__(self, username: str, hashed_password: str):
self.username = username
self.hashed_password = hashed_password
[docs]class Application:
"""
The Application class.
"""
def __init__(
self,
share: bool = False,
debug: bool = False,
username: Union[None, str, List] = None,
password: Union[None, str, List] = None,
host: str = "127.0.0.1",
port: int = 5000,
enable_telemetry: bool = True,
):
"""
Initialize the Application.
Args:
share (bool, optional): If True, the application will be shared via ngrok. Defaults to False.
debug (bool, optional): If True, the application will run in debug mode. Defaults to False.
username (str, optional): The username for authentication. Defaults to None.
password (str, optional): The password for authentication. Defaults to None.
host (str): The host to run the application on. Defaults to None.
port (int): The port to run the application on. Defaults to None.
enable_telemetry (bool, optional): If True, enable_telemetry will be enabled. Defaults to True.
"""
self._debug = debug
self._share = share
self._host = host
self._port = port
self.data_sources = {}
self.components = []
if username and password:
self._auth = True
else:
self._auth = False
if isinstance(username, str):
username = [username]
if isinstance(password, str):
password = [password]
if (
username is not None
and password is not None
and len(username) != len(password)
):
raise ValueError("The length of username and password must be the same.")
self._pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
self._fake_users_db = {}
if username is not None and password is not None:
for user_name, pass_word in zip(username, password):
self._fake_users_db[user_name] = UserInDB(
username=user_name,
hashed_password=self._pwd_context.hash(pass_word),
)
self._telemetry = Telemetry(enable_telemetry)
[docs] def authenticate_user(self, fake_db, username: str, password: str):
if self._auth:
user = fake_db.get(username)
if not user:
return False
if not self._pwd_context.verify(password, user.hashed_password):
return False
return user
else:
return "no_auth"
[docs] def auth_required(self, credentials: HTTPBasicCredentials = Depends(oauth_scheme)):
user = self.authenticate_user(
self._fake_users_db, credentials.username, credentials.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return user
[docs] def dummy_auth(self):
return None
[docs] def get_auth_dependency(self):
if self._auth:
return self.auth_required
else:
return self.dummy_auth
[docs] def add_component(self, component: Any):
"""
Add a component to the application.
Args:
component (Any): The component to be added.
"""
if component.data_source:
self.data_sources[component.id] = component.data_source
# set data_endpoint if it's a Dropdown component
if isinstance(component, Dropdown):
component.props["data_endpoint"] = component.id
self.components.append(
{
"id": component.id,
"component": component,
"svelte_component": component.svelte_component,
"props": component.props,
}
)
[docs] def create_chatbot_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create chatbot routes for the application.
Args:
app (FastAPI): The FastAPI application.
component (Dict[str, Any]): The component for which the routes are being created.
"""
@app.post("/chat/{message}")
async def inference(
message: str,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
output = component["component"].model.predict(message)[0]
# insert question and answer into database
id = component["component"].database.insert_question_answer(
message, output
)
return {
"id": id,
"log": "Inference complete",
"status": "200",
"question": message,
"answer": output,
}
except Exception as ex:
return {"log": f"Inference failed: {ex}", "status": "500"}
@app.post("/chat/qa_table/update")
async def update_qa_table(
request_body: UpdateQATable,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.update_vote_status(
request_body.id, request_body.vote_status
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}
@app.get("/chat/qa_table/close")
async def close_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
component["component"].database.close_connection()
return {"log": "Table closed", "status": "200"}
except Exception as ex:
return {"log": f"Table close failed: {ex}", "status": "500"}
@app.post("/chat/multi_responses/{message}")
async def inference_ranking_table(
message: str,
request_body: InferenceRankingTable,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
num_of_response = request_body.n
output = component["component"].model.predict(message, num_of_response)
# Check the type of each item in the output list
return {
"log": "Inference complete",
"status": "200",
"question": message,
"answer": output,
}
except Exception as ex:
return {"log": f"Inference failed: {ex}", "status": "500"}
@app.post("/chat/ranking_table/update")
async def update_ranking_table(
request_body: RankingTableUpdate,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.insert_ranking(
request_body.question,
request_body.up_ranking_answer,
request_body.low_ranking_answer,
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}
@app.get("/chat/ranking_table/retrieve")
async def retrieve_ranking_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
print("retrieve_ranking_table")
rows = component["component"].database.retrieve_all_question_answers()
return {"rows": rows, "log": "Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}
[docs] def create_feedback_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create feedback routes for the application.
Args:
app (FastAPI): The FastAPI application.
component (Dict[str, Any]): The component for which the routes are being created.
"""
@app.get("/chat/qa_table/retrieve")
async def retrieve_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].database.retrieve_all_question_answers()
return {"rows": rows, "log": "Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}
@app.get("/chat/ranking_table/close")
async def close_ranking_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
component["component"].database.close_connection()
return {"log": "Table closed", "status": "200"}
except Exception as ex:
return {"log": f"Table close failed: {ex}", "status": "500"}
[docs] def create_chatbot_comparator_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create chatbot comparator routes for the application.
Args:
app (FastAPI): The FastAPI application.
component (Dict[str, Any]): The component for which the routes are being created.
"""
@app.post("/chat/comparator/{message}")
async def compare(
message: str,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
output_dict = {}
# insert question and answer into database
qid = component["component"].question_db.insert(
question=message,
)
# TODO: refactor to run multiple models in parallel using threading
for model_name, model in component["component"].models.items():
output = model.predict(message)[0]
# TODO: refactor this into using another comparator database
output_dict[model_name] = output
component["component"].comparator_db.insert(
model=model_name,
qid=qid,
rank=1, # default rank is 1
answer=output,
)
return {
"qid": qid,
"log": "Inference complete",
"status": "200",
"question": message,
"answer": output_dict,
}
except Exception as ex:
return {"log": f"Inference failed: {ex}", "status": "500"}
@app.post("/chat/comparator/db/update")
async def update_comparator(
request: ComparatorInsertRequest,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
print("REQ", request.data)
try:
for model_answer in request.data:
component["component"].comparator_db.update(
model=model_answer.model,
qid=model_answer.qid,
rank=model_answer.rank,
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}
@app.get("/chat/comparator/db/retrieve")
async def retrieve_comparator(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].comparator_db.retrieve_all()
data = []
for row in rows:
_, model_name, qid, rank, answer, _ = row
data.append(
{
"model": model_name,
"qid": qid,
"rank": rank,
"answer": answer,
}
)
return {"data": data, "log": "Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}
@app.get("/chat/comparator/db/close")
async def close_comparator(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
component["component"].question_db.close_connection()
component["component"].comparator_db.close_connection()
return {"log": "Table closed", "status": "200"}
except Exception as ex:
return {"log": f"Table close failed: {ex}", "status": "500"}
[docs] def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create QA retrieval routes for the application.
Args:
app (FastAPI): The FastAPI application.
component (Dict[str, Any]): The component for which the routes are being created.
"""
@app.get("/retrieval/file/get")
async def get_files(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
print("[/retrieval/file/get]: getting files...")
# create folder if it doesn't exist
dir_path = os.environ["DOC_PATH"]
if not os.path.exists(dir_path):
os.makedirs(dir_path)
files = os.listdir(dir_path)
# Create a list of dictionaries, each containing the file's name, size and type
file_data = []
for file in files:
size = os.path.getsize(
os.path.join(dir_path, file)
) # get size of file in bytes
_, ext = os.path.splitext(
file
) # split the file name into name and extension
file_data.append(
{
"name": file,
"size": size,
"type": ext[1:], # remove the period from the extension
}
)
return {"files": file_data}
@app.post("/retrieval/file/upload")
async def upload_files(
files: List[UploadFile],
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
# create folder if it doesn't exist
if not os.path.exists(os.getenv("DOC_PATH")):
os.makedirs(os.getenv("DOC_PATH"))
print("[/retrieval/file/upload]: upload files...")
# Check if any file is sent
if not files:
raise HTTPException(status_code=400, detail="No file part")
filenames = []
# Iterate over each file
for file in files:
# Check if file is selected
if not file.filename:
raise HTTPException(status_code=400, detail="No selected file")
print(f"[/retrieval/file/upload]: saving file {file.filename}")
# Save or process the file
with open(
os.path.join(os.getenv("DOC_PATH"), file.filename), "wb"
) as buffer:
buffer.write(await file.read())
filenames.append(file.filename)
# List all files in the DOC_PATH directory
file_list = os.listdir(os.getenv("DOC_PATH"))
return JSONResponse(
{"status": "ok", "filenames": filenames, "files": file_list}
)
except Exception as e:
return JSONResponse({"status": "error", "message": str(e)})
@app.post("/retrieval/vector_db/index")
async def index_vector_db(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
print("[/retrieval/vector_db/index]: indexing files...")
component["component"].vector_db.index()
return {"log": "Indexing complete", "status": "200"}
except Exception as ex:
return {"log": f"Indexing failed: {ex}", "status": "500"}
@app.get("/retrieval/{message}")
async def inference(
message: str,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("[/retrieval]: model inference...")
output = component["component"].retrieval_model.run(message)
id = component["component"].database.insert_question_answer(
message, output
)
return {
"id": id,
"log": "Inference complete",
"status": "200",
"question": message,
"answer": output,
}
except Exception as ex:
return {"log": f"Inference failed: {ex}", "status": "500"}
@app.get("/retrieval/vector_db/get")
async def get_vector_db(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
print("[/retrieval/vector_db/get]: get embedding...")
response_dict = component["component"].vector_db.get_embedding()
return response_dict
except Exception as ex:
return {
"log": "Failed to get embedding: {}".format(ex),
"status": "500",
}
[docs] def create_nvml_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create NVML routes for the application.
Args:
app (FastAPI): The FastAPI application.
component (Dict[str, Any]): The component for which the routes are being created.
"""
@app.get("/nvml")
async def get_nvml_info(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
print("[/nvml]: get nvml info...")
response_dict = component["component"].nvml.get()
return response_dict
except Exception as ex:
return {
"log": "Failed to get nvml info: {}".format(ex),
"status": "500",
}
[docs] def run(self):
"""
Run the application.
"""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/token")
def login(credentials: HTTPBasicCredentials = Depends(oauth_scheme)):
user = self.authenticate_user(
self._fake_users_db, credentials.username, credentials.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return {"message": "Logged in successfully"}
@app.get("/components")
async def get_components(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return JSONResponse(
[
{
"id": component["id"],
"svelte_component": component["svelte_component"],
"props": component["props"],
}
for component in self.components
]
)
def create_data_route(id: str, data_source: Any):
"""
Create data route for the application.
Args:
id (str): The id of the data source.
data_source (Any): The data source.
"""
@app.get(f"/data/{id}")
async def get_data(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
data = data_source.fetch_func()
return JSONResponse(data)
for id, data_source in self.data_sources.items():
create_data_route(id, data_source)
for component in self.components:
if component["svelte_component"] == "Chatbot":
self.create_chatbot_route(app, component)
if component["svelte_component"] == "Feedback":
self.create_feedback_route(app, component)
if component["svelte_component"] == "Compare":
self.create_chatbot_comparator_route(app, component)
if component["svelte_component"] == "RetrievalQA":
self.create_qa_retrieval_route(app, component)
if component["svelte_component"] == "Nvml":
self.create_nvml_route(app, component)
app.mount(
"/",
StaticFiles(
directory=os.path.join(
os.path.dirname(os.path.realpath(__file__)), "frontend/dist"
),
html=True,
),
name="static",
)
@app.get("/{path:path}")
async def read_item(
path: str,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
return {"path": path}
# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.
# Set the ngrok tunnel if share is True
start_event = AppStartEvent(
start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time())
)
self._telemetry.capture(start_event)
if self._share:
public_url = ngrok.connect(self._host + ":" + str(self._port))
print("Public URL:", public_url)
import uvicorn
uvicorn.run(app, host=self._host, port=self._port)
print("Stopping server...")
ngrok.disconnect(public_url)
else:
import uvicorn
uvicorn.run(app, host=self._host, port=self._port)
self._telemetry.capture(
AppStopEvent(
end_time=time.time(),
date_time=datetime.utcfromtimestamp(time.time()),
duration=time.time() - start_event.start_time,
)
)
[docs] def display(self):
"""
Run the application.
"""
import nest_asyncio
nest_asyncio.apply()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/token")
def login(credentials: HTTPBasicCredentials = Depends(oauth_scheme)):
user = self.authenticate_user(
self._fake_users_db, credentials.username, credentials.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return {"message": "Logged in successfully"}
@app.get("/components")
async def get_components(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return JSONResponse(
[
{
"id": component["id"],
"svelte_component": component["svelte_component"],
"props": component["props"],
}
for component in self.components
]
)
def create_data_route(id: str, data_source: Any):
"""
Create data route for the application.
Args:
id (str): The id of the data source.
data_source (Any): The data source.
"""
@app.get(f"/data/{id}")
async def get_data(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
data = data_source.fetch_func()
return JSONResponse(data)
for id, data_source in self.data_sources.items():
create_data_route(id, data_source)
for component in self.components:
if component["svelte_component"] == "Chatbot":
self.create_chatbot_route(app, component)
if component["svelte_component"] == "Feedback":
self.create_feedback_route(app, component)
if component["svelte_component"] == "Compare":
self.create_chatbot_comparator_route(app, component)
app.mount(
"/",
StaticFiles(
directory=os.path.join(
os.path.dirname(os.path.realpath(__file__)), "frontend/dist"
),
html=True,
),
name="static",
)
@app.get("/{path:path}")
async def read_item(
path: str, user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return {"path": path}
# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.
# Set the ngrok tunnel if share is True
if self._share:
public_url = ngrok.connect(self._port)
print("Public URL:", public_url)
import uvicorn
uvicorn.run(app, host=self._host, port=self._port)
print("Stopping server...")
ngrok.disconnect(public_url)
else:
import uvicorn
def run_uvicorn():
uvicorn.run(app, host=self._host, port=self._port)
t = threading.Thread(target=run_uvicorn)
t.start()
return Chatbot()(port=self._host)