diff --git a/.dockerignore b/.dockerignore
index 4508f9d..0f21929 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,3 +1,4 @@
+README.md
*.md
.github
res
\ No newline at end of file
diff --git a/bot/func/controller.py b/bot/func/functions.py
similarity index 54%
rename from bot/func/controller.py
rename to bot/func/functions.py
index 940cdd8..54c1ce8 100644
--- a/bot/func/controller.py
+++ b/bot/func/functions.py
@@ -1,21 +1,22 @@
-import json
import logging
import os
-from asyncio import Lock
-
import aiohttp
+import json
+from aiogram import types
+from asyncio import Lock
+from functools import wraps
from dotenv import load_dotenv
+# --- Environment
load_dotenv()
-system_info = os.uname()
+# --- Environment Checker
token = os.getenv("TOKEN")
-ollama_base_url = os.getenv("OLLAMA_BASE_URL")
allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(",")))
-# Will be implemented soon
-# content = []
-
+ollama_base_url = os.getenv("OLLAMA_BASE_URL")
log_level_str = os.getenv("LOG_LEVEL", "INFO")
+
+# --- Other
log_levels = list(logging._levelToName.values())
# ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']
@@ -28,6 +29,7 @@
logging.basicConfig(level=log_level)
+# Ollama API
async def model_list():
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:11434/api/tags"
@@ -53,7 +55,48 @@ async def generate(payload: dict, modelname: str, prompt: str):
yield json.loads(decoded_chunk)
-# Telegram-related
+# Aiogram functions & wraps
+def perms_allowed(func):
+ @wraps(func)
+ async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
+ user_id = message.from_user.id if message else query.from_user.id
+ if user_id in admin_ids or user_id in allowed_ids:
+ if message:
+ return await func(message)
+ elif query:
+ return await func(query=query)
+ else:
+ if message:
+ await message.answer("Access Denied")
+ elif query:
+ await query.answer("Access Denied")
+
+ return wrapper
+
+
+def perms_admins(func):
+ @wraps(func)
+ async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
+ user_id = message.from_user.id if message else query.from_user.id
+ if user_id in admin_ids:
+ if message:
+ return await func(message)
+ elif query:
+ return await func(query=query)
+ else:
+ if message:
+ await message.answer("Access Denied")
+ logging.info(
+ f"[MSG] {message.from_user.first_name} {message.from_user.last_name}({message.from_user.id}) is not allowed to use this bot."
+ )
+ elif query:
+ await query.answer("Access Denied")
+ logging.info(
+ f"[QUERY] {message.from_user.first_name} {message.from_user.last_name}({message.from_user.id}) is not allowed to use this bot."
+ )
+
+ return wrapper
+
def md_autofixer(text: str) -> str:
# In MarkdownV2, these characters must be escaped: _ * [ ] ( ) ~ ` > # + - = | { } . !
escape_chars = r"_[]()~>#+-=|{}.!"
@@ -61,6 +104,7 @@ def md_autofixer(text: str) -> str:
return "".join("\\" + char if char in escape_chars else char for char in text)
+# Context-Related
class contextLock:
lock = Lock()
diff --git a/bot/run.py b/bot/run.py
index 101169d..5dbe3ab 100644
--- a/bot/run.py
+++ b/bot/run.py
@@ -1,14 +1,15 @@
-import asyncio
-import traceback
-
-import io
-import base64
-from aiogram import Bot, Dispatcher, types
+from aiogram import Bot, Dispatcher
from aiogram.enums import ParseMode
from aiogram.filters.command import Command, CommandStart
from aiogram.types import Message
+from aiogram import F
from aiogram.utils.keyboard import InlineKeyboardBuilder
-from func.controller import *
+from func.functions import *
+# Other
+import asyncio
+import traceback
+import io
+import base64
bot = Bot(token=token)
dp = Dispatcher()
@@ -21,37 +22,38 @@
commands = [
types.BotCommand(command="start", description="Start"),
types.BotCommand(command="reset", description="Reset Chat"),
- types.BotCommand(command="getcontext", description="Get chat context json"),
+ types.BotCommand(command="history", description="Look through messages"),
]
-
+# Context variables for OllamaAPI
ACTIVE_CHATS = {}
ACTIVE_CHATS_LOCK = contextLock()
-
modelname = os.getenv("INITMODEL")
+mention = None
+
+async def get_bot_info():
+ global mention
+ if mention is None:
+ get = await bot.get_me()
+ mention = (f"@{get.username}")
+ return mention
+
+# /start command
@dp.message(CommandStart())
async def command_start_handler(message: Message) -> None:
- if message.from_user.id in allowed_ids:
- start_message = f"Welcome to OllamaTelegram Bot, ***{message.from_user.full_name}***!\nSource code: https://github.com/ruecat/ollama-telegram"
- start_message_md = md_autofixer(start_message)
- await message.answer(
- start_message_md,
- parse_mode=ParseMode.MARKDOWN_V2,
- reply_markup=builder.as_markup(),
- disable_web_page_preview=True,
- )
- else:
- await message.answer(
- f"{message.from_user.full_name} [AuthBlocked]\nContact staff to whitelist you",
- parse_mode=ParseMode.MARKDOWN_V2,
- )
- logging.info(
- f"[Interactions] {message.from_user.first_name} {message.from_user.last_name}({message.from_user.id}) is not allowed to use this bot. Value in environment: {allowed_ids}"
- )
+ start_message = f"Welcome to OllamaTelegram Bot, ***{message.from_user.full_name}***!\nSource code: https://github.com/ruecat/ollama-telegram"
+ start_message_md = md_autofixer(start_message)
+ await message.answer(
+ start_message_md,
+ parse_mode=ParseMode.MARKDOWN_V2,
+ reply_markup=builder.as_markup(),
+ disable_web_page_preview=True,
+ )
+# /reset command, wipes context (history)
@dp.message(Command("reset"))
async def command_reset_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
@@ -65,7 +67,8 @@ async def command_reset_handler(message: Message) -> None:
)
-@dp.message(Command("getcontext"))
+# /history command | Displays dialogs between LLM and USER
+@dp.message(Command("history"))
async def command_get_context_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
if message.from_user.id in ACTIVE_CHATS:
@@ -87,26 +90,23 @@ async def command_get_context_handler(message: Message) -> None:
@dp.callback_query(lambda query: query.data == "modelmanager")
async def modelmanager_callback_handler(query: types.CallbackQuery):
- if query.from_user.id in admin_ids:
- models = await model_list()
- modelmanager_builder = InlineKeyboardBuilder()
- for model in models:
- modelname = model["name"]
- modelfamilies = ""
- if model["details"]["families"]:
- modelicon = {"llama": "š¦","clip":"š·"}
- modelfamilies = "".join([modelicon[family] for family in model['details']['families']])
- # Add a button for each model
- modelmanager_builder.row(
- types.InlineKeyboardButton(
- text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}"
- )
+ models = await model_list()
+ modelmanager_builder = InlineKeyboardBuilder()
+ for model in models:
+ modelname = model["name"]
+ modelfamilies = ""
+ if model["details"]["families"]:
+ modelicon = {"llama": "š¦", "clip": "š·"}
+ modelfamilies = "".join([modelicon[family] for family in model['details']['families']])
+ # Add a button for each model
+ modelmanager_builder.row(
+ types.InlineKeyboardButton(
+ text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}"
)
- await query.message.edit_text(
- f"Choose model:", reply_markup=modelmanager_builder.as_markup()
)
- else:
- await query.answer("Access Denied")
+ await query.message.edit_text(
+ f"Choose model:", reply_markup=modelmanager_builder.as_markup()
+ )
@dp.callback_query(lambda query: query.data.startswith("model_"))
@@ -118,127 +118,123 @@ async def model_callback_handler(query: types.CallbackQuery):
@dp.callback_query(lambda query: query.data == "info")
+@perms_admins
async def systeminfo_callback_handler(query: types.CallbackQuery):
- if query.from_user.id in admin_ids:
- await bot.send_message(
- chat_id=query.message.chat.id,
- text=f"š¦ LLM\nCurrent model: {modelname}
\n\nš§ Hardware\nKernel: {system_info[0]}\n
\n(Other options will be added soon..)",
- parse_mode="HTML",
- )
- else:
- await query.answer("Access Denied")
+ await bot.send_message(
+ chat_id=query.message.chat.id,
+ text=f"š¦ LLM\nModel: {modelname}
\n\n",
+ parse_mode="HTML",
+ )
+# React on message | LLM will respond on user's message or mention in groups
@dp.message()
+@perms_allowed
async def handle_message(message: types.Message):
+ await get_bot_info()
+ if message.chat.type == "private":
+ await ollama_request(message, bot)
+ if message.chat.type == "supergroup" and message.text.startswith(mention):
+ # Remove the mention from the message
+ text_without_mention = message.text.replace(mention, "").strip()
+ # Pass the modified text and bot instance to ollama_request
+ await ollama_request(types.Message(
+ message_id=message.message_id,
+ from_user=message.from_user,
+ date=message.date,
+ chat=message.chat,
+ text=text_without_mention
+ ), bot)
+
+
+async def ollama_request(message: types.Message, bot: types.Bot):
try:
- botinfo = await bot.get_me()
- is_allowed_user = message.from_user.id in allowed_ids
- is_private_chat = message.chat.type == "private"
- is_supergroup = message.chat.type == "supergroup"
- bot_mentioned = any(
- entity.type == "mention"
- and message.text[entity.offset : entity.offset + entity.length]
- == f"@{botinfo.username}"
- for entity in message.entities or []
- )
- if (
- is_allowed_user
- and (message.text or message.caption)
- and (is_private_chat or (is_supergroup and bot_mentioned))
- ):
- if is_supergroup and bot_mentioned:
- cutmention = len(botinfo.username) + 2
- prompt = message.text[cutmention:] or message.caption[cutmention:] # + ""
+ await bot.send_chat_action(message.chat.id, "typing")
+ prompt = message.text or message.caption
+ image_base64 = ''
+ if message.content_type == 'photo':
+ image_buffer = io.BytesIO()
+ await bot.download(
+ message.photo[-1],
+ destination=image_buffer
+ )
+ image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8')
+ full_response = ""
+ sent_message = None
+ last_sent_text = None
+
+ async with ACTIVE_CHATS_LOCK:
+ # Add prompt to active chats object
+ if ACTIVE_CHATS.get(message.from_user.id) is None:
+ ACTIVE_CHATS[message.from_user.id] = {
+ "model": modelname,
+ "messages": [{"role": "user", "content": prompt, "images": [image_base64]}],
+ "stream": True,
+ }
else:
- prompt = message.text or message.caption
-
- image_base64=''
- if message.content_type=='photo':
- image_buffer = io.BytesIO()
- await bot.download(
- message.photo[-1],
- destination=image_buffer
+ ACTIVE_CHATS[message.from_user.id]["messages"].append(
+ {"role": "user", "content": prompt, "images": [image_base64]}
)
- image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8')
-
- await bot.send_chat_action(message.chat.id, "typing")
- full_response = ""
- sent_message = None
- last_sent_text = None
-
- async with ACTIVE_CHATS_LOCK:
- # Add prompt to active chats object
- if ACTIVE_CHATS.get(message.from_user.id) is None:
- ACTIVE_CHATS[message.from_user.id] = {
- "model": modelname,
- "messages": [{"role": "user", "content": prompt, "images": [image_base64]}],
- "stream": True,
- }
+ logging.info(
+ f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
+ )
+ payload = ACTIVE_CHATS.get(message.from_user.id)
+ async for response_data in generate(payload, modelname, prompt):
+ msg = response_data.get("message")
+ if msg is None:
+ continue
+ chunk = msg.get("content", "")
+ full_response += chunk
+ full_response_stripped = full_response.strip()
+
+ # avoid Bad Request: message text is empty
+ if full_response_stripped == "":
+ continue
+
+ if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk:
+ if sent_message:
+ if last_sent_text != full_response_stripped:
+ await sent_message.edit_text(full_response_stripped)
+ last_sent_text = full_response_stripped
else:
- ACTIVE_CHATS[message.from_user.id]["messages"].append(
- {"role": "user", "content": prompt, "images": [image_base64]}
+ sent_message = await message.answer(
+ full_response_stripped,
+ reply_to_message_id=message.message_id,
)
- logging.info(
- f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
- )
- payload = ACTIVE_CHATS.get(message.from_user.id)
- async for response_data in generate(payload, modelname, prompt):
- msg = response_data.get("message")
- if msg is None:
- continue
- chunk = msg.get("content", "")
- full_response += chunk
- full_response_stripped = full_response.strip()
+ last_sent_text = full_response_stripped
- # avoid Bad Request: message text is empty
- if full_response_stripped == "":
- continue
-
- if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk:
+ if response_data.get("done"):
+ if (
+ full_response_stripped
+ and last_sent_text != full_response_stripped
+ ):
if sent_message:
- if last_sent_text != full_response_stripped:
- await sent_message.edit_text(full_response_stripped)
- last_sent_text = full_response_stripped
+ await sent_message.edit_text(full_response_stripped)
else:
- sent_message = await message.answer(
- full_response_stripped,
- reply_to_message_id=message.message_id,
- )
- last_sent_text = full_response_stripped
-
- if response_data.get("done"):
- if (
+ sent_message = await message.answer(full_response_stripped)
+ await sent_message.edit_text(
+ md_autofixer(
full_response_stripped
- and last_sent_text != full_response_stripped
- ):
- if sent_message:
- await sent_message.edit_text(full_response_stripped)
- else:
- sent_message = await message.answer(full_response_stripped)
- await sent_message.edit_text(
- md_autofixer(
- full_response_stripped
- + f"\n\nCurrent Model: `{modelname}`**\n**Generated in {response_data.get('total_duration')/1e9:.2f}s"
- ),
- parse_mode=ParseMode.MARKDOWN_V2,
- )
+ + f"\n\nCurrent Model: `{modelname}`**\n**Generated in {response_data.get('total_duration') / 1e9:.2f}s"
+ ),
+ parse_mode=ParseMode.MARKDOWN_V2,
+ )
- async with ACTIVE_CHATS_LOCK:
- if ACTIVE_CHATS.get(message.from_user.id) is not None:
- # Add response to active chats object
- ACTIVE_CHATS[message.from_user.id]["messages"].append(
- {"role": "assistant", "content": full_response_stripped}
- )
- logging.info(
- f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}"
- )
- else:
- await bot.send_message(
- chat_id=message.chat.id, text="Chat was reset"
- )
+ async with ACTIVE_CHATS_LOCK:
+ if ACTIVE_CHATS.get(message.from_user.id) is not None:
+ # Add response to active chats object
+ ACTIVE_CHATS[message.from_user.id]["messages"].append(
+ {"role": "assistant", "content": full_response_stripped}
+ )
+ logging.info(
+ f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}"
+ )
+ else:
+ await bot.send_message(
+ chat_id=message.chat.id, text="Chat was reset"
+ )
- break
+ break
except Exception as e:
await bot.send_message(
chat_id=message.chat.id,
@@ -249,7 +245,6 @@ async def handle_message(message: types.Message):
async def main():
await bot.set_my_commands(commands)
-
await dp.start_polling(bot, skip_update=True)