1414)
1515
1616from llm_client import LLMClient , Model , Provider
17+ from db_client import DBClient , Account
1718
1819
1920logging .basicConfig (
2526
2627# Environment variables names
2728_TELEGRAM_BOT_TOKEN_VAR_NAME = "TELEGRAM_BOT_TOKEN"
29+ _DB_URI_VAR_NAME = "DB_URI"
2830
2931_START_MESSAGE = (
3032 "Hi there! I'm a bot designed to assist you in rephrasing your text into polished English. "
@@ -38,30 +40,51 @@ async def start_command(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
3840
3941
4042async def rewrite (
41- update : Update , context : ContextTypes .DEFAULT_TYPE , llm_client : LLMClient
43+ update : Update ,
44+ context : ContextTypes .DEFAULT_TYPE ,
45+ db_client : DBClient ,
46+ llm_client : LLMClient ,
4247) -> None :
48+ # Handle the message
4349 input_message = None
4450 if update .message is not None :
4551 input_message = update .message .text
4652 elif update .edited_message is not None :
4753 input_message = update .edited_message .text
4854 assert input_message is not None , "No message to rewrite."
49- rewritten_text = llm_client .rewrite (input_message )
55+ # Handle the user
56+ user_id = update .message .from_user .id
57+ username = update .message .from_user .username
58+ account = db_client .get_or_create_account (user_id = user_id , username = username )
59+ # Check if the user has run out of tokens.
60+ # If the user is a friend, they have an unlimited token balance. ;)
61+ if not account .is_friend and account .tokens_balance <= 0 :
62+ await context .bot .send_message (
63+ chat_id = update .effective_chat .id ,
64+ text = "You have run out of tokens. 🥲\n Please contact the bot owner to get more." ,
65+ )
66+ return
67+ # Rewrite the message. Do not touch the user's token balance if they are a friend.
68+ rewritten_text , num_tokens = llm_client .rewrite (input_message )
69+ if not account .is_friend :
70+ db_client .decrease_token_balance (account = account , num_tokens = num_tokens )
5071 await context .bot .send_message (
5172 chat_id = update .effective_chat .id ,
5273 text = rewritten_text ,
5374 )
5475
5576
5677if __name__ == "__main__" :
78+ db_client = DBClient (db_url = os .getenv (_DB_URI_VAR_NAME ))
79+ llm_client = LLMClient (provider = Provider .GROQ , model = Model .GEMMA )
80+
5781 app = ApplicationBuilder ().token (os .getenv (_TELEGRAM_BOT_TOKEN_VAR_NAME )).build ()
5882
5983 app .add_handler (CommandHandler ("start" , start_command ))
60-
61- llm_client = LLMClient (provider = Provider .GROQ , model = Model .GEMMA )
6284 app .add_handler (
6385 MessageHandler (
64- filters .TEXT & (~ filters .COMMAND ), partial (rewrite , llm_client = llm_client )
86+ filters .TEXT & (~ filters .COMMAND ),
87+ partial (rewrite , db_client = db_client , llm_client = llm_client ),
6588 )
6689 )
6790
0 commit comments