diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index be391334..df36ad40 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -6,7 +6,6 @@ import warnings from concurrent.futures import ThreadPoolExecutor, Future from datetime import datetime -from dis import specialized from typing import Any, List, Optional, Union, Dict, Iterator import orjson @@ -46,6 +45,7 @@ from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds from apps.system.crud.parameter_manage import get_groups +from apps.system.crud.user import user_ws_list from apps.system.schemas.system_schema import AssistantOutDsSchema from apps.terminology.curd.terminology import get_terminology_template from common.core.config import settings @@ -87,6 +87,7 @@ def extract_tables_from_sql(sql: str, ds_type: str = None) -> set: class LLMService: ds: CoreDatasource chat_question: ChatQuestion + oid: int record: ChatRecord config: LLMConfig llm: BaseChatModel @@ -127,15 +128,27 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C self.table_name_list = [] + chat_question.lang = get_lang_name(current_user.language) + self.trans = i18n(lang=current_user.language) + chat_id = chat_question.chat_id chat: Chat | None = session.get(Chat, chat_id) if not chat: raise SingleMessageError(f"Chat with id {chat_id} not found") + self.oid = chat.oid + + if self.oid: + w_list = user_ws_list(session, self.current_user.id) + oid_list = [item.id for item in w_list] + if int(self.oid) not in oid_list: + raise SingleMessageError("Current user cannot not access this chat") + + self.current_user.oid = chat.oid ds: CoreDatasource | AssistantOutDsSchema | None = None if not chat.datasource and chat_question.datasource_id: _ds = session.get(CoreDatasource, chat_question.datasource_id) if _ds: - if _ds.oid != current_user.oid: + if _ds.oid != self.oid: raise SingleMessageError( f"Datasource with id {chat_question.datasource_id} does not belong to current workspace") chat.datasource = _ds.id @@ -165,9 +178,6 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id) - chat_question.lang = get_lang_name(current_user.language) - self.trans = i18n(lang=current_user.language) - self.ds = ( ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None self.chat_question = chat_question @@ -345,7 +355,7 @@ def filter_terminology_template(self, _session: Session, oid: int = None, ds_id: calculate_oid = oid calculate_ds_id = ds_id if self.current_assistant: - calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid + calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid if self.current_assistant.type == 1: calculate_ds_id = None if self.current_assistant and self.current_assistant.type == 1: @@ -373,7 +383,7 @@ def filter_custom_prompts(self, _session: Session, custom_prompt_type: CustomPro calculate_oid = oid calculate_ds_id = ds_id if self.current_assistant: - calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid + calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid if self.current_assistant.type == 1: calculate_ds_id = None if self.current_assistant and self.current_assistant.type == 1: @@ -399,7 +409,7 @@ def filter_training_template(self, _session: Session, oid: int = None, ds_id: in calculate_oid = oid calculate_ds_id = ds_id if self.current_assistant: - calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid + calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.oid if self.current_assistant.type == 1: calculate_ds_id = None if self.current_assistant and self.current_assistant.type == 1: @@ -451,9 +461,9 @@ def generate_analysis(self, _session: Session): ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.filter_terminology_template(_session, self.current_user.oid, ds_id) + self.filter_terminology_template(_session, self.oid, ds_id) - self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.current_user.oid, ds_id) + self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.oid, ds_id) analysis_msg.append(SystemPromptMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) @@ -504,7 +514,7 @@ def generate_predict(self, _session: Session): self.chat_question.data = orjson.dumps(data.get('data')).decode() ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.current_user.oid, ds_id) + self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.oid, ds_id) predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemPromptMessage(content=self.chat_question.predict_sys_question())) @@ -624,7 +634,7 @@ def select_datasource(self, _session: Session): _ds_list = get_assistant_ds(session=_session, llm_service=self) else: stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( - and_(CoreDatasource.oid == self.current_user.oid)) + and_(CoreDatasource.oid == self.oid)) _ds_list = [ { "id": ds.id, @@ -643,7 +653,7 @@ def select_datasource(self, _session: Session): if not ignore_auto_select: if settings.TABLE_EMBEDDING_ENABLED and ( not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)): - _ds_list = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance, + _ds_list = get_ds_embedding(_session, _ds_list, self.out_ds_instance, self.chat_question.question, self.current_assistant) # yield {'content': '{"id":' + str(ds.get('id')) + '}'} diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py index 9bfe4a48..f75ee7df 100644 --- a/backend/apps/datasource/embedding/ds_embedding.py +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -11,11 +11,11 @@ from apps.system.crud.assistant import AssistantOutDs from common.core.config import settings from common.core.deps import CurrentAssistant -from common.core.deps import SessionDep, CurrentUser +from common.core.deps import SessionDep from common.utils.utils import SQLBotLogUtil -def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs, +def get_ds_embedding(session: SessionDep, _ds_list, out_ds: AssistantOutDs, question: str, current_assistant: Optional[CurrentAssistant] = None): _list = [] diff --git a/backend/apps/system/crud/user.py b/backend/apps/system/crud/user.py index 1f5ccb59..5f9928b1 100644 --- a/backend/apps/system/crud/user.py +++ b/backend/apps/system/crud/user.py @@ -1,21 +1,23 @@ - from typing import Optional + from sqlmodel import Session, func, select, delete as sqlmodel_delete + from apps.system.models.system_model import UserWsModel, WorkspaceModel from apps.system.schemas.auth import CacheName, CacheNamespace from apps.system.schemas.system_schema import EMAIL_REGEX, PWD_REGEX, BaseUserDTO, UserInfoDTO, UserWs from common.core.deps import SessionDep +from common.core.security import verify_md5pwd from common.core.sqlbot_cache import cache, clear_cache -from common.utils.locale import I18n +from common.utils.locale import I18n, I18nHelper from common.utils.utils import SQLBotLogUtil from ..models.user import UserModel, UserPlatformModel -from common.core.security import verify_md5pwd -import re + def get_db_user(*, session: Session, user_id: int) -> UserModel: db_user = session.get(UserModel, user_id) return db_user + def get_user_by_account(*, session: Session, account: str) -> BaseUserDTO | None: statement = select(UserModel).where(UserModel.account == account) db_user = session.exec(statement).first() @@ -23,19 +25,22 @@ def get_user_by_account(*, session: Session, account: str) -> BaseUserDTO | None return None return BaseUserDTO.model_validate(db_user.model_dump()) + @cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="user_id") async def get_user_info(*, session: Session, user_id: int) -> UserInfoDTO | None: - db_user: UserModel = get_db_user(session = session, user_id = user_id) + db_user: UserModel = get_db_user(session=session, user_id=user_id) if not db_user: return None userInfo = UserInfoDTO.model_validate(db_user.model_dump()) userInfo.isAdmin = userInfo.id == 1 and userInfo.account == 'admin' if userInfo.isAdmin: return userInfo - ws_model: UserWsModel = session.exec(select(UserWsModel).where(UserWsModel.uid == userInfo.id, UserWsModel.oid == userInfo.oid)).first() + ws_model: UserWsModel = session.exec( + select(UserWsModel).where(UserWsModel.uid == userInfo.id, UserWsModel.oid == userInfo.oid)).first() userInfo.weight = ws_model.weight if ws_model else -1 return userInfo + def authenticate(*, session: Session, account: str, password: str) -> BaseUserDTO | None: db_user = get_user_by_account(session=session, account=account) if not db_user: @@ -44,7 +49,8 @@ def authenticate(*, session: Session, account: str, password: str) -> BaseUserDT return None return db_user -async def user_ws_options(session: Session, uid: int, trans: Optional[I18n] = None) -> list[UserWs]: + +def user_ws_list(session: Session, uid: int, trans: Optional[I18n | I18nHelper] = None) -> list[UserWs]: if uid == 1: stmt = select(WorkspaceModel.id, WorkspaceModel.name).order_by(WorkspaceModel.name, WorkspaceModel.create_time) else: @@ -57,16 +63,20 @@ async def user_ws_options(session: Session, uid: int, trans: Optional[I18n] = No if not trans: return result.all() list_result = [ - UserWs(id = id, name = trans(name) if name.startswith('i18n') else name) + UserWs(id=id, name=trans(name) if name.startswith('i18n') else name) for id, name in result.all() ] if list_result: list_result.sort(key=lambda x: x.name) return list_result - + +async def user_ws_options(session: Session, uid: int, trans: Optional[I18n | I18nHelper] = None) -> list[UserWs]: + return user_ws_list(session, uid, trans) + + @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id") async def single_delete(session: SessionDep, id: int): - user_model: UserModel = get_db_user(session = session, user_id = id) + user_model: UserModel = get_db_user(session=session, user_id=id) del_stmt = sqlmodel_delete(UserWsModel).where(UserWsModel.uid == id) session.exec(del_stmt) if user_model and user_model.origin and user_model.origin != 0: @@ -75,20 +85,23 @@ async def single_delete(session: SessionDep, id: int): session.delete(user_model) session.commit() -@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id") + +@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id") async def clean_user_cache(id: int): SQLBotLogUtil.info(f"User cache for [{id}] has been cleaned") def check_account_exists(*, session: Session, account: str) -> bool: return session.exec(select(func.count()).select_from(UserModel).where(UserModel.account == account)).one() > 0 + + def check_email_exists(*, session: Session, email: str) -> bool: return session.exec(select(func.count()).select_from(UserModel).where(UserModel.email == email)).one() > 0 - def check_email_format(email: str) -> bool: return bool(EMAIL_REGEX.fullmatch(email)) + def check_pwd_format(pwd: str) -> bool: return bool(PWD_REGEX.fullmatch(pwd))