Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import warnings
from concurrent.futures import ThreadPoolExecutor, Future
from datetime import datetime
from dis import specialized
from typing import Any, List, Optional, Union, Dict, Iterator

import orjson
Expand Down Expand Up @@ -46,6 +45,7 @@
from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
from apps.system.crud.parameter_manage import get_groups
from apps.system.crud.user import user_ws_list
from apps.system.schemas.system_schema import AssistantOutDsSchema
from apps.terminology.curd.terminology import get_terminology_template
from common.core.config import settings
Expand Down Expand Up @@ -87,6 +87,7 @@ def extract_tables_from_sql(sql: str, ds_type: str = None) -> set:
class LLMService:
ds: CoreDatasource
chat_question: ChatQuestion
oid: int
record: ChatRecord
config: LLMConfig
llm: BaseChatModel
Expand Down Expand Up @@ -127,15 +128,27 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C

self.table_name_list = []

chat_question.lang = get_lang_name(current_user.language)
self.trans = i18n(lang=current_user.language)

chat_id = chat_question.chat_id
chat: Chat | None = session.get(Chat, chat_id)
if not chat:
raise SingleMessageError(f"Chat with id {chat_id} not found")
self.oid = chat.oid

if self.oid:
w_list = user_ws_list(session, self.current_user.id)
oid_list = [item.id for item in w_list]
if int(self.oid) not in oid_list:
raise SingleMessageError("Current user cannot not access this chat")

self.current_user.oid = chat.oid
ds: CoreDatasource | AssistantOutDsSchema | None = None
if not chat.datasource and chat_question.datasource_id:
_ds = session.get(CoreDatasource, chat_question.datasource_id)
if _ds:
if _ds.oid != current_user.oid:
if _ds.oid != self.oid:
raise SingleMessageError(
f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
chat.datasource = _ds.id
Expand Down Expand Up @@ -165,9 +178,6 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C

self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id)

chat_question.lang = get_lang_name(current_user.language)
self.trans = i18n(lang=current_user.language)

self.ds = (
ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None
self.chat_question = chat_question
Expand Down Expand Up @@ -345,7 +355,7 @@ def filter_terminology_template(self, _session: Session, oid: int = None, ds_id:
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
if self.current_assistant and self.current_assistant.type == 1:
Expand Down Expand Up @@ -373,7 +383,7 @@ def filter_custom_prompts(self, _session: Session, custom_prompt_type: CustomPro
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
if self.current_assistant and self.current_assistant.type == 1:
Expand All @@ -399,7 +409,7 @@ def filter_training_template(self, _session: Session, oid: int = None, ds_id: in
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
if self.current_assistant and self.current_assistant.type == 1:
Expand Down Expand Up @@ -451,9 +461,9 @@ def generate_analysis(self, _session: Session):

ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None

self.filter_terminology_template(_session, self.current_user.oid, ds_id)
self.filter_terminology_template(_session, self.oid, ds_id)

self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.current_user.oid, ds_id)
self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.oid, ds_id)

analysis_msg.append(SystemPromptMessage(content=self.chat_question.analysis_sys_question()))
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
Expand Down Expand Up @@ -504,7 +514,7 @@ def generate_predict(self, _session: Session):
self.chat_question.data = orjson.dumps(data.get('data')).decode()

ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.current_user.oid, ds_id)
self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.oid, ds_id)

predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
predict_msg.append(SystemPromptMessage(content=self.chat_question.predict_sys_question()))
Expand Down Expand Up @@ -624,7 +634,7 @@ def select_datasource(self, _session: Session):
_ds_list = get_assistant_ds(session=_session, llm_service=self)
else:
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
and_(CoreDatasource.oid == self.current_user.oid))
and_(CoreDatasource.oid == self.oid))
_ds_list = [
{
"id": ds.id,
Expand All @@ -643,7 +653,7 @@ def select_datasource(self, _session: Session):
if not ignore_auto_select:
if settings.TABLE_EMBEDDING_ENABLED and (
not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)):
_ds_list = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance,
_ds_list = get_ds_embedding(_session, _ds_list, self.out_ds_instance,
self.chat_question.question, self.current_assistant)
# yield {'content': '{"id":' + str(ds.get('id')) + '}'}

Expand Down
4 changes: 2 additions & 2 deletions backend/apps/datasource/embedding/ds_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from apps.system.crud.assistant import AssistantOutDs
from common.core.config import settings
from common.core.deps import CurrentAssistant
from common.core.deps import SessionDep, CurrentUser
from common.core.deps import SessionDep
from common.utils.utils import SQLBotLogUtil


def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs,
def get_ds_embedding(session: SessionDep, _ds_list, out_ds: AssistantOutDs,
question: str,
current_assistant: Optional[CurrentAssistant] = None):
_list = []
Expand Down
37 changes: 25 additions & 12 deletions backend/apps/system/crud/user.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,46 @@

from typing import Optional

from sqlmodel import Session, func, select, delete as sqlmodel_delete

from apps.system.models.system_model import UserWsModel, WorkspaceModel
from apps.system.schemas.auth import CacheName, CacheNamespace
from apps.system.schemas.system_schema import EMAIL_REGEX, PWD_REGEX, BaseUserDTO, UserInfoDTO, UserWs
from common.core.deps import SessionDep
from common.core.security import verify_md5pwd
from common.core.sqlbot_cache import cache, clear_cache
from common.utils.locale import I18n
from common.utils.locale import I18n, I18nHelper
from common.utils.utils import SQLBotLogUtil
from ..models.user import UserModel, UserPlatformModel
from common.core.security import verify_md5pwd
import re


def get_db_user(*, session: Session, user_id: int) -> UserModel:
db_user = session.get(UserModel, user_id)
return db_user


def get_user_by_account(*, session: Session, account: str) -> BaseUserDTO | None:
statement = select(UserModel).where(UserModel.account == account)
db_user = session.exec(statement).first()
if not db_user:
return None
return BaseUserDTO.model_validate(db_user.model_dump())


@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="user_id")
async def get_user_info(*, session: Session, user_id: int) -> UserInfoDTO | None:
db_user: UserModel = get_db_user(session = session, user_id = user_id)
db_user: UserModel = get_db_user(session=session, user_id=user_id)
if not db_user:
return None
userInfo = UserInfoDTO.model_validate(db_user.model_dump())
userInfo.isAdmin = userInfo.id == 1 and userInfo.account == 'admin'
if userInfo.isAdmin:
return userInfo
ws_model: UserWsModel = session.exec(select(UserWsModel).where(UserWsModel.uid == userInfo.id, UserWsModel.oid == userInfo.oid)).first()
ws_model: UserWsModel = session.exec(
select(UserWsModel).where(UserWsModel.uid == userInfo.id, UserWsModel.oid == userInfo.oid)).first()
userInfo.weight = ws_model.weight if ws_model else -1
return userInfo


def authenticate(*, session: Session, account: str, password: str) -> BaseUserDTO | None:
db_user = get_user_by_account(session=session, account=account)
if not db_user:
Expand All @@ -44,7 +49,8 @@ def authenticate(*, session: Session, account: str, password: str) -> BaseUserDT
return None
return db_user

async def user_ws_options(session: Session, uid: int, trans: Optional[I18n] = None) -> list[UserWs]:

def user_ws_list(session: Session, uid: int, trans: Optional[I18n | I18nHelper] = None) -> list[UserWs]:
if uid == 1:
stmt = select(WorkspaceModel.id, WorkspaceModel.name).order_by(WorkspaceModel.name, WorkspaceModel.create_time)
else:
Expand All @@ -57,16 +63,20 @@ async def user_ws_options(session: Session, uid: int, trans: Optional[I18n] = No
if not trans:
return result.all()
list_result = [
UserWs(id = id, name = trans(name) if name.startswith('i18n') else name)
UserWs(id=id, name=trans(name) if name.startswith('i18n') else name)
for id, name in result.all()
]
if list_result:
list_result.sort(key=lambda x: x.name)
return list_result


async def user_ws_options(session: Session, uid: int, trans: Optional[I18n | I18nHelper] = None) -> list[UserWs]:
return user_ws_list(session, uid, trans)


@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
async def single_delete(session: SessionDep, id: int):
user_model: UserModel = get_db_user(session = session, user_id = id)
user_model: UserModel = get_db_user(session=session, user_id=id)
del_stmt = sqlmodel_delete(UserWsModel).where(UserWsModel.uid == id)
session.exec(del_stmt)
if user_model and user_model.origin and user_model.origin != 0:
Expand All @@ -75,20 +85,23 @@ async def single_delete(session: SessionDep, id: int):
session.delete(user_model)
session.commit()

@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")

@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
async def clean_user_cache(id: int):
SQLBotLogUtil.info(f"User cache for [{id}] has been cleaned")


def check_account_exists(*, session: Session, account: str) -> bool:
return session.exec(select(func.count()).select_from(UserModel).where(UserModel.account == account)).one() > 0


def check_email_exists(*, session: Session, email: str) -> bool:
return session.exec(select(func.count()).select_from(UserModel).where(UserModel.email == email)).one() > 0



def check_email_format(email: str) -> bool:
return bool(EMAIL_REGEX.fullmatch(email))


def check_pwd_format(pwd: str) -> bool:
return bool(PWD_REGEX.fullmatch(pwd))
Loading