From edee7ab6936e035cc4ad17414f831963a9c066b8 Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Tue, 10 Feb 2026 17:38:27 +0800 Subject: [PATCH 001/565] Docs: Added v0.24.0 release notes (#13096) ### What problem does this PR solve? Added v0.24.0 release notes. ### Type of change - [x] Documentation Update --- docs/release_notes.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/release_notes.md b/docs/release_notes.md index fc779973afc..e6dbdc4d83f 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -9,6 +9,35 @@ sidebar_custom_props: { Key features, improvements and bug fixes in the latest releases. +## v0.24.0 + +Released on February 10, 2026. + +### New features + +- Memory + - Introduces APIs and an SDK for developer integration. + - Adds Memory extraction log display in the console for improved debugging and tracing. +- Dataset + - Added support for batch management of Metadata. + - Renamed "ToC (Table of Contents)" to "PageIndex". +- Agent + - Launches a new Chat-like Agent conversation management interface that retains Sessions and dialogue history. + - Introduces a multi-Sandbox mechanism, currently supporting local gVisor and Alibaba Cloud, with compatibility for mainstream Sandbox APIs (configurable in the Admin page). +- Chat + - Adds a new "Thinking" mode and removed the previous "Reasoning" configuration option. + - Optimizes retrieval strategies for deep-research scenarios, enhancing recall accuracy. +- Admin + - Adds support for configuring multiple Admin accounts. +- Model configuration center + - Adds a model connection test feature when adding new models. +- Ecosystem + - Adds support for OceanBase as a database alternative to MySQL. + - Adds support for PaddleOCR-VL. +- Model + - Adds new model support for Kimi 2.5, Stepfun 3, and doubao-embedding-vision, among others. +- Data sources + - Adds new data source integrations for Zendesk, Bitbucket, and others. ## v0.23.1 From a546c023449b5c75eeea05267bcdd6b70e9b70c4 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Wed, 11 Feb 2026 09:47:33 +0800 Subject: [PATCH 002/565] Fix: upload image files (#13071) ### What problem does this PR solve? Fix: upload image files ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/llm.py | 124 ++++++++++++++++++++-- api/db/services/dialog_service.py | 169 ++++++++++++++++++++++++++++-- api/db/services/file_service.py | 15 ++- rag/app/picture.py | 3 +- rag/flow/parser/parser.py | 4 +- rag/llm/cv_model.py | 158 +++++++++++++++++++++++----- 6 files changed, 418 insertions(+), 55 deletions(-) diff --git a/agent/component/llm.py b/agent/component/llm.py index e9d8770684c..7538e0d736e 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -125,23 +125,118 @@ def _sys_prompt_and_msg(self, msg, args): msg.append(p) return msg, self.string_format(self._param.sys_prompt, args) + @staticmethod + def _extract_data_images(value) -> list[str]: + imgs = [] + + def walk(v): + if v is None: + return + if isinstance(v, str): + v = v.strip() + if v.startswith("data:image/"): + imgs.append(v) + return + if isinstance(v, (list, tuple, set)): + for item in v: + walk(item) + return + if isinstance(v, dict): + if "content" in v: + walk(v.get("content")) + else: + for item in v.values(): + walk(item) + + walk(value) + return imgs + + @staticmethod + def _uniq_images(images: list[str]) -> list[str]: + seen = set() + uniq = [] + for img in images: + if not isinstance(img, str): + continue + if not img.startswith("data:image/"): + continue + if img in seen: + continue + seen.add(img) + uniq.append(img) + return uniq + + @classmethod + def _remove_data_images(cls, value): + if value is None: + return None + + if isinstance(value, str): + return None if value.strip().startswith("data:image/") else value + + if isinstance(value, list): + cleaned = [] + for item in value: + v = cls._remove_data_images(item) + if v is None: + continue + if isinstance(v, (list, tuple, set, dict)) and not v: + continue + cleaned.append(v) + return cleaned + + if isinstance(value, tuple): + cleaned = [] + for item in value: + v = cls._remove_data_images(item) + if v is None: + continue + if isinstance(v, (list, tuple, set, dict)) and not v: + continue + cleaned.append(v) + return tuple(cleaned) + + if isinstance(value, set): + cleaned = [] + for item in value: + v = cls._remove_data_images(item) + if v is None: + continue + if isinstance(v, (list, tuple, set, dict)) and not v: + continue + cleaned.append(v) + return cleaned + + if isinstance(value, dict): + if value.get("type") in {"image_url", "input_image", "image"} and cls._extract_data_images(value): + return None + + cleaned = {} + for k, item in value.items(): + v = cls._remove_data_images(item) + if v is None: + continue + if isinstance(v, (list, tuple, set, dict)) and not v: + continue + cleaned[k] = v + return cleaned + + return value + def _prepare_prompt_variables(self): + self.imgs = [] if self._param.visual_files_var: - self.imgs = self._canvas.get_variable_value(self._param.visual_files_var) - if not self.imgs: - self.imgs = [] - self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"] - if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value: - self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value, - self._param.llm_id, max_retries=self._param.max_retries, - retry_interval=self._param.delay_after_error - ) - + self.imgs.extend(self._extract_data_images(self._canvas.get_variable_value(self._param.visual_files_var))) args = {} vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs + extracted_imgs = [] for k, o in vars.items(): - args[k] = o["value"] + raw_value = o["value"] + extracted_imgs.extend(self._extract_data_images(raw_value)) + args[k] = self._remove_data_images(raw_value) + if args[k] is None: + args[k] = "" if not isinstance(args[k], str): try: args[k] = json.dumps(args[k], ensure_ascii=False) @@ -149,6 +244,13 @@ def _prepare_prompt_variables(self): args[k] = str(args[k]) self.set_input_value(k, args[k]) + self.imgs = self._uniq_images(self.imgs + extracted_imgs) + if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value: + self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value, + self._param.llm_id, max_retries=self._param.max_retries, + retry_interval=self._param.delay_after_error + ) + msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args) user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt) if self._param.cite and self._canvas.get_reference()["chunks"]: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 66025d13ef8..0ed5d830b3f 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -180,10 +180,24 @@ def get_all_dialogs_by_tenant_id(cls, tenant_id): async def async_chat_solo(dialog, messages, stream=True): + llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id) attachments = "" + image_attachments = [] + image_files = [] if "files" in messages[-1]: - attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) - if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": + if llm_type == "chat": + text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) + else: + text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) + attachments = "\n\n".join(text_attachments) + + if llm_type == "image2text": + llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + factory = llm_model_config.get("llm_factory", "") if llm_model_config else "" + + if llm_type == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) @@ -195,8 +209,13 @@ async def async_chat_solo(dialog, messages, stream=True): msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments + if llm_type == "chat" and image_attachments: + convert_last_user_msg_to_multimodal(msg, image_attachments, factory) if stream: - stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting) + if llm_type == "chat": + stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting) + else: + stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) async for kind, value, state in _stream_with_think_delta(stream_iter): if kind == "marker": flags = {"start_to_think": True} if value == "" else {"end_to_think": True} @@ -204,7 +223,10 @@ async def async_chat_solo(dialog, messages, stream=True): continue yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False} else: - answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + if llm_type == "chat": + answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + else: + answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} @@ -235,6 +257,120 @@ def get_models(dialog): return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl +def split_file_attachments(files: list[dict] | None, raw: bool = False) -> tuple[list[str], list[str] | list[dict]]: + if not files: + return [], [] + + text_attachments = [] + if raw: + file_contents, image_files = FileService.get_files(files, raw=True) + for content in file_contents: + if not isinstance(content, str): + content = str(content) + text_attachments.append(content) + return text_attachments, image_files + + image_attachments = [] + for content in FileService.get_files(files, raw=False): + if not isinstance(content, str): + content = str(content) + if content.strip().startswith("data:"): + image_attachments.append(content.strip()) + continue + text_attachments.append(content) + return text_attachments, image_attachments + + +_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P[A-Za-z0-9+/=\s]+)$") + + +def _parse_data_uri_or_b64(s: str, default_mime: str = "image/png") -> tuple[str, str]: + s = (s or "").strip() + match = _DATA_URI_RE.match(s) + if match: + mime = match.group("mime").strip() + b64 = match.group("b64").strip() + return mime, b64 + return default_mime, s + + +def _normalize_text_from_content(content) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for blk in content: + if isinstance(blk, dict): + if blk.get("type") in {"text", "input_text"}: + txt = blk.get("text") + if txt: + texts.append(str(txt)) + elif "text" in blk and isinstance(blk.get("text"), (str, int, float)): + texts.append(str(blk["text"])) + return "\n".join(texts).strip() + return str(content) + + +def convert_last_user_msg_to_multimodal(msg: list[dict], image_data_uris: list[str], factory: str) -> None: + if not msg or not image_data_uris: + return + + factory_norm = (factory or "").strip().lower() + + for idx in range(len(msg) - 1, -1, -1): + if msg[idx].get("role") != "user": + continue + + original_content = msg[idx].get("content", "") + text = _normalize_text_from_content(original_content) + + if factory_norm == "gemini": + parts = [] + if text: + parts.append({"text": text}) + for image in image_data_uris: + mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png") + parts.append({"inline_data": {"mime_type": mime, "data": b64}}) + msg[idx]["content"] = parts + return + + if factory_norm == "anthropic": + blocks = [] + if text: + blocks.append({"type": "text", "text": text}) + for image in image_data_uris: + mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png") + blocks.append( + { + "type": "image", + "source": {"type": "base64", "media_type": mime, "data": b64}, + } + ) + msg[idx]["content"] = blocks + return + + multimodal_content = [] + if isinstance(original_content, list): + multimodal_content = deepcopy(original_content) + else: + text_content = "" if original_content is None else str(original_content) + if text_content: + multimodal_content.append({"type": "text", "text": text_content}) + + for data_uri in image_data_uris: + image_url = data_uri + if not isinstance(image_url, str): + image_url = str(image_url) + if not image_url.startswith("data:"): + image_url = f"data:image/png;base64,{image_url}" + multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}}) + + msg[idx]["content"] = multimodal_content + return + + BAD_CITATION_PATTERNS = [ re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] @@ -281,12 +417,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return chat_start_ts = timer() - - if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": + llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id) + if llm_type == "image2text": llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + factory = llm_model_config.get("llm_factory", "") if llm_model_config else "" max_tokens = llm_model_config.get("max_tokens", 8192) check_llm_ts = timer() @@ -316,10 +453,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs): questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] attachments_= "" + image_attachments = [] + image_files = [] if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] if "files" in messages[-1]: - attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"])) + if llm_type == "chat": + text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) + else: + text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) + attachments_ = "\n\n".join(text_attachments) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) @@ -464,6 +607,8 @@ async def callback(msg:str): prompt4citation = citation_prompt() msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) + if llm_type == "chat" and image_attachments: + convert_last_user_msg_to_multimodal(msg, image_attachments, factory) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" prompt = msg[0]["content"] @@ -555,7 +700,10 @@ def decorate_answer(answer): ) if stream: - stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf) + if llm_type == "chat": + stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf) + else: + stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf, images=image_files) last_state = None async for kind, value, state in _stream_with_think_delta(stream_iter): last_state = state @@ -572,7 +720,10 @@ def decorate_answer(answer): final["answer"] = "" yield final else: - answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) + if llm_type == "chat": + answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) + else: + answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index eba59a3cf22..498199393e2 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -663,7 +663,7 @@ async def adownload(): return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) @staticmethod - def get_files(files: Union[None, list[dict]]) -> list[str]: + def get_files(files: Union[None, list[dict]], raw: bool = False) -> Union[list[str], tuple[list[str], list[dict]]]: if not files: return [] def image_to_base64(file): @@ -671,10 +671,17 @@ def image_to_base64(file): base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) exe = ThreadPoolExecutor(max_workers=5) threads = [] + imgs = [] for file in files: if file["mime_type"].find("image") >=0: - threads.append(exe.submit(image_to_base64, file)) + if raw: + imgs.append(FileService.get_blob(file["created_by"], file["id"])) + else: + threads.append(exe.submit(image_to_base64, file)) continue threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) - return [th.result() for th in threads] - + + if raw: + return [th.result() for th in threads], imgs + else: + return [th.result() for th in threads] diff --git a/rag/app/picture.py b/rag/app/picture.py index 2ad773a3cd2..67540772ab1 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -51,8 +51,9 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): } ) cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang) + video_prompt = str(parser_config.get("video_prompt", "") or "") ans = asyncio.run( - cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)) + cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename, video_prompt=video_prompt)) callback(0.8, "CV LLM respond: %s ..." % ans[:32]) ans += "\n" + ans tokenize(doc, ans, eng) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 7fcdde860f0..9ed7d65d721 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -161,6 +161,7 @@ def __init__(self): "mkv", ], "output_format": "text", + "prompt": "", }, } @@ -685,7 +686,8 @@ def _video(self, name, blob): self.set_output("output_format", conf["output_format"]) cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"]) - txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name)) + video_prompt = str(conf.get("prompt", "") or "") + txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name, video_prompt=video_prompt)) self.set_output("text", txt) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 9fdd9680a5d..e8f28f7ebaf 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -67,6 +67,61 @@ def _form_history(self, system, history, images=None): hist.append(h) return hist + @staticmethod + def _blob_to_data_url(blob, mime_type="image/png"): + if isinstance(blob, str): + blob = blob.strip() + if blob.startswith("data:") or blob.startswith("http://") or blob.startswith("https://") or blob.startswith("file://"): + return blob + return f"data:{mime_type};base64,{blob}" + if isinstance(blob, BytesIO): + blob = blob.getvalue() + if isinstance(blob, memoryview): + blob = blob.tobytes() + if isinstance(blob, bytearray): + blob = bytes(blob) + if isinstance(blob, bytes): + b64 = base64.b64encode(blob).decode("utf-8") + return f"data:{mime_type};base64,{b64}" + return None + + def _normalize_image(self, image): + if isinstance(image, dict): + inline_data = image.get("inline_data") + if isinstance(inline_data, dict): + mime = inline_data.get("mime_type") or "image/png" + data_url = self._blob_to_data_url(inline_data.get("data"), mime) + if data_url: + return data_url + + image_url = image.get("image_url") + if isinstance(image_url, dict): + data_url = self._blob_to_data_url(image_url.get("url"), image.get("mime_type") or "image/png") + if data_url: + return data_url + if isinstance(image_url, str): + data_url = self._blob_to_data_url(image_url, image.get("mime_type") or "image/png") + if data_url: + return data_url + + if "url" in image: + data_url = self._blob_to_data_url(image.get("url"), image.get("mime_type") or "image/png") + if data_url: + return data_url + + mime = image.get("mime_type") or image.get("media_type") or "image/png" + for key in ("blob", "data"): + if key in image: + data_url = self._blob_to_data_url(image.get(key), mime) + if data_url: + return data_url + + if isinstance(image, (bytes, bytearray, memoryview, BytesIO)): + return self.image2base64(image) + if isinstance(image, str): + return self._blob_to_data_url(image, "image/png") + return self.image2base64(image) + def _image_prompt(self, text, images): if not images: return text @@ -76,7 +131,11 @@ def _image_prompt(self, text, images): pmpt = [{"type": "text", "text": text}] for img in images: - pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}}) + try: + pmpt.append({"type": "image_url", "image_url": {"url": self._normalize_image(img)}}) + except Exception: + logging.warning("[%s] Skip invalid image input in request payload.", self.__class__.__name__) + continue return pmpt async def async_chat(self, system, history, gen_conf, images=None, **kwargs): @@ -248,51 +307,86 @@ def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=N base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs) + @staticmethod + def _extract_text_from_content(content): + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + texts = [] + for blk in content: + if not isinstance(blk, dict): + continue + if blk.get("type") in {"text", "input_text"} and blk.get("text"): + texts.append(str(blk["text"])) + elif "text" in blk and isinstance(blk.get("text"), (str, int, float)): + texts.append(str(blk["text"])) + return "\n".join(texts).strip() + return "" + + def _resolve_video_prompt(self, system, history, **kwargs): + prompt = kwargs.get("video_prompt") or kwargs.get("prompt") + if isinstance(prompt, str) and prompt.strip(): + return prompt.strip() + + for h in reversed(history or []): + if h.get("role") != "user": + continue + txt = self._extract_text_from_content(h.get("content")) + if txt: + return txt + + if isinstance(system, str) and system.strip(): + return system.strip() + + return "Please summarize this video in proper sentences." + async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): if video_bytes: try: - summary, summary_num_tokens = self._process_video(video_bytes, filename) + summary, summary_num_tokens = self._process_video(video_bytes, filename, self._resolve_video_prompt(system, history, **kwargs)) return summary, summary_num_tokens except Exception as e: return "**ERROR**: " + str(e), 0 - return "**ERROR**: Method chat not supported yet.", 0 + return await super().async_chat(system, history, gen_conf, images=images, **kwargs) - def _process_video(self, video_bytes, filename): + def _process_video(self, video_bytes, filename, prompt): from dashscope import MultiModalConversation video_suffix = Path(filename).suffix or ".mp4" + tmp_path = None with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp: tmp.write(video_bytes) tmp_path = tmp.name - video_path = f"file://{tmp_path}" - messages = [ - { - "role": "user", - "content": [ - { - "video": video_path, - "fps": 2, - }, - { - "text": "Please summarize this video in proper sentences.", - }, - ], - } - ] + video_path = f"file://{tmp_path}" + messages = [ + { + "role": "user", + "content": [ + { + "video": video_path, + "fps": 2, + }, + { + "text": prompt, + }, + ], + } + ] - def call_api(): - response = MultiModalConversation.call( - api_key=self.api_key, - model=self.model_name, - messages=messages, - ) - if response.get("message"): - raise Exception(response["message"]) - summary = response["output"]["choices"][0]["message"].content[0]["text"] - return summary, num_tokens_from_string(summary) + def call_api(): + response = MultiModalConversation.call( + api_key=self.api_key, + model=self.model_name, + messages=messages, + ) + if response.get("message"): + raise Exception(response["message"]) + summary = response["output"]["choices"][0]["message"].content[0]["text"] + return summary, num_tokens_from_string(summary) + try: try: return call_api() except Exception as e1: @@ -303,6 +397,12 @@ def call_api(): return call_api() except Exception as e2: raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}") + finally: + if tmp_path and os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except Exception: + logging.warning("[QWenCV] Failed to cleanup temp video file: %s", tmp_path) class HunyuanCV(GptV4): From b1e7111ae11c8c132f92f5ff31f13ad52e697a1e Mon Sep 17 00:00:00 2001 From: Ahmad Intisar <168020872+ahmadintisar@users.noreply.github.com> Date: Wed, 11 Feb 2026 06:49:48 +0500 Subject: [PATCH 003/565] Fix: Correct Gemini embedding model name in llm_factories.json (#13051) ## Problem RAGFlow was using incorrect model names for Google Gemini embeddings: - `embedding-001` (missing `gemini-` prefix) - `text-embedding-004` (OpenAI model name, not Gemini) This caused API errors when users tried to use Gemini embeddings. ## Solution - Updated `conf/llm_factories.json` to use the correct model name: `gemini-embedding-001` - Removed the incorrect `text-embedding-004` entry - Added volume mount in `docker-compose.yml` to ensure config changes persist ## Testing Tested with a valid Gemini API key and confirmed embeddings now work correctly. ## Changes - Modified `conf/llm_factories.json` - Modified `docker/docker-compose.yml` --------- Co-authored-by: Ahmad Intisar Co-authored-by: Kevin Hu From ffb87ea259fe12706e0f245c39eb8014527baf63 Mon Sep 17 00:00:00 2001 From: Jim Smith Date: Tue, 10 Feb 2026 19:51:53 -0700 Subject: [PATCH 004/565] Fix: Make time_utils tests timezone-independent (#13100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Replace hardcoded CST (UTC+8) expected values in `test_time_utils.py` with dynamically computed local-time expectations using `time.localtime()` and `time.mktime()` - Tests previously failed in any timezone other than UTC+8; they now pass regardless of the system's local timezone ## Test plan - [x] `uv run pytest test/unit_test/ -v` — 317 passed, 25 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Jim Smith Co-authored-by: Claude Opus 4.6 --- test/unit_test/common/test_time_utils.py | 58 ++++++++++++------------ 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/test/unit_test/common/test_time_utils.py b/test/unit_test/common/test_time_utils.py index 7efc1d2902f..6a622e798a6 100644 --- a/test/unit_test/common/test_time_utils.py +++ b/test/unit_test/common/test_time_utils.py @@ -68,22 +68,23 @@ def test_basic_timestamp_conversion(self): # Test with a specific timestamp timestamp = 1704067200000 # 2024-01-01 00:00:00 UTC result = timestamp_to_date(timestamp) - expected = "2024-01-01 08:00:00" + expected = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp / 1000)) assert result == expected def test_custom_format_string(self): """Test conversion with custom format string""" timestamp = 1704067200000 # 2024-01-01 00:00:00 UTC + local = time.localtime(timestamp / 1000) # Test different format strings result1 = timestamp_to_date(timestamp, "%Y-%m-%d") - assert result1 == "2024-01-01" + assert result1 == time.strftime("%Y-%m-%d", local) result2 = timestamp_to_date(timestamp, "%H:%M:%S") - assert result2 == "08:00:00" + assert result2 == time.strftime("%H:%M:%S", local) result3 = timestamp_to_date(timestamp, "%Y/%m/%d %H:%M") - assert result3 == "2024/01/01 08:00" + assert result3 == time.strftime("%Y/%m/%d %H:%M", local) def test_zero_timestamp(self): """Test conversion with zero timestamp (epoch)""" @@ -104,14 +105,14 @@ def test_string_timestamp_input(self): """Test that string timestamp input is handled correctly""" timestamp_str = "1704067200000" result = timestamp_to_date(timestamp_str) - expected = "2024-01-01 08:00:00" + expected = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(timestamp_str) / 1000)) assert result == expected def test_float_timestamp_input(self): """Test that float timestamp input is handled correctly""" timestamp_float = 1704067200000.0 result = timestamp_to_date(timestamp_float) - expected = "2024-01-01 08:00:00" + expected = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(timestamp_float) / 1000)) assert result == expected def test_different_timezones_handled(self): @@ -130,19 +131,18 @@ def test_millisecond_precision(self): timestamp = 1704067200123 # 2024-01-01 00:00:00.123 UTC result = timestamp_to_date(timestamp) - # Should still return "08:00:00" since milliseconds are truncated - assert "08:00:00" in result + # Milliseconds are truncated, so result should match the base timestamp + expected = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(timestamp) / 1000)) + assert result == expected def test_various_timestamps(self): """Test conversion with various timestamp values""" - test_cases = [ - (1609459200000, "2021-01-01 08:00:00"), # 2020-12-31 16:00:00 UTC - (4102444800000, "2100-01-01"), # Future date - ] + test_cases = [1609459200000, 4102444800000] - for timestamp, expected_prefix in test_cases: + for timestamp in test_cases: result = timestamp_to_date(timestamp) - assert expected_prefix in result + expected = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp / 1000)) + assert result == expected def test_return_type_always_string(self): """Test that return type is always string regardless of input""" @@ -176,21 +176,22 @@ def test_basic_date_string_conversion(self): """Test basic date string to timestamp conversion with default format""" date_string = "2024-01-01 08:00:00" result = date_string_to_timestamp(date_string) - expected = 1704067200000 + expected = int(time.mktime(time.strptime(date_string, "%Y-%m-%d %H:%M:%S")) * 1000) assert result == expected def test_custom_format_string(self): """Test conversion with custom format strings""" # Test different date formats test_cases = [ - ("2024-01-01", "%Y-%m-%d", 1704038400000), - ("2024/01/01 12:30:45", "%Y/%m/%d %H:%M:%S", 1704083445000), - ("01-01-2024", "%m-%d-%Y", 1704038400000), - ("20240101", "%Y%m%d", 1704038400000), + ("2024-01-01", "%Y-%m-%d"), + ("2024/01/01 12:30:45", "%Y/%m/%d %H:%M:%S"), + ("01-01-2024", "%m-%d-%Y"), + ("20240101", "%Y%m%d"), ] - for date_string, format_string, expected in test_cases: + for date_string, format_string in test_cases: result = date_string_to_timestamp(date_string, format_string) + expected = int(time.mktime(time.strptime(date_string, format_string)) * 1000) assert result == expected def test_return_type_integer(self): @@ -213,14 +214,15 @@ def test_timestamp_in_milliseconds(self): def test_different_dates(self): """Test conversion with various date strings""" test_cases = [ - ("2024-01-01 00:00:00", 1704038400000), - ("2020-12-31 16:00:00", 1609401600000), - ("2023-06-15 14:30:00", 1686810600000), - ("2025-12-25 23:59:59", 1766678399000), + "2024-01-01 00:00:00", + "2020-12-31 16:00:00", + "2023-06-15 14:30:00", + "2025-12-25 23:59:59", ] - for date_string, expected in test_cases: + for date_string in test_cases: result = date_string_to_timestamp(date_string) + expected = int(time.mktime(time.strptime(date_string, "%Y-%m-%d %H:%M:%S")) * 1000) assert result == expected def test_epoch_date(self): @@ -236,15 +238,15 @@ def test_leap_year_date(self): """Test conversion with leap year date""" date_string = "2024-02-29 12:00:00" # Valid leap year date result = date_string_to_timestamp(date_string) - expected = 1709179200000 # 2024-02-29 12:00:00 in milliseconds + expected = int(time.mktime(time.strptime(date_string, "%Y-%m-%d %H:%M:%S")) * 1000) assert result == expected def test_date_only_string(self): """Test conversion with date-only format (assumes 00:00:00 time)""" date_string = "2024-01-01" result = date_string_to_timestamp(date_string, "%Y-%m-%d") - # Should be equivalent to "2024-01-01 00:00:00" - expected = 1704038400000 + # Should be equivalent to "2024-01-01 00:00:00" in local timezone + expected = int(time.mktime(time.strptime(date_string, "%Y-%m-%d")) * 1000) assert result == expected def test_with_whitespace(self): From b7009b876668ce7b94c49b97241462184ab665da Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Wed, 11 Feb 2026 11:21:03 +0800 Subject: [PATCH 005/565] Refa: Change aliyun repo. (#13103) ### Type of change - [x] Refactoring --- README_tzh.md | 2 +- README_zh.md | 2 +- docker/.env | 4 ++-- docker/README.md | 2 +- docs/configurations.md | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README_tzh.md b/README_tzh.md index d46d06077ce..b63d0dfabb7 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -217,7 +217,7 @@ > 如果你遇到 Docker 映像檔拉不下來的問題,可以在 **docker/.env** 檔案內根據變數 `RAGFLOW_IMAGE` 的註解提示選擇華為雲或阿里雲的對應映像。 > > - 華為雲鏡像名:`swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow` -> - 阿里雲鏡像名:`registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow` +> - 阿里雲鏡像名:`infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow` 4. 伺服器啟動成功後再次確認伺服器狀態: diff --git a/README_zh.md b/README_zh.md index 5b194daa0ff..bfecfc00a34 100644 --- a/README_zh.md +++ b/README_zh.md @@ -218,7 +218,7 @@ > 如果你遇到 Docker 镜像拉不下来的问题,可以在 **docker/.env** 文件内根据变量 `RAGFLOW_IMAGE` 的注释提示选择华为云或者阿里云的相应镜像。 > > - 华为云镜像名:`swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow` - > - 阿里云镜像名:`registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow` + > - 阿里云镜像名:`infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow` 4. 服务器启动成功后再次确认服务器状态: diff --git a/docker/.env b/docker/.env index 7e1bdf801bc..57a42ac1abc 100644 --- a/docker/.env +++ b/docker/.env @@ -158,11 +158,11 @@ RAGFLOW_IMAGE=infiniflow/ragflow:v0.24.0 # If you cannot download the RAGFlow Docker image: # RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.24.0 -# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.24.0 +# RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:v0.24.0 # # - For the `nightly` edition, uncomment either of the following: # RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly -# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly +# RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly # The embedding service image, model and port. # Important: To enable the embedding service, you need to uncomment one of the following two lines: diff --git a/docker/README.md b/docker/README.md index c6422bad8c7..b5f9bc66712 100644 --- a/docker/README.md +++ b/docker/README.md @@ -87,7 +87,7 @@ The [.env](./.env) file contains important environment variables for Docker. > > - For the `nightly` edition: > - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, -> - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. +> - `RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly`. ### Timezone diff --git a/docs/configurations.md b/docs/configurations.md index 2b274c8e9b2..ca7e67e29ea 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -110,7 +110,7 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - For the `nightly` edition: - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, - - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. + - `RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly`. ::: ### Embedding service From 4d840d863cb56ac60a28cafde95ecee9bbb298be Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:58:08 +0800 Subject: [PATCH 006/565] Docs: Updated sandbox reference (#13114) ### What problem does this PR solve? Updated sandbox reference. ### Type of change - [x] Documentation Update --- docs/guides/agent/agent_component_reference/code.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/agent/agent_component_reference/code.mdx b/docs/guides/agent/agent_component_reference/code.mdx index a9472ca5e03..d0af92cc184 100644 --- a/docs/guides/agent/agent_component_reference/code.mdx +++ b/docs/guides/agent/agent_component_reference/code.mdx @@ -23,7 +23,7 @@ We use gVisor to isolate code execution from the host system. Please follow [the ### 2. Ensure Sandbox is properly installed -RAGFlow Sandbox is a secure, pluggable code execution backend. It serves as the code executor for the **Code** component. Please follow the [instructions here](https://github.com/infiniflow/ragflow/tree/main/sandbox) to install RAGFlow Sandbox. +RAGFlow Sandbox is a secure, pluggable code execution backend. It serves as the code executor for the **Code** component. Please follow the [instructions here](https://github.com/infiniflow/ragflow/tree/main/agent/sandbox) to install RAGFlow Sandbox. :::note Docker client version The executor manager image now bundles Docker CLI `29.1.0` (API 1.44+). Older images shipped Docker 24.x and will fail against newer Docker daemons with `client version 1.43 is too old`. Pull the latest `infiniflow/sandbox-executor-manager:latest` or rebuild it in `./sandbox/executor_manager` if you encounter this error. From e8b68f2fb3163d2795022558908e54ff8df1b962 Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:08:56 +0800 Subject: [PATCH 007/565] Docs: Updated v0.24.0 release notes (#13115) ### What problem does this PR solve? Updated v0.24.0 release notes. ### Type of change - [x] Documentation Update --- docs/release_notes.md | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/docs/release_notes.md b/docs/release_notes.md index e6dbdc4d83f..d222f977f51 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -17,27 +17,36 @@ Released on February 10, 2026. - Memory - Introduces APIs and an SDK for developer integration. - - Adds Memory extraction log display in the console for improved debugging and tracing. + - Outputs Memory extraction log to the console for debugging and tracing. - Dataset - - Added support for batch management of Metadata. - - Renamed "ToC (Table of Contents)" to "PageIndex". + - Supports batch metadata management. + - Renames "ToC (Table of Contents)" to "PageIndex". See [here](./guides/dataset/extract_table_of_contents.md). - Agent - - Launches a new Chat-like Agent conversation management interface that retains Sessions and dialogue history. - - Introduces a multi-Sandbox mechanism, currently supporting local gVisor and Alibaba Cloud, with compatibility for mainstream Sandbox APIs (configurable in the Admin page). + - Launches a new Chat-like Agent conversation management interface that retains sessions and dialogue history. + - Introduces a multi-Sandbox mechanism supporting local gVisor and Alibaba Cloud, with compatibility for mainstream Sandbox APIs (configurable in the Admin page). - Chat - - Adds a new "Thinking" mode and removed the previous "Reasoning" configuration option. + - Adds a new "Thinking" mode and removes the previous "Reasoning" configuration option. - Optimizes retrieval strategies for deep-research scenarios, enhancing recall accuracy. - Admin - - Adds support for configuring multiple Admin accounts. + - Supports configuring multiple Admin accounts. - Model configuration center - Adds a model connection test feature when adding new models. -- Ecosystem - - Adds support for OceanBase as a database alternative to MySQL. - - Adds support for PaddleOCR-VL. -- Model - - Adds new model support for Kimi 2.5, Stepfun 3, and doubao-embedding-vision, among others. -- Data sources - - Adds new data source integrations for Zendesk, Bitbucket, and others. + +### MySQL alternative + +- Supports OceanBase as an alternative to MySQL. + +### Model support + +- Kimi 2.5 +- Stepfun 3 +- doubao-embedding-vision +- PaddleOCR-VL + +### Data sources + +- Zendesk +- Bitbucket ## v0.23.1 From d75ab78a27b72af27650ee4cff8d7e39a0765580 Mon Sep 17 00:00:00 2001 From: TheoG <45789400+TheoGuil@users.noreply.github.com> Date: Wed, 11 Feb 2026 13:11:56 +0100 Subject: [PATCH 008/565] Fix graphrag extraction (#13113) ### What problem does this PR solve? Fix error when extracting the graph. A string is expected, but a tuple was provided. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/graphrag/general/extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rag/graphrag/general/extractor.py b/rag/graphrag/general/extractor.py index ccb0d3ba8bd..3328604b67a 100644 --- a/rag/graphrag/general/extractor.py +++ b/rag/graphrag/general/extractor.py @@ -78,7 +78,7 @@ def _chat(self, system, history, gen_conf={}, task_id=""): raise TaskCanceledException(f"Task {task_id} was cancelled") try: response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf)) - response = re.sub(r"^.*", "", response, flags=re.DOTALL) + response = re.sub(r"^.*", "", response[0], flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) From 3bdb7e621c7f52a32cf2f2d116f4477a294fcacb Mon Sep 17 00:00:00 2001 From: Levi <81591061+levischd@users.noreply.github.com> Date: Thu, 12 Feb 2026 03:09:35 +0100 Subject: [PATCH 009/565] Fix: persist SSO auth token on root route loader (#12784) ### What problem does this PR solve? This PR fixes SSO/OIDC login persistence after the Vite migration #12568. Because wrappers are ignored by React Router, the OAuth callback never stored the auth token in localStorage, causing auth to only work while ?auth= stayed in the URL. We move that logic into a route loader and remove the Bearer prefix for the signed token so the backend accepts it. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Contribution during my time at RAGcon GmbH. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- web/src/routes.tsx | 14 ++++++++++++-- web/src/utils/authorization-util.ts | 4 +--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/web/src/routes.tsx b/web/src/routes.tsx index ab05987c926..198b95f496d 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -1,7 +1,8 @@ import { lazy, memo, Suspense } from 'react'; -import { createBrowserRouter, Navigate, type RouteObject } from 'react-router'; +import { createBrowserRouter, Navigate, redirect, type RouteObject } from 'react-router'; import FallbackComponent from './components/fallback-component'; import { IS_ENTERPRISE } from './pages/admin/utils'; +import authorizationUtil from './utils/authorization-util'; export enum Routes { Root = '/', @@ -141,7 +142,16 @@ const routeConfigOptions = [ path: Routes.Root, layout: false, Component: () => import('@/layouts/next'), - wrappers: ['@/wrappers/auth'], + loader: ({ request }) => { + const url = new URL(request.url); + const auth = url.searchParams.get('auth'); + if (auth) { + authorizationUtil.setAuthorization(auth); + url.searchParams.delete('auth'); + return redirect(`${url.pathname}${url.search}`); + } + return null; + }, children: [ { path: Routes.Root, diff --git a/web/src/utils/authorization-util.ts b/web/src/utils/authorization-util.ts index e25e4915f18..328def48de4 100644 --- a/web/src/utils/authorization-util.ts +++ b/web/src/utils/authorization-util.ts @@ -49,9 +49,7 @@ const storage = { export const getAuthorization = () => { const auth = getSearchValue('auth'); - const authorization = auth - ? 'Bearer ' + auth - : storage.getAuthorization() || ''; + const authorization = auth ? auth : storage.getAuthorization() || ''; return authorization; }; From 1e629a613aca7cfd28ad1e2d57640099bc983657 Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 12 Feb 2026 10:11:50 +0800 Subject: [PATCH 010/565] Refactor: split memory API into gateway and service layers (#13111) ### What problem does this PR solve? Decouple the memory API into a gateway layer (for routing/param parse) and a service layer (for business logic). ### Type of change - [x] Refactoring --- api/apps/__init__.py | 8 +- api/apps/restful_apis/memory_api.py | 173 ++++++++++++++ api/apps/sdk/memories.py | 291 ------------------------ api/apps/services/__init__.py | 0 api/apps/services/memory_api_service.py | 223 ++++++++++++++++++ common/exceptions.py | 10 + 6 files changed, 413 insertions(+), 292 deletions(-) create mode 100644 api/apps/restful_apis/memory_api.py delete mode 100644 api/apps/sdk/memories.py create mode 100644 api/apps/services/__init__.py create mode 100644 api/apps/services/memory_api_service.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 7feae696e35..89078d9fb81 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -244,6 +244,10 @@ def search_pages_path(page_path): path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") ] app_path_list.extend(api_path_list) + restful_api_path_list = [ + path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".") + ] + app_path_list.extend(restful_api_path_list) return app_path_list @@ -263,8 +267,9 @@ def register_page(page_path): spec.loader.exec_module(page) page_name = getattr(page, "page_name", page_name) sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/" + restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/" url_prefix = ( - f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}" + f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" ) app.register_blueprint(page.manager, url_prefix=url_prefix) @@ -274,6 +279,7 @@ def register_page(page_path): pages_dir = [ Path(__file__).parent, Path(__file__).parent.parent / "api" / "apps", + Path(__file__).parent.parent / "api" / "apps" / "restful_apis", Path(__file__).parent.parent / "api" / "apps" / "sdk", ] diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py new file mode 100644 index 00000000000..53c7f866e27 --- /dev/null +++ b/api/apps/restful_apis/memory_api.py @@ -0,0 +1,173 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import time + +from quart import request +from common.constants import RetCode +from common.exceptions import ArgumentException, NotFoundException +from api.apps import login_required +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result +from api.apps.services import memory_api_service + + +@manager.route("/memories", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "memory_type", "embd_id", "llm_id") +async def create_memory(): + timing_enabled = os.getenv("RAGFLOW_API_TIMING") + t_start = time.perf_counter() if timing_enabled else None + req = await get_request_json() + t_parsed = time.perf_counter() if timing_enabled else None + try: + memory_info = { + "name": req["name"], + "memory_type": req["memory_type"], + "embd_id": req["embd_id"], + "llm_id": req["llm_id"] + } + success, res = await memory_api_service.create_memory(memory_info) + if timing_enabled: + logging.info( + "api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_parsed) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + if success: + return get_json_result(message=True, data=res) + else: + return get_json_result(message=res, code=RetCode.SERVER_ERROR) + + except ArgumentException as arg_error: + logging.error(arg_error) + if timing_enabled: + logging.info( + "api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s", + str(arg_error), + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + return get_error_argument_result(str(arg_error)) + + except Exception as e: + logging.error(e) + if timing_enabled: + logging.info( + "api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s", + str(e), + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["PUT"]) # noqa: F821 +@login_required +async def update_memory(memory_id): + req = await get_request_json() + new_settings = {k: req[k] for k in [ + "name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature", + "avatar", "description", "system_prompt", "user_prompt" + ] if k in req} + try: + success, res = await memory_api_service.update_memory(memory_id, new_settings) + if success: + return get_json_result(message=True, data=res) + else: + return get_json_result(message=res, code=RetCode.SERVER_ERROR) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except ArgumentException as arg_error: + logging.error(arg_error) + return get_error_argument_result(str(arg_error)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_memory(memory_id): + try: + await memory_api_service.delete_memory(memory_id) + return get_json_result(message=True) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories", methods=["GET"]) # noqa: F821 +@login_required +async def list_memory(): + filter_params = { + k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args + } + keywords = request.args.get("keywords") + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 50)) + try: + res = await memory_api_service.list_memory(filter_params, keywords, page, page_size) + return get_json_result(message=True, data=res) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories//config", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_config(memory_id): + try: + res = await memory_api_service.get_memory_config(memory_id) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_messages(memory_id): + args = request.args + agent_ids = args.getlist("agent_id") + if len(agent_ids) == 1 and ',' in agent_ids[0]: + agent_ids = agent_ids[0].split(',') + keywords = args.get("keywords", "") + keywords = keywords.strip() + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + try: + res = await memory_api_service.get_memory_messages( + memory_id, agent_ids, keywords, page, page_size + ) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") diff --git a/api/apps/sdk/memories.py b/api/apps/sdk/memories.py deleted file mode 100644 index ada4b34fab9..00000000000 --- a/api/apps/sdk/memories.py +++ /dev/null @@ -1,291 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import os -import time - -from quart import request -from api.apps import login_required, current_user -from api.db import TenantPermission -from api.db.services.memory_service import MemoryService -from api.db.services.user_service import UserTenantService -from api.db.services.canvas_service import UserCanvasService -from api.db.services.task_service import TaskService -from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default -from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result -from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human -from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT -from memory.services.messages import MessageService -from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, RetCode, ForgettingPolicy - - -@manager.route("/memories", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("name", "memory_type", "embd_id", "llm_id") -async def create_memory(): - timing_enabled = os.getenv("RAGFLOW_API_TIMING") - t_start = time.perf_counter() if timing_enabled else None - req = await get_request_json() - t_parsed = time.perf_counter() if timing_enabled else None - # check name length - name = req["name"] - memory_name = name.strip() - if len(memory_name) == 0: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result("Memory name cannot be empty or whitespace.") - if len(memory_name) > MEMORY_NAME_LIMIT: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") - # check memory_type valid - if not isinstance(req["memory_type"], list): - if timing_enabled: - logging.info( - "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result("Memory type must be a list.") - memory_type = set(req["memory_type"]) - invalid_type = memory_type - {e.name.lower() for e in MemoryType} - if invalid_type: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") - memory_type = list(memory_type) - - try: - t_before_db = time.perf_counter() if timing_enabled else None - res, memory = MemoryService.create_memory( - tenant_id=current_user.id, - name=memory_name, - memory_type=memory_type, - embd_id=req["embd_id"], - llm_id=req["llm_id"] - ) - if timing_enabled: - logging.info( - "api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (t_before_db - t_parsed) * 1000, - (time.perf_counter() - t_before_db) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - - if res: - return get_json_result(message=True, data=format_ret_data_from_memory(memory)) - else: - return get_json_result(message=memory, code=RetCode.SERVER_ERROR) - - except Exception as e: - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories/", methods=["PUT"]) # noqa: F821 -@login_required -async def update_memory(memory_id): - req = await get_request_json() - update_dict = {} - # check name length - if "name" in req: - name = req["name"] - memory_name = name.strip() - if len(memory_name) == 0: - return get_error_argument_result("Memory name cannot be empty or whitespace.") - if len(memory_name) > MEMORY_NAME_LIMIT: - return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") - update_dict["name"] = memory_name - # check permissions valid - if req.get("permissions"): - if req["permissions"] not in [e.value for e in TenantPermission]: - return get_error_argument_result(f"Unknown permission '{req['permissions']}'.") - update_dict["permissions"] = req["permissions"] - if req.get("llm_id"): - update_dict["llm_id"] = req["llm_id"] - if req.get("embd_id"): - update_dict["embd_id"] = req["embd_id"] - if req.get("memory_type"): - memory_type = set(req["memory_type"]) - invalid_type = memory_type - {e.name.lower() for e in MemoryType} - if invalid_type: - return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") - update_dict["memory_type"] = list(memory_type) - # check memory_size valid - if req.get("memory_size"): - if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT: - return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") - update_dict["memory_size"] = req["memory_size"] - # check forgetting_policy valid - if req.get("forgetting_policy"): - if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: - return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.") - update_dict["forgetting_policy"] = req["forgetting_policy"] - # check temperature valid - if "temperature" in req: - temperature = float(req["temperature"]) - if not 0 <= temperature <= 1: - return get_error_argument_result("Temperature should be in range [0, 1].") - update_dict["temperature"] = temperature - # allow update to empty fields - for field in ["avatar", "description", "system_prompt", "user_prompt"]: - if field in req: - update_dict[field] = req[field] - current_memory = MemoryService.get_by_memory_id(memory_id) - if not current_memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - memory_dict = current_memory.to_dict() - memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) - to_update = {} - for k, v in update_dict.items(): - if isinstance(v, list) and set(memory_dict[k]) != set(v): - to_update[k] = v - elif memory_dict[k] != v: - to_update[k] = v - - if not to_update: - return get_json_result(message=True, data=memory_dict) - # check memory empty when update embd_id, memory_type - memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) - not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] - if not_allowed_update: - return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.") - if "memory_type" in to_update: - if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): - # update old default prompt, assemble a new one - to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) - - try: - MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) - updated_memory = MemoryService.get_by_memory_id(memory_id) - return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory)) - - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories/", methods=["DELETE"]) # noqa: F821 -@login_required -async def delete_memory(memory_id): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(message=True, code=RetCode.NOT_FOUND) - try: - MemoryService.delete_memory(memory_id) - if MessageService.has_index(memory.tenant_id, memory_id): - MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) - return get_json_result(message=True) - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories", methods=["GET"]) # noqa: F821 -@login_required -async def list_memory(): - args = request.args - try: - tenant_ids = args.getlist("tenant_id") - memory_types = args.getlist("memory_type") - storage_type = args.get("storage_type") - keywords = args.get("keywords", "") - page = int(args.get("page", 1)) - page_size = int(args.get("page_size", 50)) - # make filter dict - filter_dict: dict = {"storage_type": storage_type} - if not tenant_ids: - # restrict to current user's tenants - user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) - filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] - else: - if len(tenant_ids) == 1 and ',' in tenant_ids[0]: - tenant_ids = tenant_ids[0].split(',') - filter_dict["tenant_id"] = tenant_ids - if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: - memory_types = memory_types[0].split(',') - filter_dict["memory_type"] = memory_types - - memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) - [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] - return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count}) - - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories//config", methods=["GET"]) # noqa: F821 -@login_required -async def get_memory_config(memory_id): - memory = MemoryService.get_with_owner_name_by_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - return get_json_result(message=True, data=format_ret_data_from_memory(memory)) - - -@manager.route("/memories/", methods=["GET"]) # noqa: F821 -@login_required -async def get_memory_detail(memory_id): - args = request.args - agent_ids = args.getlist("agent_id") - if len(agent_ids) == 1 and ',' in agent_ids[0]: - agent_ids = agent_ids[0].split(',') - keywords = args.get("keywords", "") - keywords = keywords.strip() - page = int(args.get("page", 1)) - page_size = int(args.get("page_size", 50)) - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - messages = MessageService.list_message( - memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) - agent_name_mapping = {} - extract_task_mapping = {} - if messages["message_list"]: - agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) - agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} - task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) - if task_list: - task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task - for task in task_list: - # the 'digest' field carries the source_id when a task is created, so use 'digest' as key - extract_task_mapping.update({int(task["digest"]): task}) - for message in messages["message_list"]: - message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") - message["task"] = extract_task_mapping.get(message["message_id"], {}) - for extract_msg in message["extract"]: - extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") - return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) diff --git a/api/apps/services/__init__.py b/api/apps/services/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py new file mode 100644 index 00000000000..53bb0f6e9ef --- /dev/null +++ b/api/apps/services/memory_api_service.py @@ -0,0 +1,223 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.apps import current_user +from api.db import TenantPermission +from api.db.services.memory_service import MemoryService +from api.db.services.user_service import UserTenantService +from api.db.services.canvas_service import UserCanvasService +from api.db.services.task_service import TaskService +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default +from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT +from memory.services.messages import MessageService +from memory.utils.prompt_util import PromptAssembler +from common.constants import MemoryType, ForgettingPolicy +from common.exceptions import ArgumentException, NotFoundException + + +async def create_memory(memory_info: dict): + """ + :param memory_info: { + "name": str, + "memory_type": list[str], + "embd_id": str, + "llm_id": str + } + """ + # check name length + name = memory_info["name"] + memory_name = name.strip() + if len(memory_name) == 0: + raise ArgumentException("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + # check memory_type valid + if not isinstance(memory_info["memory_type"], list): + raise ArgumentException("Memory type must be a list.") + memory_type = set(memory_info["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + raise ArgumentException(f"Memory type '{invalid_type}' is not supported.") + memory_type = list(memory_type) + success, res = MemoryService.create_memory( + tenant_id=current_user.id, + name=memory_name, + memory_type=memory_type, + embd_id=memory_info["embd_id"], + llm_id=memory_info["llm_id"] + ) + if success: + return True, format_ret_data_from_memory(res) + else: + return False, res + + +async def update_memory(memory_id: str, new_memory_setting: dict): + """ + :param memory_id: str + :param new_memory_setting: { + "name": str, + "permissions": str, + "llm_id": str, + "embd_id": str, + "memory_type": list[str], + "memory_size": int, + "forgetting_policy": str, + "temperature": float, + "avatar": str, + "description": str, + "system_prompt": str, + "user_prompt": str + } + """ + update_dict = {} + # check name length + if "name" in new_memory_setting: + name = new_memory_setting["name"] + memory_name = name.strip() + if len(memory_name) == 0: + raise ArgumentException("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + update_dict["name"] = memory_name + # check permissions valid + if new_memory_setting.get("permissions"): + if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: + raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") + update_dict["permissions"] = new_memory_setting["permissions"] + if new_memory_setting.get("llm_id"): + update_dict["llm_id"] = new_memory_setting["llm_id"] + if new_memory_setting.get("embd_id"): + update_dict["embd_id"] = new_memory_setting["embd_id"] + if new_memory_setting.get("memory_type"): + memory_type = set(new_memory_setting["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + raise ArgumentException(f"Memory type '{invalid_type}' is not supported.") + update_dict["memory_type"] = list(memory_type) + # check memory_size valid + if new_memory_setting.get("memory_size"): + if not 0 < int(new_memory_setting["memory_size"]) <= MEMORY_SIZE_LIMIT: + raise ArgumentException(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") + update_dict["memory_size"] = new_memory_setting["memory_size"] + # check forgetting_policy valid + if new_memory_setting.get("forgetting_policy"): + if new_memory_setting["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: + raise ArgumentException(f"Forgetting policy '{new_memory_setting['forgetting_policy']}' is not supported.") + update_dict["forgetting_policy"] = new_memory_setting["forgetting_policy"] + # check temperature valid + if "temperature" in new_memory_setting: + temperature = float(new_memory_setting["temperature"]) + if not 0 <= temperature <= 1: + raise ArgumentException("Temperature should be in range [0, 1].") + update_dict["temperature"] = temperature + # allow update to empty fields + for field in ["avatar", "description", "system_prompt", "user_prompt"]: + if field in new_memory_setting: + update_dict[field] = new_memory_setting[field] + current_memory = MemoryService.get_by_memory_id(memory_id) + if not current_memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + memory_dict = current_memory.to_dict() + memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) + to_update = {} + for k, v in update_dict.items(): + if isinstance(v, list) and set(memory_dict[k]) != set(v): + to_update[k] = v + elif memory_dict[k] != v: + to_update[k] = v + + if not to_update: + return True, memory_dict + # check memory empty when update embd_id, memory_type + memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) + not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] + if not_allowed_update: + raise ArgumentException(f"Can't update {not_allowed_update} when memory isn't empty.") + if "memory_type" in to_update: + if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): + # update old default prompt, assemble a new one + to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) + + MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) + updated_memory = MemoryService.get_by_memory_id(memory_id) + return True, format_ret_data_from_memory(updated_memory) + + +async def delete_memory(memory_id): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + MemoryService.delete_memory(memory_id) + if MessageService.has_index(memory.tenant_id, memory_id): + MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) + return True + + +async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50): + filter_dict: dict = {"storage_type": filter_params.get("storage_type")} + tenant_ids = filter_params.get("tenant_id") + if not filter_params.get("tenant_id"): + # restrict to current user's tenants + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) + filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + else: + if len(tenant_ids) == 1 and ',' in tenant_ids[0]: + tenant_ids = tenant_ids[0].split(',') + filter_dict["tenant_id"] = tenant_ids + memory_types = filter_params.get("memory_type") + if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: + memory_types = memory_types[0].split(',') + filter_dict["memory_type"] = memory_types + + memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) + [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] + return { + "memory_list": memory_list, "total_count": count + } + + +async def get_memory_config(memory_id): + memory = MemoryService.get_with_owner_name_by_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + return format_ret_data_from_memory(memory) + + +async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + messages = MessageService.list_message( + memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) + agent_name_mapping = {} + extract_task_mapping = {} + if messages["message_list"]: + agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) + agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} + task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) + if task_list: + task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task + for task in task_list: + # the 'digest' field carries the source_id when a task is created, so use 'digest' as key + extract_task_mapping.update({int(task["digest"]): task}) + for message in messages["message_list"]: + message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") + message["task"] = extract_task_mapping.get(message["message_id"], {}) + for extract_msg in message["extract"]: + extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") + return {"messages": messages, "storage_type": memory.storage_type} diff --git a/common/exceptions.py b/common/exceptions.py index c0caac4842e..9511304720a 100644 --- a/common/exceptions.py +++ b/common/exceptions.py @@ -16,3 +16,13 @@ class TaskCanceledException(Exception): def __init__(self, msg): self.msg = msg + + +class ArgumentException(Exception): + def __init__(self, msg): + self.msg = msg + + +class NotFoundException(Exception): + def __init__(self, msg): + self.msg = msg From 7351762316af2e688a6da839e668cf7a22bb6e43 Mon Sep 17 00:00:00 2001 From: Liu An Date: Thu, 12 Feb 2026 10:15:09 +0800 Subject: [PATCH 011/565] Refa: test file location for better organization (#13107) ### What problem does this PR solve? Renamed test/unit/test_delete_query_construction.py to test/unit_test/common/test_delete_query_construction.py to align with the project's directory structure and improve test categorization. ### Type of change - [x] Refactoring --- test/{unit => unit_test/common}/test_delete_query_construction.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{unit => unit_test/common}/test_delete_query_construction.py (100%) diff --git a/test/unit/test_delete_query_construction.py b/test/unit_test/common/test_delete_query_construction.py similarity index 100% rename from test/unit/test_delete_query_construction.py rename to test/unit_test/common/test_delete_query_construction.py From ddc0959ef87df684830b4e9cc5c3aaafe2d40450 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Thu, 12 Feb 2026 13:42:12 +0800 Subject: [PATCH 012/565] Fix: Bugs fixed (#13109) (#13122) ### What problem does this PR solve? Fix: Bugs fixed (#13109) - chat pdf preview error - data source add box error - change route next-chat -> chat , next-search->search ... ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- .../document-preview/pdf-preview.tsx | 3 ++- web/src/components/pdf-drawer/index.tsx | 23 +++++++++++++------ web/src/components/pdf-previewer/hooks.ts | 13 ++++++++--- web/src/hooks/use-document-request.ts | 4 ++-- web/src/routes.tsx | 10 ++++---- web/vite.config.ts | 9 +++++++- 6 files changed, 43 insertions(+), 19 deletions(-) diff --git a/web/src/components/document-preview/pdf-preview.tsx b/web/src/components/document-preview/pdf-preview.tsx index 5fc9377c57d..babcd37efc2 100644 --- a/web/src/components/document-preview/pdf-preview.tsx +++ b/web/src/components/document-preview/pdf-preview.tsx @@ -8,12 +8,12 @@ import { Popup, } from 'react-pdf-highlighter'; -import { useCatchDocumentError } from '@/components/pdf-previewer/hooks'; import { Spin } from '@/components/ui/spin'; // import FileError from '@/pages/document-viewer/file-error'; import { Authorization } from '@/constants/authorization'; import FileError from '@/pages/document-viewer/file-error'; import { getAuthorization } from '@/utils/authorization-util'; +import { useCatchDocumentError } from '../pdf-previewer/hooks'; import styles from './index.module.less'; type PdfLoaderProps = React.ComponentProps & { httpHeaders?: Record; @@ -149,3 +149,4 @@ const PdfPreview = ({ }; export default memo(PdfPreview); +export { PdfPreview }; diff --git a/web/src/components/pdf-drawer/index.tsx b/web/src/components/pdf-drawer/index.tsx index 4b6caeca9ad..2d54da7f456 100644 --- a/web/src/components/pdf-drawer/index.tsx +++ b/web/src/components/pdf-drawer/index.tsx @@ -1,8 +1,12 @@ +import { + useGetChunkHighlights, + useGetDocumentUrl, +} from '@/hooks/use-document-request'; import { IModalProps } from '@/interfaces/common'; import { IReferenceChunk } from '@/interfaces/database/chat'; import { IChunk } from '@/interfaces/database/knowledge'; import { cn } from '@/lib/utils'; -import DocumentPreviewer from '../pdf-previewer'; +import PdfPreview from '../document-preview/pdf-preview'; import { Sheet, SheetContent, SheetHeader, SheetTitle } from '../ui/sheet'; interface IProps extends IModalProps { @@ -13,13 +17,15 @@ interface IProps extends IModalProps { } export const PdfSheet = ({ - visible = false, hideModal, documentId, chunk, width = '50vw', height, }: IProps) => { + const getDocumentUrl = useGetDocumentUrl(documentId); + const url = getDocumentUrl(documentId); + const { highlights, setWidthAndHeight } = useGetChunkHighlights(chunk); return ( Document Previewer - + {url && documentId && ( + + )} ); diff --git a/web/src/components/pdf-previewer/hooks.ts b/web/src/components/pdf-previewer/hooks.ts index 0774a9d9ae7..164877c3a56 100644 --- a/web/src/components/pdf-previewer/hooks.ts +++ b/web/src/components/pdf-previewer/hooks.ts @@ -1,15 +1,22 @@ +import { Authorization } from '@/constants/authorization'; +import { getAuthorization } from '@/utils/authorization-util'; import axios from 'axios'; -import { useCallback, useEffect, useState } from 'react'; +import { useCallback, useEffect, useMemo, useState } from 'react'; export const useCatchDocumentError = (url: string) => { + const httpHeaders = useMemo(() => { + return { + [Authorization]: getAuthorization(), + }; + }, []); const [error, setError] = useState(''); const fetchDocument = useCallback(async () => { - const { data } = await axios.get(url); + const { data } = await axios.get(url, { headers: httpHeaders }); if (data.code !== 0) { setError(data?.message); } - }, [url]); + }, [url, httpHeaders]); useEffect(() => { fetchDocument(); }, [fetchDocument]); diff --git a/web/src/hooks/use-document-request.ts b/web/src/hooks/use-document-request.ts index 9075e50b108..6232569bede 100644 --- a/web/src/hooks/use-document-request.ts +++ b/web/src/hooks/use-document-request.ts @@ -469,8 +469,8 @@ export const useGetDocumentUrl = (documentId?: string) => { const getDocumentUrl = useCallback( (id?: string) => { return auth - ? `${ExternalApi}/v1/documents/${documentId || id}` - : `${api_host}/document/get/${documentId || id}`; + ? `${ExternalApi}/v1/documents/${id || documentId}` + : `${api_host}/document/get/${id || documentId}`; }, [documentId, auth], ); diff --git a/web/src/routes.tsx b/web/src/routes.tsx index 198b95f496d..300382dd27b 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -22,11 +22,11 @@ export enum Routes { MemoryMessage = '/memory-message', MemorySetting = '/memory-setting', AgentList = '/agent-list', - Searches = '/next-searches', - Search = '/next-search', - SearchShare = '/next-search/share', - Chats = '/next-chats', - Chat = '/next-chat', + Searches = '/searches', + Search = '/search', + SearchShare = '/search/share', + Chats = '/chats', + Chat = '/chat', Files = '/files', ProfileSetting = '/profile-setting', Profile = '/profile', diff --git a/web/vite.config.ts b/web/vite.config.ts index 833aebcb9ae..3504176923c 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -70,7 +70,14 @@ export default defineConfig(({ mode, command }) => { changeOrigin: true, ws: true, }, - '^/(api|v1)': { + + '/api': { + target: 'http://127.0.0.1:9380/', + changeOrigin: true, + ws: true, + }, + + '/v1': { target: 'http://127.0.0.1:9380/', changeOrigin: true, ws: true, From 7ddaf73ab751d02f0507de1609d501e2c1e65fa0 Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 12 Feb 2026 14:43:52 +0800 Subject: [PATCH 013/565] Refactor: split message apis to gateway and service (#13126) ### What problem does this PR solve? Split message apis to gateway and service ### Type of change - [x] Refactoring --- api/apps/restful_apis/memory_api.py | 121 ++++++++++++++++++ api/apps/sdk/messages.py | 158 ------------------------ api/apps/services/memory_api_service.py | 114 ++++++++++++++++- web/src/routes.tsx | 7 +- 4 files changed, 240 insertions(+), 160 deletions(-) delete mode 100644 api/apps/sdk/messages.py diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 53c7f866e27..c1cd9e5a9e7 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -171,3 +171,124 @@ async def get_memory_messages(memory_id): except Exception as e: logging.error(e) return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") +async def add_message(): + req = await get_request_json() + memory_ids = req["memory_id"] + + message_dict = { + "user_id": req.get("user_id"), + "agent_id": req["agent_id"], + "session_id": req["session_id"], + "user_input": req["user_input"], + "agent_response": req["agent_response"], + } + + res, msg = await memory_api_service.add_message(memory_ids, message_dict) + if res: + return get_json_result(message=msg) + + return get_json_result(message="Some messages failed to add. Detail:" + msg, code=RetCode.SERVER_ERROR) + + +@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 +@login_required +async def forget_message(memory_id: str, message_id: int): + try: + res = await memory_api_service.forget_message(memory_id, message_id) + return get_json_result(message=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 +@login_required +@validate_request("status") +async def update_message(memory_id: str, message_id: int): + req = await get_request_json() + status = req["status"] + if not isinstance(status, bool): + return get_error_argument_result("Status must be a boolean.") + + try: + update_succeed = await memory_api_service.update_message_status(memory_id, message_id, status) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/search", methods=["GET"]) # noqa: F821 +@login_required +async def search_message(): + args = request.args + memory_ids = args.getlist("memory_id") + if len(memory_ids) == 1 and ',' in memory_ids[0]: + memory_ids = memory_ids[0].split(',') + query = args.get("query") + similarity_threshold = float(args.get("similarity_threshold", 0.2)) + keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) + top_n = int(args.get("top_n", 5)) + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + + filter_dict = { + "memory_id": memory_ids, + "agent_id": agent_id, + "session_id": session_id + } + params = { + "query": query, + "similarity_threshold": similarity_threshold, + "keywords_similarity_weight": keywords_similarity_weight, + "top_n": top_n + } + res = await memory_api_service.search_message(filter_dict, params) + return get_json_result(message=True, data=res) + +@manager.route("/messages", methods=["GET"]) # noqa: F821 +@login_required +async def get_messages(): + args = request.args + memory_ids = args.getlist("memory_id") + if len(memory_ids) == 1 and ',' in memory_ids[0]: + memory_ids = memory_ids[0].split(',') + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + limit = int(args.get("limit", 10)) + if not memory_ids: + return get_error_argument_result("memory_ids is required.") + try: + res = await memory_api_service.get_messages(memory_ids, agent_id, session_id, limit) + return get_json_result(message=True, data=res) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 +@login_required +async def get_message_content(memory_id: str, message_id: int): + try: + res = await memory_api_service.get_message_content(memory_id, message_id) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") diff --git a/api/apps/sdk/messages.py b/api/apps/sdk/messages.py deleted file mode 100644 index 5ed5902188a..00000000000 --- a/api/apps/sdk/messages.py +++ /dev/null @@ -1,158 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from quart import request -from api.apps import login_required -from api.db.services.memory_service import MemoryService -from common.time_utils import current_timestamp, timestamp_to_date - -from memory.services.messages import MessageService -from api.db.joint_services import memory_message_service -from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result -from common.constants import RetCode - - -@manager.route("/messages", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") -async def add_message(): - - req = await get_request_json() - memory_ids = req["memory_id"] - - message_dict = { - "user_id": req.get("user_id"), - "agent_id": req["agent_id"], - "session_id": req["session_id"], - "user_input": req["user_input"], - "agent_response": req["agent_response"], - } - - res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict) - - if res: - return get_json_result(message=msg) - - return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg) - - -@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 -@login_required -async def forget_message(memory_id: str, message_id: int): - - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - forget_time = timestamp_to_date(current_timestamp()) - update_succeed = MessageService.update_message( - {"memory_id": memory_id, "message_id": int(message_id)}, - {"forget_at": forget_time}, - memory.tenant_id, memory_id) - if update_succeed: - return get_json_result(message=update_succeed) - else: - return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.") - - -@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 -@login_required -@validate_request("status") -async def update_message(memory_id: str, message_id: int): - req = await get_request_json() - status = req["status"] - if not isinstance(status, bool): - return get_error_argument_result("Status must be a boolean.") - - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id) - if update_succeed: - return get_json_result(message=update_succeed) - else: - return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") - - -@manager.route("/messages/search", methods=["GET"]) # noqa: F821 -@login_required -async def search_message(): - args = request.args - empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)] - if empty_fields: - return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.") - - memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') - query = args.get("query") - similarity_threshold = float(args.get("similarity_threshold", 0.2)) - keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) - top_n = int(args.get("top_n", 5)) - agent_id = args.get("agent_id", "") - session_id = args.get("session_id", "") - - filter_dict = { - "memory_id": memory_ids, - "agent_id": agent_id, - "session_id": session_id - } - params = { - "query": query, - "similarity_threshold": similarity_threshold, - "keywords_similarity_weight": keywords_similarity_weight, - "top_n": top_n - } - res = memory_message_service.query_message(filter_dict, params) - return get_json_result(message=True, data=res) - - -@manager.route("/messages", methods=["GET"]) # noqa: F821 -@login_required -async def get_messages(): - args = request.args - memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') - agent_id = args.get("agent_id", "") - session_id = args.get("session_id", "") - limit = int(args.get("limit", 10)) - if not memory_ids: - return get_error_argument_result("memory_ids is required.") - memory_list = MemoryService.get_by_ids(memory_ids) - uids = [memory.tenant_id for memory in memory_list] - res = MessageService.get_recent_messages( - uids, - memory_ids, - agent_id, - session_id, - limit - ) - return get_json_result(message=True, data=res) - - -@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 -@login_required -async def get_message_content(memory_id:str, message_id: int): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) - if res: - return get_json_result(message=True, data=res) - else: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 53bb0f6e9ef..e49fe7ed0f8 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -19,13 +19,14 @@ from api.db.services.user_service import UserTenantService from api.db.services.canvas_service import UserCanvasService from api.db.services.task_service import TaskService -from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler from common.constants import MemoryType, ForgettingPolicy from common.exceptions import ArgumentException, NotFoundException +from common.time_utils import current_timestamp, timestamp_to_date async def create_memory(memory_info: dict): @@ -169,6 +170,16 @@ async def delete_memory(memory_id): async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50): + """ + :param filter_params: { + "memory_type": list[str], + "tenant_id": list[str], + "storage_type": str + } + :param keywords: str + :param page: int + :param page_size: int + """ filter_dict: dict = {"storage_type": filter_params.get("storage_type")} tenant_ids = filter_params.get("tenant_id") if not filter_params.get("tenant_id"): @@ -221,3 +232,104 @@ async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, pa for extract_msg in message["extract"]: extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") return {"messages": messages, "storage_type": memory.storage_type} + + +async def add_message(memory_ids: list[str], message_dict: dict): + """ + :param memory_ids: list[str] + :param message_dict: { + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str, + "message_type": str + } + """ + return await queue_save_to_memory_task(memory_ids, message_dict) + + +async def forget_message(memory_id: str, message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + forget_time = timestamp_to_date(current_timestamp()) + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"forget_at": forget_time}, + memory.tenant_id, memory_id) + if update_succeed: + return True + raise Exception(f"Failed to forget message '{message_id}' in memory '{memory_id}'.") + + +async def update_message_status(memory_id: str, message_id: int, status: bool): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"status": status}, + memory.tenant_id, memory_id) + if update_succeed: + return True + raise Exception(f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + + +async def search_message(filter_dict: dict, params: dict): + """ + :param filter_dict: { + "memory_id": list[str], + "agent_id": str, + "session_id": str + } + :param params: { + "query": str, + "similarity_threshold": float, + "keywords_similarity_weight": float, + "top_n": int + } + """ + return query_message(filter_dict, params) + + +async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: str = "", limit: int = 10): + """ + Get recent messages from specified memories. + + :param memory_ids: list of memory IDs + :param agent_id: optional agent ID for filtering + :param session_id: optional session ID for filtering + :param limit: maximum number of messages to return + :return: list of recent messages + """ + memory_list = MemoryService.get_by_ids(memory_ids) + uids = [memory.tenant_id for memory in memory_list] + res = MessageService.get_recent_messages( + uids, + memory_ids, + agent_id, + session_id, + limit + ) + return res + + +async def get_message_content(memory_id: str, message_id: int): + """ + Get content of a specific message from a memory. + + :param memory_id: memory ID + :param message_id: message ID + :return: message content + :raises NotFoundException: if memory or message not found + """ + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) + if res: + return res + raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") \ No newline at end of file diff --git a/web/src/routes.tsx b/web/src/routes.tsx index 300382dd27b..8c6d538a6d5 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -1,5 +1,10 @@ import { lazy, memo, Suspense } from 'react'; -import { createBrowserRouter, Navigate, redirect, type RouteObject } from 'react-router'; +import { + createBrowserRouter, + Navigate, + redirect, + type RouteObject, +} from 'react-router'; import FallbackComponent from './components/fallback-component'; import { IS_ENTERPRISE } from './pages/admin/utils'; import authorizationUtil from './utils/authorization-util'; From caf7db97c6b3db0d1f38390a8eaac78aaecc8b5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=96=AF=E7=99=AB?= Date: Thu, 12 Feb 2026 15:40:15 +0800 Subject: [PATCH 014/565] Fix the bug where the mcp service tools/list does not return knowledge base IDs information. (#13123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix the issue where the server-side parameter validation fails when the id parameter is None in the asynchronous list_datasets method. ### What problem does this PR solve? Fix the issue where the server-side parameter validation fails when the id parameter is None in the asynchronous list_datasets method. ### Type of change - [√ ] Bug Fix (non-breaking change which fixes an issue) --- mcp/server/server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mcp/server/server.py b/mcp/server/server.py index 07cb10d9481..9b1bd4f1d12 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -138,7 +138,13 @@ async def list_datasets( id: str | None = None, name: str | None = None, ): - res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key) + params = {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc} + if id: + params['id'] = id + if name : + params['name'] = name + + res = await self._get("/datasets", params, api_key=api_key) if not res or res.status_code != 200: raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) From 31469265edbebb2cb628796aeb3cedc0d3d9287d Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Thu, 12 Feb 2026 15:40:55 +0800 Subject: [PATCH 015/565] Improve: optimize file name (with path) in box container. (#13124) ### What problem does this PR solve? Refact: optimize file name (with path) in box container. ### Type of change - [x] Performance Improvement image --- common/data_source/box_connector.py | 27 ++++++++++++++++++++++----- web/vite.config.ts | 7 ------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py index 3006e709c9c..253029d3c92 100644 --- a/common/data_source/box_connector.py +++ b/common/data_source/box_connector.py @@ -38,9 +38,10 @@ def validate_connector_settings(self): def _yield_files_recursive( self, - folder_id, + folder_id: str, start: SecondsSinceUnixEpoch | None, - end: SecondsSinceUnixEpoch | None + end: SecondsSinceUnixEpoch | None, + relative_folder_path: str = "", ) -> GenerateDocumentsOutput: if self.box_client is None: @@ -59,6 +60,7 @@ def _yield_files_recursive( file = self.box_client.files.get_file_by_id( entry.id ) + modified_time: SecondsSinceUnixEpoch | None = None raw_time = ( getattr(file, "created_at", None) or getattr(file, "content_created_at", None) @@ -72,13 +74,18 @@ def _yield_files_recursive( continue content_bytes = self.box_client.downloads.download_file(file.id) + semantic_identifier = ( + f"{relative_folder_path} / {file.name}" + if relative_folder_path + else file.name + ) batch.append( Document( id=f"box:{file.id}", blob=content_bytes.read(), source=DocumentSource.BOX, - semantic_identifier=file.name, + semantic_identifier=semantic_identifier, extension=get_file_ext(file.name), doc_updated_at=modified_time, size_bytes=file.size, @@ -86,7 +93,17 @@ def _yield_files_recursive( ) ) elif entry.type == 'folder': - yield from self._yield_files_recursive(folder_id=entry.id, start=start, end=end) + child_relative_path = ( + f"{relative_folder_path} / {entry.name}" + if relative_folder_path + else entry.name + ) + yield from self._yield_files_recursive( + folder_id=entry.id, + start=start, + end=end, + relative_folder_path=child_relative_path + ) if batch: yield batch @@ -159,4 +176,4 @@ def load_from_state(self): if __name__ == "__main__": pass - # app.run(port=4999) \ No newline at end of file + # app.run(port=4999) diff --git a/web/vite.config.ts b/web/vite.config.ts index 3504176923c..c1074708d23 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -65,18 +65,11 @@ export default defineConfig(({ mode, command }) => { overlay: false, }, proxy: { - '/api/v1/admin': { - target: 'http://127.0.0.1:9381/', - changeOrigin: true, - ws: true, - }, - '/api': { target: 'http://127.0.0.1:9380/', changeOrigin: true, ws: true, }, - '/v1': { target: 'http://127.0.0.1:9380/', changeOrigin: true, From 97264139626cd525d267c417ec5483ca8dfbe13f Mon Sep 17 00:00:00 2001 From: Ahmad Intisar <168020872+ahmadintisar@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:05:58 +0500 Subject: [PATCH 016/565] fix: register WebDAVConnector in data_source __init__.py (#13121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What problem does this PR solve? The sync_data_source.py module imports WebDAVConnector from common.data_source, but WebDAVConnector was never registered in the package's __init__.py. This causes an ImportError at startup, crashing the data sync service: ImportError: cannot import name 'WebDAVConnector' from 'common.data_source' The webdav_connector.py file already exists in the common/data_source/ directory — it just wasn't exported. This PR adds the import and registers it in __all__. Type of change Bug Fix (non-breaking change which fixes an issue) Co-authored-by: Ahmad Intisar --- common/data_source/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 74baaee016f..099f3d7b3bd 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -41,6 +41,7 @@ from .zendesk_connector import ZendeskConnector from .seafile_connector import SeaFileConnector from .rdbms_connector import RDBMSConnector +from .webdav_connector import WebDAVConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo from .exceptions import ( @@ -81,4 +82,5 @@ "ZendeskConnector", "SeaFileConnector", "RDBMSConnector", + "WebDAVConnector", ] From 7cacfa3fb34511acef53461a2474b016db3e1676 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Thu, 12 Feb 2026 19:48:35 +0800 Subject: [PATCH 017/565] Fix: replace session page icons and fix nested list search functionality in filters (#13127) ### What problem does this PR solve? Fix: replace session page icons and fix nested list search functionality in filters ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/public/batch_delete2.png | Bin 3074 -> 0 bytes web/public/return2.png | Bin 6472 -> 0 bytes .../list-filter-bar/filter-field.tsx | 2 +- .../list-filter-bar/filter-popover.tsx | 12 +++++----- web/src/locales/en.ts | 1 + web/src/locales/zh.ts | 1 + .../components/metedata/manage-modal.tsx | 5 ++-- web/src/pages/next-chats/chat/sessions.tsx | 22 ++++++++++++------ 8 files changed, 27 insertions(+), 16 deletions(-) delete mode 100644 web/public/batch_delete2.png delete mode 100644 web/public/return2.png diff --git a/web/public/batch_delete2.png b/web/public/batch_delete2.png deleted file mode 100644 index 91d5342cb861c91d69f9d0cee65ec0476ea53832..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3074 zcmds3d03L^7AFu~QbJSQ4Ux>osIVND#-!9MW|`UIZn$R(rlwAVq)b+zX=JWhT8=h1 z7jj7v_k9_fY_v>6rA)4&;#Mwc>JU(Oo_n9?{xN^vKfdpK&UfDTob&$9`Mu}R935;G zMu({jL`p~H#?+Wgna8kC{xVjI+2z_Fa&~wEw zv%wJNjIRZTf!`Hl1HjLU0UnzKa&0OD(cp~LfcQ6rH|i7XJ1@2=q#7MR-T&xm;Rl;j z^fmm(;M*nFmvJvgWOb*lDKaOn02T8fV5GHzoHbQLnF=#Dfa#uhL>zu`9~=MALw{Xh zBVAs(FxgjF`0z9@VA1ggY-YHZ)M#$@>Ej|Yt&K1L$SS{HVt}My=&!CTh7wqU zcm7^!4h|uvrE za8;BaVwN9U3fMyXPNgd=3(VL7XX-A{;OjdA4^Q&yvou)uz`$&egG0J46v4($&7UV` zIA5<)Myf|;Ewt>efLGnnXJd5^-W8ZR*M}WOwQnp>8D=xaw6ZaYEmx z)G=lFMF|z#;`+|XByy{o9dRO^jQ9qp8}|iz65HaRCS2>*JMoa{Dm{;5@)TuDmrzN> z&aGYF)3fc1E-b?GjE{v*q!SS4Bx>q@i$p(7Dq-sQs!^pCRe68SqwGtWXA#2B`=9kq zBH2c;+PXU&-pdE*tGH{P0nJ~2(e)c~25mjfJnb)_$o44ux6(5(AhAQ&kC&>d(=7CO)KP)UVA;UX#hL8O%3<45XU)aXrE7HkzQb3* z2|wRSIHRe#4qj2jlO2DX6r8%^mK(lk(HJa|x6`p_Qx=UgWJQ%cqzhKUvwzEZr7GTj9gd6Jjt+K!iYr*# z2SH)*1S_8+u82-F?Ff>+xIN6dCwre{y%r8p`Sek4j|xz1?}4-u&R~IJ(ZdL)SsXtR zCM9;8EDWO%_u9!vYm*pLUs=@iXjDOmqPPMMJj4RYHcsK+)SFRF$@w2mfnqnEYPdL+ zZrG@dB&=XzSbFM5zO9C$o}WfpCFqGs7Cl9ga5TFwKS2aF|hQyVly)d0Ou&B z1^zvgn!}2Gglp!j<@K3baz8=I2HU) zmlkT#zeeqUIBD9CO7JtzTHevw$xbWzO2QyYJ}gH$Oz?~~N%M&{9>Vt%Mu>zAfzr9)Y+AT13&%hkpr1~fK$3v6?{^`c=pP34o}1EVPqdAX6N6@K)`PZOJ%_Co z`Z}kDd$ZzX_}P|J@2zSe2%YG-e;Pr_#=+sVXDDRKev?Sw=jL>5^-xXqg^fOv(%ill zQt~O`Upk6z{5B*!lIP`|md#PO$?zP#Mid5%2tvf}_}n&pIO{d0;<`_~i}JmZ7fFW# lDKZ$T|2ZxH=Qus9iqUi?KLPG~E1c~Cb~X;yjKlb!{{(%VS?B-& diff --git a/web/public/return2.png b/web/public/return2.png deleted file mode 100644 index 4655fb319cb0e33aa48a7be0e28bd2fe7dd11cae..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6472 zcmW-mby!qS7sr=nfn7qDrGFwUjf6-@3M{bFQX(NONJ%$HEhU|zG$`FHNY{!>Nw;)& zBO)p9`u=h6eV+TA`OeHe=gge>+(=DzMKa?5hyefqnX(d68~5z}cYz6U_X_>WYybcX zP)5q?c;oG7K~nE`rSvSlwT(HT@A-*k7KlSj?b47iy-!Od4ZV*)10700tcJ$(sCE)T z{AQkibbh~v|52(2($CS4wvV6%q7a?1_tkX!UG>5t5UU3h74+oY-saCq>8g7eJ@iZl z!&2q*N*$FK#>S_g?B*?&m74CzEWFfM5DYR^3xdR`1`!B8`cE1I0_ekU9-0B!!PN*A z%!<5P5J6`S9*1Tyf`CB#i;grF#3C0&085NpwVMhDkmQTZH%&x9%}~hvyvFo8W(=@| z%H&E=3@G;Ed}7Fx30N-gUdQ<$N++&Q!O4(0KW02nE=v$q^{&Z=W-OHx$3czDM})XFN18{C?Em~A~@Foy;8gKR}* zTUq30DU!D}t%rUinG%wb?$mh> z$}YAYuwrH#IR+-0QlJMWUcBiNQ}HtEiCbKhKfQdCp<=z3r^PZd0z{YB$7^J1doxnd zUhfFwttwB?`CfF#?k%3Ly?6Mom}^cT7ji-XM&E<=$S!ZOFEh>!Wq%mn^15;vkEqwO zag!&`ulv0$@?VT~NyvF&eh1jaW2nyvagR$xDbS}x5=h3GVEJ~pR)DoV$=8*Gy&}x2 z)JOx9l$qS7)sJ`oSvW!-f)#$T1;Niog!au!v0*(CE??FxKiXZJeeZ?Z(4p7{C&&m_ zr6P(KteCMhC~vx0@$5J?a<^yx6eEADFSzTPy9p!L0uZe~G|+5umBD*;bDo>z0XNkS zFTo&PC1VoH$KB`{EE>y+j$6!XD>*zIij}^t_Q8TDDiC7b6ljgtN@a$hfC7V_aW0`} z__TF)lUa!1C>Kq$wPuDMu`lO7_gdRM*c0Fu48n9B($2H{k6l3uBy6tF>58lA5iuda zpf-jzy5+t=BJe~Qf|tjmZHbO(<>#(%*mnoXet~BBZdA}Tsj%YpQCaq} zwy`n4yWKmp@6Bvb0~CmQEf=sH4V?Gv+1ak_XWUe{WPvK9Ks0O6)jxAySn}F~p$0!K z%e7{9&6g;U>?kR3H*#!>c;h2w!yk6&v*7Kb_;yM&>Qc4i)2Ft<^AG`ojf* z5RPuZRSR~fP$=`#yp)d&~#}DmSF}<%7!#bUE z>%eYab9BWDwL+QUawQfp{S8~CEiTrr6?Kn^M;ZPBsLTdy(Q?0mAjCQ>4>-ko1>zup zlNAUgR2}ZNbtuDXHDDAO6ew=eOn&$D zv5R#J+@oKtP=X5sKYzPGouo1|txuKy43D>>b*AhQla2~)ID>@B!G#ya?u8}91F{O{ zsG9S4<6YJ>W2)W@#XQRSBP@tYC&3Djo8qBPzY)Likd#~O`6h7B*wS%45?!7dko*@TO0TP6+JwCoVDbt;zh;#W67Adp?D{jWc#pk(~_If11 zm(<`MF$H%4r|@jVB&bs$ZX}SVG1vlH9Q7|mz|e;E3+Xt4=0vZTQHU-Am}ZQ;H?ZAy zZia`OAox}gg$C>Bz`tR|uoKi z2l09^IjO5n5d!&K7;q=v1+YpMP|5|e=Wa|L+L|u^myl^eOh6Pd8-}ju?wq>1twqw} z+LmEX5U&FSq*T%rxCdRGX;`{kP9VEK)wOCrv* z!_;USCh7XhkEyg-8thOj3*9Ez76zEH&vsa)J47e-t$$rO#BPVExhHwU}rISrY+ngHvEK%uv<4IwVltlzyw`B@gNd;o#wX5AVImiT)++u zPGC`rpOC$7f}mg+HlPWEP|W-EqUOLa`uf&mYJm-Pzk@VDuhgisY#|`L!^+nWnu?2I zh^MJT!VTRP$xB}cCaRv--jZm&jjBP=@akIvv#e02Y6uh{_Jxd7?#W|ge|v$!vlQW9 z;{IQKF3!0~u#bsC0q!1`5mVv)d)Pv56}-=JEr z?6&S8P>t{kW}}>9N{q%R-J6m(1v&NQ_NC>&_uQ%9BnKN*JS}QGvMbylIK%s5>Q>xhKm1%I(F7me!

*~X>EXB zjtvoXZ@vDGE&J2b`UYDz8590?K2P|g-L;AjrS&5}*M0ld?yQET$ZaZLeP$&dqlq{h zA5Zqjcb>DplGHgoPI2a>n^PYWpUD#si0@T`f<%$f;-0nZQEF;a5gi{iqJw|K_N z<;Ro1*<&bH@)wFUOxlM{cBE7+%YHJO0u$?gY`7lvkqcMIAp3@0W1!=DLm8j0@|ig8 z4Sj%3^Rt8X3r%A~UN-~vwbQfeo4sC(2M#BoOts*q5W8cM0{AD-QQmAPhv)thN<|wN zr+f*Y_;N44|0Hv+poae_$%}vB%GMw7aP-~U}^1)4^(j-D*d;W3Q zQ#0o$ay*P7PpI_gMa6O`-*GF#Dc{i#$XX}$gD_s(zA1}@Ufk`>t3}Tk*dyero{PVW z0RJ?8ppm3MbjZ{AZhUJ+x|!7NR!>8nAlf>q@67M6{l_n8F~TyjKI@W*%DLWX`Z-(D z4wbp_Gq=^Ey(_a8XXy6f;k=FHj2n9SfC$KVZpeeY09{@lS-P&+ceMXNlN-TnMvSDd zNi-tzY;L7guA-qG!h>g`XAi31=!4fq+}B;zH=}{A&*U)w<@`=3V)?7~r2SZ*0@d3s zseaxRyw|56dH$U?!Gy3gOIhf**BEEfHf8leT+D&xJvUUS60>(W5dk5+>2GX6Xo9HZFOg+}teIQ!_g*QD0zR`Uq-3-_y6K&TUF*S%UM$}@rq-HuzmvP4uVU+bA#_5Y3*8-Zl&bFqqd`LRXw^N8h>cPqQeg{^% zro$JBrh zCw;5khLrTvjB39jC>Ng33r25edaicp@u?49)-yR~=*-{jY}%#BE*Y`#!up#o6=-}8 z`rR+$bi%)lXZ;UnL0qc=oMf}LQ5C5)69#O;z(hd!YxeeFh4l#aS|Vm_O+^I}XD*Q6 z)4&**&G|^1Pb>Rb@P(j7F-Q<3OBjp5{JFuZ^iH>W^R|a0#}xKjk*R<|@5%zeudJx& zsWj0m22mn;HSe|IH6O(}-7=Q+A(v_NLaRK<6(z%C68sb#%Tb+m067QDy+*84vOEuT z=JD0=Nj0aLK*if7y<{U<{8G|2Z6!$6ga7fy_y(RXou7NdZCio+l`Y z_++Ca2;Y{i*E_OcwlvGvOT^voWqsKTIj~d2;J4}(xJyLyW=eY9(Q}J*`LMD=4Qye3 z_4?yPibPNzgQp%UD?V{4N~NPJamhZ~;ETYNR~l);m)v z^OrkbLESO)^aGM`uTvE&@Q5-t+4LcIOC@lnu|_Dqmrfi-Iw`Yd$xoBr7)ly2NTZa7)7JDb|~JEu7oSCy^!h)$x8MDAL?s#+@u zNVl3xqjpcL%bTZRmXrJADlrC346<)o!%HvGSNmi%8=t@&^nx_hG>w5LQHeDPdsqBY z_?*`6lE_uj1YSd5pXEP`e?fYyr}bZ}^4#ap5Oh+0TC=$gQ^)&MK$NwP-FrrT?#uV5 zQ(*{83kJ^X*WWx|2y^Tt&CNnixM7hDpeeHZ4m8XnRGr?+CBEQS=-;gp zl=F=pH6=Fs=9f+i_kQqC^^_YyCwT}`B-MgFNiK_xYFy{lYIEgj#T*$Y6r$6uV&qiT zI`6BETh*J_SFcsvR|qVM*b)xrys8mDHnMRi!m-A6H(`#ZGNTOiX`<3SIdk_h;m0x| zZ+Q_Z^Y4q#NEi7Gr#h-$WraTtp^|4w^vkkpsC=|o)T)k7^5(6VNVG96w(lpL1!+dq z+&vs^{Ix(qkWx2;JHUS6+6~A9h|&7BImf4>cwmwBlzvU-Yc=4r9^q}K7Y1o>IIxVD zh3We65UjnV8kYQo-RKm$*HyKkE1+%e@Jj&As0J-iy1fLB96R8N39jtRa z!@NL{XAc2SY^0CF#XK4~KF9!sss{skz|lAw-?XPy6^ zLiBa-$I2YP*f7LgpZ-8F-#P2j^gfOn<&lz-YO7Dad}=JXAa!NaC)L64^MU+|Ey%i9 z3<}JC*=L&hi9lP;BGY4T$fy{`v%_M)el;xan6z1bO}-l!I=0JfHBkFms%6L7U6Sjlbjsui1X*xe*Q*-RwzQEM!;UHNh8P_| zE@%cl3v9~;IdhKmO`;{9sHD9+S~!0QYSk=g15RYGNsft&obUBE9|qw}sSQP`{#IEkZgtM6^@IY^vPf5U_He`*&V{*^|nz)^{W(c(Rh@o$jU`{U0p3U zYjjzax5lz1pl64#48yQp)P`pYi~2I-Ay~C|Z18ew);| zx1B(GS~8|$4%AHTTp`EcnP7Uq)XO3xiFZ%ahpj*E{QOvBj{5NBdrzqxJW>A)ZOezY z#*_V(sQlXPXlRjyTg!k(iM_aHxCX)oE4h6g&e^d)V^TjF5koL7^-Nmb6MKjdF@U4Y z%5#$?JRQbwc);^TGuC_l}Op%tdEo6%7Sv_^#<%JPl@x^yDI5lNA zh++GvpU?2!Xh|N1tD!O7h`Ll3TQiK^tq7E*5>|xfbf)esY(|&?f3}r z1YvXQm?ii|gpLEz>HXaM*h9kN>lkT@-|1&bzv}XoT`OG2bPDFS9xkGewrnwqMjDxN zovClPguI@GA>~Ps0YyEOtKr4&x)Zy~UwXZ7J91WsSj_#HbB6y)*g4qPpn8W~Gxwil zA3LV(W!}J`dpNiBEc+ClQ_n9bQ$;f<{tZysG7ONK{?Wcd!18Sze&Y1(j$a{?Xs5%Z zVZLWa9!NELn6i{#Ap4t8g@5#%c=ZfW(Qo?m6_7xw?N9rh8vY?ZPIAnseeEQd@%GZ3 z&#tx`NWbnlcS@UAV?7?&HkFwrtIi}?oi#u;H#AgpFaT5`8aq#{PreoWzfk(#kh}fj ze3L))jr&q}u{en(DOYPWgcL`vNfVd7I35$?5ILQuH!UR{4o1_nd^=JzBgK(o+=}#w z<1rG>g(AIN^WxCd!vSZH$*lR*TpURc=J@e_&-- za=gYFQ|vvibo9U(m-c6JU-P6r67KL0i0SrB~x~q$${rOh|O;|8) zOvvKe#?nbn9ORdFdn6$DPiL36((X4-hZ~!@Fhzut!3O~D?$j4OzTy6d5dJHE@BxwT WVIntSWmw!BfbuhSWTl*0@c#gh@fgqm diff --git a/web/src/components/list-filter-bar/filter-field.tsx b/web/src/components/list-filter-bar/filter-field.tsx index 8a66e33d99e..2b5ab8018a4 100644 --- a/web/src/components/list-filter-bar/filter-field.tsx +++ b/web/src/components/list-filter-bar/filter-field.tsx @@ -74,7 +74,7 @@ const FilterItem = memo(

handleCheckChange({ checked, field, item }) } diff --git a/web/src/components/list-filter-bar/filter-popover.tsx b/web/src/components/list-filter-bar/filter-popover.tsx index 31c5408e370..3da931e16e9 100644 --- a/web/src/components/list-filter-bar/filter-popover.tsx +++ b/web/src/components/list-filter-bar/filter-popover.tsx @@ -61,12 +61,12 @@ const filterNestedList = ( return false; }) .map((item) => { - if (item.list && item.list.length > 0) { - return { - ...item, - list: filterNestedList(item.list, searchTerm), - }; - } + // if (item.list && item.list.length > 0) { + // return { + // ...item, + // list: filterNestedList(item.list, searchTerm), + // }; + // } return item; }); }; diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 43bb024a5a0..4e95d1454cd 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -232,6 +232,7 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim description: 'Description', fieldName: 'Field name', editMetadata: 'Edit metadata', + addMetadata: 'Add metadata', deleteWarn: 'This {{field}} will be removed from all associated files', deleteManageFieldAllWarn: 'This field and all its corresponding values will be deleted from all associated files.', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 9c10a75b282..30dc591226c 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -207,6 +207,7 @@ export default { description: '描述', fieldName: '字段名称', editMetadata: '编辑元数据', + addMetadata: '添加元数据', deleteWarn: '此 {{field}} 将从所有关联文件中移除', deleteManageFieldAllWarn: '此字段及其所有对应值将从所有关联的文件中删除。', diff --git a/web/src/pages/dataset/components/metedata/manage-modal.tsx b/web/src/pages/dataset/components/metedata/manage-modal.tsx index 4cf0d42ed24..e94480195a7 100644 --- a/web/src/pages/dataset/components/metedata/manage-modal.tsx +++ b/web/src/pages/dataset/components/metedata/manage-modal.tsx @@ -552,7 +552,9 @@ export const ManageMetadataModal = (props: IManageModalProps) => { {metadataType === MetadataType.Setting || metadataType === MetadataType.SingleFileSetting ? t('knowledgeDetails.metadata.fieldSetting') - : t('knowledgeDetails.metadata.editMetadata')} + : isAddValueMode + ? t('knowledgeDetails.metadata.addMetadata') + : t('knowledgeDetails.metadata.editMetadata')}
} type={metadataType} @@ -569,7 +571,6 @@ export const ManageMetadataModal = (props: IManageModalProps) => { isShowValueSwitch={isShowValueSwitch} isShowType={true} isVerticalShowValue={isVerticalShowValue} - isAddValueMode={isAddValueMode} // handleDeleteSingleValue={handleDeleteSingleValue} // handleDeleteSingleRow={handleDeleteSingleRow} /> diff --git a/web/src/pages/next-chats/chat/sessions.tsx b/web/src/pages/next-chats/chat/sessions.tsx index 8ce21d9934f..4418df1818c 100644 --- a/web/src/pages/next-chats/chat/sessions.tsx +++ b/web/src/pages/next-chats/chat/sessions.tsx @@ -12,7 +12,14 @@ import { useRemoveConversation, } from '@/hooks/use-chat-request'; import { cn } from '@/lib/utils'; -import { Check, PanelLeftClose, Plus, Trash2 } from 'lucide-react'; +import { + Check, + CopyX, + PanelLeftClose, + Plus, + Trash2, + Undo2, +} from 'lucide-react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useHandleClickConversationCard } from '../hooks/use-click-card'; @@ -148,7 +155,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { className="size-6" onClick={exitSelectionMode} > - 返回 + ) : ( - 批量删除 + + // 批量删除 )} From 8236ac14f73a0ea35729a29d003fec705e78c5fe Mon Sep 17 00:00:00 2001 From: Levi <81591061+levischd@users.noreply.github.com> Date: Thu, 12 Feb 2026 12:48:51 +0100 Subject: [PATCH 018/565] fix(metadata): handle unhashable list values in metadata split (#13116) ### What problem does this PR solve? This PR fixes missing metadata on documents synced from the Moodle connector, especially for **Book** modules. Background: - Moodle Book metadata includes fields like `chapters`, which is a `list[dict]`. - During metadata normalization in `DocMetadataService._split_combined_values`, list deduplication used `dict.fromkeys(...)`. - `dict.fromkeys(...)` fails for unhashable values (like `dict`), causing metadata update to fail. - Result: documents were imported, but metadata was not saved for affected module types (notably Books). What this PR changes: - Replaces hash-based list deduplication with `dedupe_list(...)`, which safely handles unhashable list items while preserving order. - This allows Book metadata (and other complex list metadata) to be persisted correctly. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Contribution during my time at RAGcon GmbH. --- api/db/services/doc_metadata_service.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 339d51c3086..69f25de485e 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -231,8 +231,9 @@ def _split_combined_values(cls, meta_fields: Dict) -> Dict: new_values.append(item) else: new_values.append(item) - # Remove duplicates while preserving order - processed[key] = list(dict.fromkeys(new_values)) + # Remove duplicates while preserving order. + # Use string-based dedupe to support unhashable values (e.g. dict entries). + processed[key] = dedupe_list(new_values) else: processed[key] = value From 0c5e4e089a30f5528704644f35f10272e2933e67 Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Thu, 12 Feb 2026 20:14:05 +0800 Subject: [PATCH 019/565] Docs: Updated v0.24.0 release notes. (#13129) ### What problem does this PR solve? Added more details to v0.24.0 release notes. ### Type of change - [x] Documentation Update --- docs/release_notes.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/release_notes.md b/docs/release_notes.md index d222f977f51..6e6319dd258 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -16,8 +16,8 @@ Released on February 10, 2026. ### New features - Memory - - Introduces APIs and an SDK for developer integration. - - Outputs Memory extraction log to the console for debugging and tracing. + - Introduces memory management APIs (HTTP and Python). + - Outputs Memory extraction log to the console. - Dataset - Supports batch metadata management. - Renames "ToC (Table of Contents)" to "PageIndex". See [here](./guides/dataset/extract_table_of_contents.md). @@ -28,9 +28,9 @@ Released on February 10, 2026. - Adds a new "Thinking" mode and removes the previous "Reasoning" configuration option. - Optimizes retrieval strategies for deep-research scenarios, enhancing recall accuracy. - Admin - - Supports configuring multiple Admin accounts. + - Supports multiple Admin accounts. - Model configuration center - - Adds a model connection test feature when adding new models. + - Adds model connection test for new models. ### MySQL alternative @@ -48,6 +48,16 @@ Released on February 10, 2026. - Zendesk - Bitbucket +### API changes + +#### HTTP API + +[Memory management API](./references/http_api_reference.md#memory-management) + +#### Python API + +[Memory management API](./references/python_api_reference.md#memory-management) + ## v0.23.1 Released on December 31, 2025. From 397b4890a588ca4bb7dec982c1c933a63b0e979f Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Fri, 13 Feb 2026 18:40:41 +0800 Subject: [PATCH 020/565] Refactor: Remove ant design component (#13143) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Refactoring --- web/src/app.tsx | 88 +--- web/src/components/document-preview/hooks.ts | 21 + .../document-preview/pdf-preview.tsx | 2 +- web/src/components/floating-chat-widget.tsx | 4 +- .../components/message-item/group-button.tsx | 80 ++- .../next-message-item/group-button.tsx | 54 +- .../components/parse-configuration/index.tsx | 217 -------- .../raptor-form-fields-old.tsx | 146 ------ web/src/components/pdf-previewer/hooks.ts | 25 - .../pdf-previewer/index.module.less | 12 - web/src/components/pdf-previewer/index.tsx | 135 ----- web/src/components/ui/date-picker.tsx | 470 ++++++++++++++++++ web/src/components/ui/input-date.tsx | 175 ------- web/src/components/ui/modal/modal.tsx | 129 ++++- web/src/hooks/logic-hooks.ts | 11 +- web/src/hooks/use-chunk-request.ts | 3 +- web/src/hooks/use-file-request.ts | 2 +- web/src/hooks/use-llm-request.tsx | 2 +- web/src/hooks/use-user-setting-request.tsx | 18 +- web/src/interfaces/antd-compat.ts | 138 +++++ .../components/header/index.module.less | 120 ----- web/src/layouts/components/header/index.tsx | 116 ----- .../components/right-toolbar/index.less | 25 - .../components/right-toolbar/index.tsx | 110 ---- web/src/layouts/components/user/index.tsx | 28 -- web/src/locales/config.ts | 123 ++--- web/src/main.tsx | 15 +- web/src/pages/404.tsx | 40 +- .../agent/form/google-scholar-form/index.tsx | 24 +- .../pages/agent/hooks/use-get-begin-query.tsx | 2 +- web/src/pages/agent/interface.ts | 2 +- web/src/pages/agent/share/index.tsx | 4 +- web/src/pages/agent/utils.ts | 2 +- .../metedata/manage-modal-column.tsx | 4 +- .../metedata/manage-values-modal.tsx | 4 +- .../dataset-setting/category-panel.tsx | 27 +- .../chunk-method-learn-more.tsx | 4 +- .../dataset-setting/components/tag-item.tsx | 39 +- .../dataset/dataset-setting/tag-tabs.tsx | 4 +- web/src/pages/document-viewer/image/index.tsx | 43 -- web/src/pages/files/hooks.ts | 2 +- web/src/pages/next-chats/chat/interface.ts | 2 +- .../hooks/use-send-shared-message.ts | 2 +- web/src/pages/next-chats/share/index.tsx | 4 +- web/src/pages/next-search/share/index.tsx | 4 +- web/src/utils/document-util.ts | 2 +- web/src/utils/file-util.ts | 2 +- web/src/utils/next-request.ts | 2 +- web/src/utils/notification.ts | 58 +++ web/src/utils/request.ts | 3 +- web/vite.config.ts | 7 + 51 files changed, 1073 insertions(+), 1483 deletions(-) delete mode 100644 web/src/components/parse-configuration/index.tsx delete mode 100644 web/src/components/parse-configuration/raptor-form-fields-old.tsx delete mode 100644 web/src/components/pdf-previewer/hooks.ts delete mode 100644 web/src/components/pdf-previewer/index.module.less delete mode 100644 web/src/components/pdf-previewer/index.tsx create mode 100644 web/src/components/ui/date-picker.tsx delete mode 100644 web/src/components/ui/input-date.tsx create mode 100644 web/src/interfaces/antd-compat.ts delete mode 100644 web/src/layouts/components/header/index.module.less delete mode 100644 web/src/layouts/components/header/index.tsx delete mode 100644 web/src/layouts/components/right-toolbar/index.less delete mode 100644 web/src/layouts/components/right-toolbar/index.tsx delete mode 100644 web/src/layouts/components/user/index.tsx delete mode 100644 web/src/pages/document-viewer/image/index.tsx create mode 100644 web/src/utils/notification.ts diff --git a/web/src/app.tsx b/web/src/app.tsx index 8bd234a0ecb..7a2086440b2 100644 --- a/web/src/app.tsx +++ b/web/src/app.tsx @@ -1,16 +1,8 @@ import { Toaster as Sonner } from '@/components/ui/sonner'; import { Toaster } from '@/components/ui/toaster'; -import i18n from '@/locales/config'; +import i18n, { changeLanguageAsync } from '@/locales/config'; import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { configResponsive } from 'ahooks'; -import { App, ConfigProvider, ConfigProviderProps, theme } from 'antd'; -import pt_BR from 'antd/lib/locale/pt_BR'; -import deDE from 'antd/locale/de_DE'; -import enUS from 'antd/locale/en_US'; -import ru_RU from 'antd/locale/ru_RU'; -import vi_VN from 'antd/locale/vi_VN'; -import zhCN from 'antd/locale/zh_CN'; -import zh_HK from 'antd/locale/zh_HK'; import dayjs from 'dayjs'; import advancedFormat from 'dayjs/plugin/advancedFormat'; import customParseFormat from 'dayjs/plugin/customParseFormat'; @@ -18,13 +10,12 @@ import localeData from 'dayjs/plugin/localeData'; import weekOfYear from 'dayjs/plugin/weekOfYear'; import weekYear from 'dayjs/plugin/weekYear'; import weekday from 'dayjs/plugin/weekday'; -import React, { useEffect, useState } from 'react'; +import React, { useEffect } from 'react'; import { RouterProvider } from 'react-router'; -import { ThemeProvider, useTheme } from './components/theme-provider'; +import { ThemeProvider } from './components/theme-provider'; import { SidebarProvider } from './components/ui/sidebar'; import { TooltipProvider } from './components/ui/tooltip'; import { ThemeEnum } from './constants/common'; -// import { getRouter } from './routes'; import { routers } from './routes'; import storage from './utils/authorization-util'; @@ -47,24 +38,6 @@ dayjs.extend(localeData); dayjs.extend(weekOfYear); dayjs.extend(weekYear); -const AntLanguageMap = { - en: enUS, - zh: zhCN, - 'zh-TRADITIONAL': zh_HK, - ru: ru_RU, - vi: vi_VN, - 'pt-BR': pt_BR, - de: deDE, -}; - -// if (process.env.NODE_ENV === 'development') { -// const whyDidYouRender = require('@welldone-software/why-did-you-render'); -// whyDidYouRender(React, { -// trackAllPureComponents: true, -// trackExtraHooks: [], -// logOnDifferentValues: true, -// }); -// } if (process.env.NODE_ENV === 'development') { import('@welldone-software/why-did-you-render').then( (whyDidYouRenderModule) => { @@ -78,6 +51,7 @@ if (process.env.NODE_ENV === 'development') { }, ); } + const queryClient = new QueryClient({ defaultOptions: { queries: { @@ -87,53 +61,31 @@ const queryClient = new QueryClient({ }, }); -type Locale = ConfigProviderProps['locale']; - function Root({ children }: React.PropsWithChildren) { - const { theme: themeragflow } = useTheme(); - const getLocale = (lng: string) => - AntLanguageMap[lng as keyof typeof AntLanguageMap] ?? enUS; - - const [locale, setLocal] = useState(getLocale(storage.getLanguage())); + useEffect(() => { + const lng = storage.getLanguage(); + if (lng) { + document.documentElement.lang = lng; + } + }, []); i18n.on('languageChanged', function (lng: string) { storage.setLanguage(lng); - setLocal(getLocale(lng)); - // Should reflect to document.documentElement.lang = lng; }); return ( - <> - - - {children} - - - - - {/* */} - + +
{children}
+
); } const RootProvider = ({ children }: React.PropsWithChildren) => { useEffect(() => { - // Because the language is saved in the backend, a token is required to obtain the api. However, the login page cannot obtain the language through the getUserInfo api, so the language needs to be saved in localstorage. const lng = storage.getLanguage(); if (lng) { - i18n.changeLanguage(lng); + changeLanguageAsync(lng); } }, []); @@ -145,6 +97,8 @@ const RootProvider = ({ children }: React.PropsWithChildren) => { storageKey="ragflow-ui-theme" > {children} + + @@ -159,16 +113,6 @@ const RouterProviderWrapper: React.FC<{ router: typeof routers }> = ({ RouterProviderWrapper.whyDidYouRender = false; export default function AppContainer() { - // const [router, setRouter] = useState(null); - - // useEffect(() => { - // getRouter().then(setRouter); - // }, []); - - // if (!router) { - // return
Loading...
; - // } - return ( diff --git a/web/src/components/document-preview/hooks.ts b/web/src/components/document-preview/hooks.ts index 097185b246c..6ecdf38152c 100644 --- a/web/src/components/document-preview/hooks.ts +++ b/web/src/components/document-preview/hooks.ts @@ -165,3 +165,24 @@ export const useFetchDocx = (filePath: string) => { return { succeed, containerRef, error }; }; + +export const useCatchDocumentError = (url: string) => { + const httpHeaders = useMemo(() => { + return { + [Authorization]: getAuthorization(), + }; + }, []); + const [error, setError] = useState(''); + + const fetchDocument = useCallback(async () => { + const { data } = await axios.get(url, { headers: httpHeaders }); + if (data.code !== 0) { + setError(data?.message); + } + }, [url, httpHeaders]); + useEffect(() => { + fetchDocument(); + }, [fetchDocument]); + + return error; +}; diff --git a/web/src/components/document-preview/pdf-preview.tsx b/web/src/components/document-preview/pdf-preview.tsx index babcd37efc2..d8fb31c97c6 100644 --- a/web/src/components/document-preview/pdf-preview.tsx +++ b/web/src/components/document-preview/pdf-preview.tsx @@ -13,7 +13,7 @@ import { Spin } from '@/components/ui/spin'; import { Authorization } from '@/constants/authorization'; import FileError from '@/pages/document-viewer/file-error'; import { getAuthorization } from '@/utils/authorization-util'; -import { useCatchDocumentError } from '../pdf-previewer/hooks'; +import { useCatchDocumentError } from './hooks'; import styles from './index.module.less'; type PdfLoaderProps = React.ComponentProps & { httpHeaders?: Record; diff --git a/web/src/components/floating-chat-widget.tsx b/web/src/components/floating-chat-widget.tsx index 14b85018d89..c0bd169aa9b 100644 --- a/web/src/components/floating-chat-widget.tsx +++ b/web/src/components/floating-chat-widget.tsx @@ -3,7 +3,7 @@ import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { MessageType, SharedFrom } from '@/constants/chat'; import { useFetchExternalAgentInputs } from '@/hooks/use-agent-request'; import { useFetchExternalChatInfo } from '@/hooks/use-chat-request'; -import i18n from '@/locales/config'; +import i18n, { changeLanguageAsync } from '@/locales/config'; import { useSendNextSharedMessage } from '@/pages/agent/hooks/use-send-shared-message'; import { MessageCircle, Minimize2, Send, X } from 'lucide-react'; import React, { useCallback, useEffect, useRef, useState } from 'react'; @@ -136,7 +136,7 @@ const FloatingChatWidget = () => { }, 50); if (locale && i18n.language !== locale) { - i18n.changeLanguage(locale); + changeLanguageAsync(locale); } return () => clearTimeout(timer); diff --git a/web/src/components/message-item/group-button.tsx b/web/src/components/message-item/group-button.tsx index aad19cb40a1..58b451f261b 100644 --- a/web/src/components/message-item/group-button.tsx +++ b/web/src/components/message-item/group-button.tsx @@ -1,5 +1,11 @@ import { PromptIcon } from '@/assets/icon/next-icon'; import CopyToClipboard from '@/components/copy-to-clipboard'; +import { ToggleGroup, ToggleGroupItem } from '@/components/ui/toggle-group'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; import { useSetModalState } from '@/hooks/common-hooks'; import { IRemoveMessageById } from '@/hooks/logic-hooks'; import { @@ -10,7 +16,6 @@ import { SoundOutlined, SyncOutlined, } from '@ant-design/icons'; -import { Radio, Tooltip } from 'antd'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import FeedbackDialog from '../feedback-dialog'; @@ -50,34 +55,44 @@ export const AssistantGroupButton = ({ return ( <> - - + + - + {showLoudspeaker && ( - - - {isPlaying ? : } + + + + + {isPlaying ? : } + + + {t('chat.read')} - + )} {showLikeButton && ( <> - + - - + + - + )} {prompt && ( - + - + )} - + {visible && ( - + + - + {regenerateMessage && ( - - - + + + + + {t('chat.regenerate')} - + )} {removeMessageById && ( - - - + + + + + + {t('common.delete')} - + )} - + ); }; diff --git a/web/src/components/next-message-item/group-button.tsx b/web/src/components/next-message-item/group-button.tsx index 1719511de32..6c0620f75f3 100644 --- a/web/src/components/next-message-item/group-button.tsx +++ b/web/src/components/next-message-item/group-button.tsx @@ -1,5 +1,10 @@ import { PromptIcon } from '@/assets/icon/next-icon'; import CopyToClipboard from '@/components/copy-to-clipboard'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; import { useSetModalState } from '@/hooks/common-hooks'; import { IRemoveMessageById } from '@/hooks/logic-hooks'; import { AgentChatContext } from '@/pages/agent/context'; @@ -13,7 +18,6 @@ import { SoundOutlined, SyncOutlined, } from '@ant-design/icons'; -import { Radio, Tooltip } from 'antd'; import { Download, NotebookText } from 'lucide-react'; import { useCallback, useContext } from 'react'; import { useTranslation } from 'react-i18next'; @@ -80,8 +84,13 @@ export const AssistantGroupButton = ({ {showLoudspeaker && ( - - {isPlaying ? : } + + + + {isPlaying ? : } + + + {t('chat.read')} @@ -97,9 +106,9 @@ export const AssistantGroupButton = ({ )} {prompt && ( - + - + )} {showLog && ( @@ -168,28 +177,39 @@ export const UserGroupButton = ({ const { t } = useTranslation(); return ( - - + + - + {regenerateMessage && ( - - - + + + + + {t('chat.regenerate')} - + )} {removeMessageById && ( - - - + + + + + + {t('common.delete')} - + )} - + ); }; diff --git a/web/src/components/parse-configuration/index.tsx b/web/src/components/parse-configuration/index.tsx deleted file mode 100644 index 446412bf45e..00000000000 --- a/web/src/components/parse-configuration/index.tsx +++ /dev/null @@ -1,217 +0,0 @@ -import { DocumentParserType } from '@/constants/knowledge'; -import { useTranslate } from '@/hooks/common-hooks'; -import { PlusOutlined } from '@ant-design/icons'; -import { Button, Flex, Form, Input, InputNumber, Slider, Switch } from 'antd'; -import random from 'lodash/random'; - -export const excludedParseMethods = [ - DocumentParserType.Table, - DocumentParserType.Resume, - DocumentParserType.One, - DocumentParserType.Picture, - DocumentParserType.KnowledgeGraph, - DocumentParserType.Qa, - DocumentParserType.Tag, -]; - -export const showRaptorParseConfiguration = ( - parserId: DocumentParserType | undefined, -) => { - return !excludedParseMethods.some((x) => x === parserId); -}; - -export const excludedTagParseMethods = [ - DocumentParserType.Table, - DocumentParserType.KnowledgeGraph, - DocumentParserType.Tag, -]; - -export const showTagItems = (parserId: DocumentParserType) => { - return !excludedTagParseMethods.includes(parserId); -}; - -// The three types "table", "resume" and "one" do not display this configuration. -const ParseConfiguration = () => { - const form = Form.useFormInstance(); - const { t } = useTranslate('knowledgeConfiguration'); - - const handleGenerate = () => { - form.setFieldValue( - ['parser_config', 'raptor', 'random_seed'], - random(10000), - ); - }; - - return ( - <> - - - - - prevValues.parser_config.raptor.use_raptor !== - curValues.parser_config.raptor.use_raptor - } - > - {({ getFieldValue }) => { - const useRaptor = getFieldValue([ - 'parser_config', - 'raptor', - 'use_raptor', - ]); - - return ( - useRaptor && ( - <> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ) - ); - }} - - - ); -}; - -export default ParseConfiguration; diff --git a/web/src/components/parse-configuration/raptor-form-fields-old.tsx b/web/src/components/parse-configuration/raptor-form-fields-old.tsx deleted file mode 100644 index 28cd79cedbc..00000000000 --- a/web/src/components/parse-configuration/raptor-form-fields-old.tsx +++ /dev/null @@ -1,146 +0,0 @@ -import { DocumentParserType } from '@/constants/knowledge'; -import { useTranslate } from '@/hooks/common-hooks'; -import random from 'lodash/random'; -import { Plus } from 'lucide-react'; -import { useCallback } from 'react'; -import { useFormContext, useWatch } from 'react-hook-form'; -import { SliderInputFormField } from '../slider-input-form-field'; -import { Button } from '../ui/button'; -import { - FormControl, - FormField, - FormItem, - FormLabel, - FormMessage, -} from '../ui/form'; -import { Input } from '../ui/input'; -import { Switch } from '../ui/switch'; -import { Textarea } from '../ui/textarea'; - -export const excludedParseMethods = [ - DocumentParserType.Table, - DocumentParserType.Resume, - DocumentParserType.One, - DocumentParserType.Picture, - DocumentParserType.KnowledgeGraph, - DocumentParserType.Qa, - DocumentParserType.Tag, -]; - -export const showRaptorParseConfiguration = ( - parserId: DocumentParserType | undefined, -) => { - return !excludedParseMethods.some((x) => x === parserId); -}; - -export const excludedTagParseMethods = [ - DocumentParserType.Table, - DocumentParserType.KnowledgeGraph, - DocumentParserType.Tag, -]; - -export const showTagItems = (parserId: DocumentParserType) => { - return !excludedTagParseMethods.includes(parserId); -}; - -const UseRaptorField = 'parser_config.raptor.use_raptor'; -const RandomSeedField = 'parser_config.raptor.random_seed'; - -// The three types "table", "resume" and "one" do not display this configuration. - -const RaptorFormFields = () => { - const form = useFormContext(); - const { t } = useTranslate('knowledgeConfiguration'); - const useRaptor = useWatch({ name: UseRaptorField }); - - const handleGenerate = useCallback(() => { - form.setValue(RandomSeedField, random(10000)); - }, [form]); - - return ( - <> - ( - - {t('useRaptor')} - - - - - - )} - /> - {useRaptor && ( -
- ( - - {t('prompt')} - - + @@ -74,7 +82,10 @@ export default function ChatBasicSetting() { {t('emptyResponse')} - + @@ -89,7 +100,10 @@ export default function ChatBasicSetting() { {t('setAnOpener')} - + diff --git a/web/src/pages/next-chats/chat/app-settings/chat-prompt-engine.tsx b/web/src/pages/next-chats/chat/app-settings/chat-prompt-engine.tsx index a7c05f7b4ab..8f84f090979 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-prompt-engine.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-prompt-engine.tsx @@ -15,12 +15,14 @@ import { import { Textarea } from '@/components/ui/textarea'; import { UseKnowledgeGraphFormField } from '@/components/use-knowledge-graph-item'; import { useTranslate } from '@/hooks/common-hooks'; +import { getDirAttribute } from '@/utils/text-direction'; import { useFormContext } from 'react-hook-form'; import { DynamicVariableForm } from './dynamic-variable'; export function ChatPromptEngine() { const { t } = useTranslate('chat'); const form = useFormContext(); + const systemPromptValue = form.watch('prompt_config.system'); return (
@@ -36,6 +38,7 @@ export function ChatPromptEngine() { rows={8} placeholder={t('messagePlaceholder')} className="overflow-y-auto" + dir={getDirAttribute(systemPromptValue || '')} /> diff --git a/web/src/pages/next-chats/chat/app-settings/dynamic-variable.tsx b/web/src/pages/next-chats/chat/app-settings/dynamic-variable.tsx index 53a593c8a79..8dd4372c032 100644 --- a/web/src/pages/next-chats/chat/app-settings/dynamic-variable.tsx +++ b/web/src/pages/next-chats/chat/app-settings/dynamic-variable.tsx @@ -9,6 +9,7 @@ import { import { BlurInput } from '@/components/ui/input'; import { Separator } from '@/components/ui/separator'; import { Switch } from '@/components/ui/switch'; +import { getDirAttribute } from '@/utils/text-direction'; import { Plus, X } from 'lucide-react'; import { useCallback } from 'react'; import { useFieldArray, useFormContext } from 'react-hook-form'; @@ -58,53 +59,58 @@ export function DynamicVariableForm() {
- {fields.map((field, index) => ( -
- ( - - - - - - - )} - /> + {fields.map((field, index) => { + const typeField = `${name}.${index}.key`; + const keyValue = form.watch(typeField); + return ( +
+ ( + + + + + + + )} + /> - + - ( - - - - - - - )} - /> + ( + + + + + + + )} + /> - -
- ))} + +
+ ); + })}
diff --git a/web/src/pages/next-chats/chat/app-settings/saving-button.tsx b/web/src/pages/next-chats/chat/app-settings/saving-button.tsx index 83cd7a6e8dc..bc880f18d1f 100644 --- a/web/src/pages/next-chats/chat/app-settings/saving-button.tsx +++ b/web/src/pages/next-chats/chat/app-settings/saving-button.tsx @@ -9,7 +9,11 @@ export function SavingButton({ loading }: SaveButtonProps) { const { t } = useTranslation(); return ( - + {t('common.save')} ); diff --git a/web/src/pages/next-search/markdown-content/index.tsx b/web/src/pages/next-search/markdown-content/index.tsx index 118f2c2aadb..fd0895a106a 100644 --- a/web/src/pages/next-search/markdown-content/index.tsx +++ b/web/src/pages/next-search/markdown-content/index.tsx @@ -18,10 +18,13 @@ import 'katex/dist/katex.min.css'; // `rehype-katex` does not import the CSS for import { currentReg, + parseCitationIndex, preprocessLaTeX, replaceTextByOldReg, replaceThinkToSection, } from '@/utils/chat'; +import { citationMarkerReg } from '@/utils/citation-utils'; +import { getDirAttribute } from '@/utils/text-direction'; import { Button } from '@/components/ui/button'; import { @@ -46,7 +49,7 @@ const styles = { fileThumbnail: 'inline-block max-w-[40px]', }; -const getChunkIndex = (match: string) => Number(match); +const getChunkIndex = (match: string) => parseCitationIndex(match); // TODO: The display of the table is inconsistent with the display previously placed in the MessageItem. const MarkdownContent = ({ @@ -234,42 +237,51 @@ const MarkdownContent = ({ [getPopoverContent], ); + const dir = getDirAttribute(content.replace(citationMarkerReg, '')); + return ( - - renderReference(children), - code(props: any) { - const { children, className, ...rest } = props; - const restProps = omit(rest, 'node'); - const match = /language-(\w+)/.exec(className || ''); - return match ? ( - - {String(children).replace(/\n$/, '')} - - ) : ( - - {children} - - ); - }, - } as any - } > - {contentWithCursor} - + ( +

{children}

+ ), + 'custom-typography': ({ children }: { children: string }) => + renderReference(children), + code(props: any) { + const { children, className, ...rest } = props; + const restProps = omit(rest, 'node'); + const match = /language-(\w+)/.exec(className || ''); + return match ? ( + + {String(children).replace(/\n$/, '')} + + ) : ( + + {children} + + ); + }, + } as any + } + > + {contentWithCursor} +
+ ); }; diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index 1608bec6f43..eb5acb6d249 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -586,7 +586,11 @@ const SearchSetting: React.FC = ({ > {t('search.cancelText')} - - ))} - - )} + )} + + {title === 'login' && channels && channels.length > 0 && ( +
+ {channels.map((item) => ( + + ))} +
+ )} - {title === 'login' && registerEnabled && ( + {!disablePasswordLogin && title === 'login' && registerEnabled && (

{t('signInTip')} @@ -217,7 +223,7 @@ function LoginFormContent({

)} - {title === 'register' && ( + {!disablePasswordLogin && title === 'register' && (

{t('signUpTip')} @@ -369,14 +375,8 @@ const Login = () => {

{t('title')}

- {/* border border-accent-primary rounded-full */} - {/*
- {t('start')} -
*/}
- {/* Logo and Header */} - {/* Login Form */} { channels={channels || []} handleLoginWithChannel={handleLoginWithChannel} t={t} + disablePasswordLogin={!!config?.disablePasswordLogin} />
@@ -398,4 +399,4 @@ const Login = () => { ); }; -export default Login; +export default Login; \ No newline at end of file From e8b73cc5b48aa0f95c1482a881632e85ce9f609e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=91=E5=8D=BF?= <121151546+shaoqing404@users.noreply.github.com> Date: Mon, 2 Mar 2026 14:58:37 +0800 Subject: [PATCH 098/565] fix:absolute page index mix-up in DeepDoc PDF parser (#12848) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Summary: This PR addresses critical indexing issues in deepdoc/parser/pdf_parser.py that occur when parsing long PDFs with chunk-based pagination: Normalize rotated table page numbering: Rotated-table re-OCR now writes page_number in chunk-local 1-based form, eliminating double-addition of page_from offset that caused misalignment between table positions and document boxes. Convert absolute positions to chunk-local coordinates: When inserting tables/figures extracted via _extract_table_figure, positions are now converted from absolute (0-based) to chunk-local indices before distance matching and box insertion. This prevents IndexError and out-of-range accesses during paged parsing of long documents. Root Cause: The parser mixed absolute (0-based, document-global) and relative (1-based, chunk-local) page numbering systems. Table/figure positions from layout extraction carried absolute page numbers, but insertion logic expected chunk-local coordinates aligned with self.boxes and page_cum_height. Testing(I do): Manual verification: Parse a 200+ page PDF with from_page > 0 and table rotation enabled. Confirm that: Tables and figures appear on correct pages No IndexError or position mismatches occur Page numbers in output match expected chunk-local offsets Automated testing: 我没做 ## Separate Discussion: Memory Optimization Strategy(from codex-5.2-max and claude 4.5 opus and me) ### Context The current implementation loads entire page ranges into memory (`__images__`, `page_chars`, intermediates), which can cause RAM exhaustion on large documents. While the page numbering fix resolves correctness issues, scalability remains a concern. ### Proposed Architecture **Pipeline-Driven Chunking with Explicit Resource Management:** 1. **Authoritative chunk planning**: Accept page-range specifications from upstream pipeline as the single source of truth. The parser should be a stateless worker that processes assigned chunks without making independent pagination decisions. 2. **Granular memory lifecycle**: ```python for chunk_spec in chunk_plan: # Load only chunk_spec.pages into __images__ page_images = load_page_range(chunk_spec.start, chunk_spec.end) # Process with offset tracking results = process_chunk(page_images, offset=chunk_spec.start) # Explicit cleanup before next iteration del page_images, page_chars, layout_intermediates gc.collect() # Force collection of large objects ``` 3. **Persistent lightweight state**: Keep model instances (layout detector, OCR engine), document metadata (outlines, PDF structure), and configuration across chunks to avoid reinitialization overhead (~2-5s per chunk for model loading). 4. **Adaptive fallback**: Provide `max_pages_per_chunk` (default: 50) only when pipeline doesn't supply a plan. Never exceed pipeline-specified ranges to maintain predictable memory bounds. 5. **Optional: Dynamic budgeting**: Expose a memory budget parameter that adjusts chunk size based on observed image dimensions and format (e.g., reduce chunk size for high-DPI scanned documents). ### Benefits - **Predictable memory footprint**: RAM usage bounded by `chunk_size × avg_page_size` rather than total document size - **Horizontal scalability**: Enables parallel chunk processing across workers - **Failure isolation**: Page extraction errors affect only current chunk, not entire document - **Cloud-friendly**: Works within container memory limits (e.g., 2-4GB per worker) ### Trade-offs - **Increased I/O**: Re-opening PDF for each chunk vs. keeping file handle (mitigated by page-range seeks) - **Complexity**: Requires careful offset tracking and stateful coordination between pipeline and parser - **Warmup cost**: Model initialization overhead amortized across chunks (acceptable for documents >100 pages) ### Implementation Priority This optimization should be **deferred to a separate PR** after the current correctness fix is merged, as: 1. It requires broader architectural changes across the pipeline 2. Current fix is critical for correctness and can be backported 3. Memory optimization needs comprehensive benchmarking on representative document corpus ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- deepdoc/parser/pdf_parser.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 6681e4a893a..49880c3c552 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -1594,15 +1594,32 @@ def min_rectangle_distance(rect1, rect2): return math.sqrt(dx * dx + dy * dy) # + (pn2-pn1)*10000 for (img, txt), poss in tbls_or_figs: + # Positions coming from _extract_table_figure carry absolute 0-based page + # indices (page_from offset). Convert back to chunk-local indices so we + # stay consistent with self.boxes/page_cum_height, which are all relative + # to the current parsing window. + local_poss = [] + for pn, left, right, top, bott in poss: + local_pn = pn - self.page_from + if 0 <= local_pn < len(self.page_cum_height) - 1: + local_poss.append((local_pn, left, right, top, bott)) + else: + logging.debug(f"Skip out-of-range table/figure position pn={pn}, page_from={self.page_from}") + if not local_poss: + logging.debug("No valid local positions for table/figure; skip insertion.") + continue + bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)] dists = [ - (min_rectangle_distance((pn, left, right, top + self.page_cum_height[pn], bott + self.page_cum_height[pn]), rect), i) for i, rect in bboxes for pn, left, right, top, bott in poss + (min_rectangle_distance((pn, left, right, top + self.page_cum_height[pn], bott + self.page_cum_height[pn]), rect), i) + for i, rect in bboxes + for pn, left, right, top, bott in local_poss ] min_i = np.argmin(dists, axis=0)[0] min_i, rect = bboxes[dists[min_i][-1]] if isinstance(txt, list): txt = "\n".join(txt) - pn, left, right, top, bott = poss[0] + pn, left, right, top, bott = local_poss[0] if self.boxes[min_i]["bottom"] < top + self.page_cum_height[pn]: min_i += 1 self.boxes.insert( From a9f349dcde2e730a034c12ee5161377d3b8e9a58 Mon Sep 17 00:00:00 2001 From: liuxiaoyusky <49766325+liuxiaoyusky@users.noreply.github.com> Date: Mon, 2 Mar 2026 15:31:40 +0800 Subject: [PATCH 099/565] Fix: respect user-configured chunk_token_num for MinerU/docling/paddleocr parsers (#13234) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary When using MinerU, docling, TCADP, or paddleocr as the PDF parser with the General (naive) chunk method, the user-configured `chunk_token_num` is **unconditionally overwritten to 0** at [rag/app/naive.py#L858-L859](https://github.com/infiniflow/ragflow/blob/main/rag/app/naive.py#L858-L859), effectively disabling chunk merging regardless of what the user sets in the UI. ### Problem A user sets `chunk_token_num = 2048` in the dataset configuration UI, expecting small parser blocks to be merged into larger chunks. However, this line: ```python if name in ["tcadp", "docling", "mineru", "paddleocr"]: parser_config["chunk_token_num"] = 0 ``` silently overrides the user's setting. As a result, every MinerU output block becomes its own chunk. For short documents (e.g. a 3-page PDF fund factsheet parsed by MinerU), this produces **47 tiny chunks** — some as small as 11 characters (`"July 2025"`) or 15 characters (`"CIES Eligible"`). This severely degrades retrieval quality: vector embeddings of such short fragments have minimal semantic value, and keyword search produces excessive noise. ### Fix Only apply the `chunk_token_num = 0` override when the user has **not** explicitly configured a positive value: ```python if name in ["tcadp", "docling", "mineru", "paddleocr"]: if int(parser_config.get("chunk_token_num", 0)) <= 0: parser_config["chunk_token_num"] = 0 ``` This preserves the original default behavior (no merging) while respecting the user's explicit configuration. ### Before / After (MinerU, 3-page PDF, chunk_token_num=2048) | | Before | After | |---|---|---| | Chunks produced | 47 | ~8 (merged by token limit) | | Smallest chunk | 11 chars | ~500 chars | | User setting respected | No | Yes | ## Test plan - [ ] Parse a PDF with MinerU and `chunk_token_num = 2048` → verify chunks are merged up to token limit - [ ] Parse a PDF with MinerU and `chunk_token_num = 0` (or default) → verify original behavior (no merging) - [ ] Parse a PDF with DeepDOC parser → verify no change in behavior (not affected by this code path) - [ ] Repeat with docling/paddleocr if available --- rag/app/naive.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rag/app/naive.py b/rag/app/naive.py index ef84fa69cbc..22606c3b32c 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -881,7 +881,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca tables = append_context2table_image4pdf(sections, tables, image_context_size) if name in ["tcadp", "docling", "mineru", "paddleocr"]: - parser_config["chunk_token_num"] = 0 + if int(parser_config.get("chunk_token_num", 0)) <= 0: + parser_config["chunk_token_num"] = 0 res = tokenize_table(tables, doc, is_english) callback(0.8, "Finish parsing.") From a1c726b9b17d0949450ff654ff1dd86629a70655 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Mon, 2 Mar 2026 15:37:08 +0800 Subject: [PATCH 100/565] Feat: add more models for siliconflow and tongyi-qwen (#13311) ### What problem does this PR solve? Feat: add more models for siliconflow and tongyi-qwen ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- conf/llm_factories.json | 133 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 1dcbb852382..89c089444c9 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -495,6 +495,34 @@ "model_type": "chat", "is_tools": true }, + { + "llm_name": "qwen3.5-plus", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen3.5-plus-2026-02-15", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen3.5-flash", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen3.5-flash-2026-02-23", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "qwen3-max", "tags": "LLM,CHAT,256k", @@ -3054,6 +3082,111 @@ "model_type": "chat", "is_tools": false }, + { + "llm_name": "Pro/MiniMaxAI/MiniMax-M2.5", + "tags": "LLM,CHAT,197k", + "max_tokens": 197000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/zai-org/GLM-5", + "tags": "LLM,CHAT,205k", + "max_tokens": 205000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/moonshotai/Kimi-K2.5", + "tags": "LLM,CHAT,IMAGE2TEXT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/zai-org/GLM-4.7", + "tags": "LLM,CHAT,205k", + "max_tokens": 205000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.2", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/deepseek-ai/DeepSeek-V3.2", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.1-Terminus", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/deepseek-ai/DeepSeek-V3.1-Terminus", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/MiniMaxAI/MiniMax-M2.1", + "tags": "LLM,CHAT,197k", + "max_tokens": 197000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "stepfun-ai/Step-3.5-Flash", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-4.6V", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "moonshotai/Kimi-K2-Thinking", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Pro/moonshotai/Kimi-K2-Thinking", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-4.6", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Kwaipilot/KAT-Dev", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "BAAI/bge-m3", "tags": "LLM,EMBEDDING,8k", From 70a3cba3e5e8b54cad384a9126ade64fbed10fb4 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Mon, 2 Mar 2026 15:37:42 +0800 Subject: [PATCH 101/565] Feat: Support siliconflow.com (#13308) ### What problem does this PR solve? Feat: Support siliconflow.com ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/llm_app.py | 23 +- conf/llm_factories.json | 457 ++++++++++++++++++ rag/llm/embedding_model.py | 9 +- rag/llm/rerank_model.py | 9 +- web/src/hooks/use-llm-request.tsx | 1 + web/src/locales/bg.ts | 2 + web/src/locales/de.ts | 2 + web/src/locales/en.ts | 2 + web/src/locales/es.ts | 2 + web/src/locales/fr.ts | 2 + web/src/locales/id.ts | 2 + web/src/locales/it.ts | 2 + web/src/locales/ja.ts | 2 + web/src/locales/pt-br.ts | 2 + web/src/locales/ru.ts | 2 + web/src/locales/vi.ts | 2 + web/src/locales/zh-traditional.ts | 2 + web/src/locales/zh.ts | 2 + .../user-setting/setting-model/hooks.tsx | 13 +- .../modal/api-key-modal/index.tsx | 9 +- 20 files changed, 530 insertions(+), 17 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 9d2fed80262..6fa9a8e3d8c 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -34,7 +34,7 @@ def factories(): try: fac = get_allowed_llm_factories() - fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI", "Builtin"]] + fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]] llms = LLMService.get_all() mdl_types = {} for m in llms: @@ -64,13 +64,22 @@ async def set_api_key(): # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] + base_url = req.get("base_url", "") + source_factory = req.get("source_fid", factory) extra = {"provider": factory} timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) + source_llms = list(LLMService.query(fid=source_factory)) + if not source_llms: + msg = f"No models configured for {factory} (source: {source_factory})." + if req.get("verify", False): + return get_json_result(data={"message": msg, "success": False}) + return get_data_error_result(message=msg) + msg = "" - for llm in LLMService.query(fid=factory): + for llm in source_llms: if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=base_url) try: arr, tc = await asyncio.wait_for( asyncio.to_thread(mdl.encode, ["Test if the api key is available"]), @@ -83,7 +92,7 @@ async def set_api_key(): msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) elif not chat_passed and llm.model_type == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." - mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) + mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra) try: m, tc = await asyncio.wait_for( mdl.async_chat( @@ -100,7 +109,7 @@ async def set_api_key(): msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) elif not rerank_passed and llm.model_type == LLMType.RERANK.value: assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." - mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=base_url) try: arr, tc = await asyncio.wait_for( asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]), @@ -122,12 +131,12 @@ async def set_api_key(): if msg: return get_data_error_result(message=msg) - llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")} + llm_config = {"api_key": req["api_key"], "api_base": base_url} for n in ["model_type", "llm_name"]: if n in req: llm_config[n] = req[n] - for llm in LLMService.query(fid=factory): + for llm in source_llms: llm_config["max_tokens"] = llm.max_tokens if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config): TenantLLMService.save( diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 89c089444c9..8f898da9021 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3245,6 +3245,463 @@ } ] }, + { + "name": "siliconflow_intl", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,IMAGE2TEXT,TTS", + "status": "1", + "rank": "781", + "llm": [ + { + "llm_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "MiniMaxAI/MiniMax-M2.5", + "tags": "LLM,CHAT,197k", + "max_tokens": 197000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-5", + "tags": "LLM,CHAT,205k", + "max_tokens": 205000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "stepfun-ai/Step-3.5-Flash", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "moonshotai/Kimi-K2.5", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMaxAI/MiniMax-M2.1", + "tags": "LLM,CHAT,197k", + "max_tokens": 197000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-4.7", + "tags": "LLM,CHAT,205k", + "max_tokens": 205000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.2", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.2-Exp", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-4.6V", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.1-Terminus", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3.1", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-V3", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-R1", + "tags": "LLM,CHAT,154k", + "max_tokens": 154000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "nex-agi/DeepSeek-V3.1-Nex-N1", + "tags": "LLM,CHAT,164k", + "max_tokens": 164000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-VL-32B-Instruct", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-VL-32B-Thinking", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "zai-org/GLM-4.5V", + "tags": "LLM,CHAT,66k", + "max_tokens": 66000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "inclusionAI/Ling-mini-2.0", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "inclusionAI/Ring-flash-2.0", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "inclusionAI/Ling-flash-2.0", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "tencent/Hunyuan-MT-7B", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Omni-30B-A3B-Captioner", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Omni-30B-A3B-Thinking", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "tags": "LLM,CHAT,65k", + "max_tokens": 65000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Next-80B-A3B-Thinking", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Coder-480B-A35B-Instruct", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Coder-30B-A3B-Instruct", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-30B-A3B-Thinking-2507", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-30B-A3B-Instruct-2507", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-235B-A22B-Instruct-2507", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-235B-A22B-Thinking-2507", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "ByteDance-Seed/Seed-OSS-36B-Instruct", + "tags": "LLM,CHAT,262k", + "max_tokens": 262000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "baidu/ERNIE-4.5-300B-A47B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "tencent/Hunyuan-A13B-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "moonshotai/Kimi-K2-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-32B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-14B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-8B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Reranker-8B", + "tags": "LLM,RE-RANK,33k", + "max_tokens": 33000, + "model_type": "rerank", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Embedding-8B", + "tags": "LLM,EMBEDDING,33k", + "max_tokens": 33000, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Reranker-4B", + "tags": "LLM,RE-RANK,33k", + "max_tokens": 33000, + "model_type": "rerank", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Embedding-4B", + "tags": "LLM,EMBEDDING,33k", + "max_tokens": 33000, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Reranker-0.6B", + "tags": "LLM,RE-RANK,33k", + "max_tokens": 33000, + "model_type": "rerank", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen3-Embedding-0.6B", + "tags": "LLM,EMBEDDING,33k", + "max_tokens": 33000, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "THUDM/GLM-Z1-32B-0414", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "THUDM/GLM-4-32B-0414", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "THUDM/GLM-Z1-9B-0414", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "THUDM/GLM-4-9B-0414", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-VL-32B-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/QwQ-32B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-VL-72B-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-VL-7B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-Coder-32B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen2.5-72B-Instruct-128K", + "tags": "LLM,CHAT,131k", + "max_tokens": 131000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-ai/deepseek-vl2", + "tags": "LLM,CHAT,4k", + "max_tokens": 4000, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "Qwen/Qwen2.5-72B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-32B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-14B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen2.5-7B-Instruct", + "tags": "LLM,CHAT,33k", + "max_tokens": 33000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "IndexTeam/IndexTTS-2", + "tags": "TTS", + "max_tokens": 1000, + "model_type": "tts", + "is_tools": false + } + ] + }, { "name": "PPIO", "logo": "", diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 699b8be3366..79dc96accff 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -770,14 +770,17 @@ class SILICONFLOWEmbed(Base): _FACTORY_NAME = "SILICONFLOW" def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"): - if not base_url: - base_url = "https://api.siliconflow.cn/v1/embeddings" + normalized_base_url = (base_url or "").strip() + if not normalized_base_url: + normalized_base_url = "https://api.siliconflow.cn/v1/embeddings" + if "/embeddings" not in normalized_base_url: + normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "embeddings").rstrip("/") self.headers = { "accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {key}", } - self.base_url = base_url + self.base_url = normalized_base_url self.model_name = model_name def encode(self, texts: list): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index d9a4a740592..b8fd19dacd4 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -274,10 +274,13 @@ class SILICONFLOWRerank(Base): _FACTORY_NAME = "SILICONFLOW" def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"): - if not base_url: - base_url = "https://api.siliconflow.cn/v1/rerank" + normalized_base_url = (base_url or "").strip() + if not normalized_base_url: + normalized_base_url = "https://api.siliconflow.cn/v1/rerank" + if "/rerank" not in normalized_base_url: + normalized_base_url = urljoin(f"{normalized_base_url.rstrip('/')}/", "rerank").rstrip("/") self.model_name = model_name - self.base_url = base_url + self.base_url = normalized_base_url self.headers = { "accept": "application/json", "content-type": "application/json", diff --git a/web/src/hooks/use-llm-request.tsx b/web/src/hooks/use-llm-request.tsx index 8579c8bd458..65da65d4972 100644 --- a/web/src/hooks/use-llm-request.tsx +++ b/web/src/hooks/use-llm-request.tsx @@ -271,6 +271,7 @@ export interface IApiKeySavingParams { llm_name?: string; model_type?: string; base_url?: string; + source_fid?: string; verify?: boolean; } diff --git a/web/src/locales/bg.ts b/web/src/locales/bg.ts index 6ce56f22a26..03b391cf975 100644 --- a/web/src/locales/bg.ts +++ b/web/src/locales/bg.ts @@ -1137,6 +1137,8 @@ The above is the content you need to summarize.`, 'Ако вашият API ключ е от OpenAI, просто го игнорирайте. Всеки друг междинен доставчик ще предостави този base url с API ключа.', tongyiBaseUrlTip: 'За китайски потребители не е необходимо да попълвате или използвайте https://dashscope.aliyuncs.com/compatible-mode/v1. За международни потребители използвайте https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'За китайски потребители не е необходимо да попълвате или използвайте https://api.siliconflow.cn/v1. За международни потребители използвайте https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Само за международни потребители, вижте съвета)', minimaxBaseUrlTip: diff --git a/web/src/locales/de.ts b/web/src/locales/de.ts index 147885ae596..508115b186a 100644 --- a/web/src/locales/de.ts +++ b/web/src/locales/de.ts @@ -1177,6 +1177,8 @@ Beispiel: Virtual Hosted Style`, 'Wenn Ihr API-Schlüssel von OpenAI stammt, ignorieren Sie dies. Andere Zwischenanbieter geben diese Basis-URL mit dem API-Schlüssel an.', tongyiBaseUrlTip: 'Für chinesische Benutzer ist keine Eingabe erforderlich oder verwenden Sie https://dashscope.aliyuncs.com/compatible-mode/v1. Für internationale Benutzer verwenden Sie https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Für chinesische Benutzer ist keine Eingabe erforderlich oder verwenden Sie https://api.siliconflow.cn/v1. Für internationale Benutzer verwenden Sie https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Nur für internationale Benutzer, bitte Hinweis beachten)', minimaxBaseUrlTip: diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 357334abcf8..5117713a29d 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1172,6 +1172,8 @@ Example: Virtual Hosted Style`, 'If your API key is from OpenAI, just ignore it. Any other intermediate providers will give this base url with the API key.', tongyiBaseUrlTip: 'For Chinese users, no need to fill in or use https://dashscope.aliyuncs.com/compatible-mode/v1. For international users, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'For Chinese users, no need to fill in or use https://api.siliconflow.cn/v1. For international users, use https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(International users only, please see tip)', minimaxBaseUrlTip: 'International users only: use https://api.minimax.io/v1', diff --git a/web/src/locales/es.ts b/web/src/locales/es.ts index 03bbe4cf5e9..816353495fa 100644 --- a/web/src/locales/es.ts +++ b/web/src/locales/es.ts @@ -358,6 +358,8 @@ export default { 'Si tu clave API es de OpenAI, ignora esto. Cualquier otro proveedor intermedio proporcionará esta URL base junto con la clave API.', tongyiBaseUrlTip: 'Para usuarios chinos, no es necesario rellenar o usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuarios internacionales, usar https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Para usuarios chinos, no es necesario rellenar o usar https://api.siliconflow.cn/v1. Para usuarios internacionales, usar https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Solo para usuarios internacionales, por favor ver consejo)', minimaxBaseUrlTip: diff --git a/web/src/locales/fr.ts b/web/src/locales/fr.ts index 8e68e7d82e0..ccb735c4103 100644 --- a/web/src/locales/fr.ts +++ b/web/src/locales/fr.ts @@ -542,6 +542,8 @@ export default { "Si votre clé API provient d'OpenAI, ignorez ceci. Tout autre fournisseur intermédiaire fournira cette URL de base avec la clé API.", tongyiBaseUrlTip: 'Pour les utilisateurs chinois, pas besoin de remplir ou utiliser https://dashscope.aliyuncs.com/compatible-mode/v1. Pour les utilisateurs internationaux, utilisez https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Pour les utilisateurs chinois, pas besoin de remplir ou utiliser https://api.siliconflow.cn/v1. Pour les utilisateurs internationaux, utilisez https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: "(Utilisateurs internationaux uniquement, veuillez consulter l'astuce)", minimaxBaseUrlTip: diff --git a/web/src/locales/id.ts b/web/src/locales/id.ts index dc5f9a4b24d..7969d85666c 100644 --- a/web/src/locales/id.ts +++ b/web/src/locales/id.ts @@ -529,6 +529,8 @@ export default { 'Jika kunci API Anda berasal dari OpenAI, abaikan saja. Penyedia perantara lainnya akan memberikan base url ini dengan kunci API.', tongyiBaseUrlTip: 'Untuk pengguna Tiongkok, tidak perlu diisi atau gunakan https://dashscope.aliyuncs.com/compatible-mode/v1. Untuk pengguna internasional, gunakan https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Untuk pengguna Tiongkok, tidak perlu diisi atau gunakan https://api.siliconflow.cn/v1. Untuk pengguna internasional, gunakan https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Hanya untuk pengguna internasional, silakan lihat tip)', minimaxBaseUrlTip: diff --git a/web/src/locales/it.ts b/web/src/locales/it.ts index cd003bc2886..04222f4607c 100644 --- a/web/src/locales/it.ts +++ b/web/src/locales/it.ts @@ -770,6 +770,8 @@ Quanto sopra è il contenuto che devi riassumere.`, baseUrl: 'URL Base', baseUrlTip: 'Se la tua chiave API è da OpenAI, ignoralo. Qualsiasi altro fornitore intermedio fornirà questo URL base con la chiave API.', + siliconBaseUrlTip: + 'For Chinese users, no need to fill in or use https://api.siliconflow.cn/v1. For international users, use https://api.siliconflow.com/v1', modify: 'Modifica', systemModelSettings: 'Imposta modelli predefiniti', chatModel: 'LLM', diff --git a/web/src/locales/ja.ts b/web/src/locales/ja.ts index 13a1a20eda4..3eb93aae5e0 100644 --- a/web/src/locales/ja.ts +++ b/web/src/locales/ja.ts @@ -569,6 +569,8 @@ export default { 'APIキーがOpenAIからのものであれば無視してください。他の中間プロバイダーはAPIキーと共にこのベースURLを提供します。', tongyiBaseUrlTip: '中国ユーザーの場合、記入不要または https://dashscope.aliyuncs.com/compatible-mode/v1 を使用してください。国際ユーザーは https://dashscope-intl.aliyuncs.com/compatible-mode/v1 を使用してください', + siliconBaseUrlTip: + '中国ユーザーの場合、入力不要または https://api.siliconflow.cn/v1 を使用してください。国際ユーザーは https://api.siliconflow.com/v1 を使用してください', tongyiBaseUrlPlaceholder: '(国際ユーザーのみ、ヒントをご覧ください)', minimaxBaseUrlTip: '国際ユーザーのみ:https://api.minimax.io/v1 を使用してください。', diff --git a/web/src/locales/pt-br.ts b/web/src/locales/pt-br.ts index ae313378b4d..1ce96814ca1 100644 --- a/web/src/locales/pt-br.ts +++ b/web/src/locales/pt-br.ts @@ -524,6 +524,8 @@ export default { 'Se sua chave da API for do OpenAI, ignore isso. Outros provedores intermediários fornecerão essa URL base com a chave da API.', tongyiBaseUrlTip: 'Para usuários chineses, não é necessário preencher ou usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuários internacionais, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Para usuários chineses, não é necessário preencher ou usar https://api.siliconflow.cn/v1. Para usuários internacionais, use https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Apenas para usuários internacionais, consulte a dica)', minimaxBaseUrlTip: diff --git a/web/src/locales/ru.ts b/web/src/locales/ru.ts index a9a59cefcb2..60b8d8ab1c7 100644 --- a/web/src/locales/ru.ts +++ b/web/src/locales/ru.ts @@ -894,6 +894,8 @@ export default { 'Если ваш API ключ от OpenAI, просто проигнорируйте это. Любые другие промежуточные провайдеры дадут этот базовый url вместе с API ключом.', tongyiBaseUrlTip: 'Для китайских пользователей не нужно заполнять или используйте https://dashscope.aliyuncs.com/compatible-mode/v1. Для международных пользователей используйте https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + 'Для китайских пользователей не нужно заполнять или используйте https://api.siliconflow.cn/v1. Для международных пользователей используйте https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(Только для международных пользователей, см. подсказку)', minimaxBaseUrlTip: diff --git a/web/src/locales/vi.ts b/web/src/locales/vi.ts index 032bf236c07..ffe73121312 100644 --- a/web/src/locales/vi.ts +++ b/web/src/locales/vi.ts @@ -573,6 +573,8 @@ export default { baseUrl: 'Base-Url', baseUrlTip: 'Nếu khóa API của bạn từ OpenAI, chỉ cần bỏ qua nó. Bất kỳ nhà cung cấp trung gian nào khác sẽ cung cấp URL cơ sở này với khóa API.', + siliconBaseUrlTip: + 'For Chinese users, no need to fill in or use https://api.siliconflow.cn/v1. For international users, use https://api.siliconflow.com/v1', minimaxBaseUrlTip: 'Chỉ người dùng quốc tế: dùng https://api.minimax.io/v1.', minimaxBaseUrlPlaceholder: diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index ce1f4932cda..229c6bea5f1 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -626,6 +626,8 @@ export default { '如果您的 API 密鑰來自 OpenAI,請忽略它。任何其他中間提供商都會提供帶有 API 密鑰的基本 URL。', tongyiBaseUrlTip: '中國用戶無需填寫或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。國際用戶請使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1', + siliconBaseUrlTip: + '中國用戶無需填寫或使用 https://api.siliconflow.cn/v1。國際用戶請使用 https://api.siliconflow.com/v1', tongyiBaseUrlPlaceholder: '(僅國際用戶,請參閱提示)', minimaxBaseUrlTip: '僅國際用戶:使用 https://api.minimax.io/v1。', minimaxBaseUrlPlaceholder: '(僅國際用戶填寫 https://api.minimax.io/v1)', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 01e5b16716d..020abcc4fab 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -991,6 +991,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 '如果您的 API 密钥来自 OpenAI,请忽略它。 任何其他中间提供商都会提供带有 API 密钥的基本 URL。', tongyiBaseUrlTip: '对于中国用户,不需要填写或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。对于国际用户,使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1。', + siliconBaseUrlTip: + '对于中国用户,不需要填写或使用 https://api.siliconflow.cn/v1。对于国际用户,使用 https://api.siliconflow.com/v1。', tongyiBaseUrlPlaceholder: '(仅国际用户需要)', minimaxBaseUrlTip: '仅国际用户:使用 https://api.minimax.io/v1。', minimaxBaseUrlPlaceholder: '(仅国际用户填写 https://api.minimax.io/v1)', diff --git a/web/src/pages/user-setting/setting-model/hooks.tsx b/web/src/pages/user-setting/setting-model/hooks.tsx index 3a5ba677b51..72bd087042d 100644 --- a/web/src/pages/user-setting/setting-model/hooks.tsx +++ b/web/src/pages/user-setting/setting-model/hooks.tsx @@ -43,11 +43,20 @@ export const useSubmitApiKey = () => { if (!isVerify) { setSaveLoading(true); } - const ret = await saveApiKey({ + const payload: IApiKeySavingParams = { ...savingParams, ...postBody, verify: isVerify, - }); + }; + if (savingParams.llm_factory === LLMFactory.SILICONFLOW) { + payload.source_fid = (postBody.base_url || '') + .toLowerCase() + .includes('api.siliconflow.com') + ? 'siliconflow_intl' + : LLMFactory.SILICONFLOW; + } + + const ret = await saveApiKey(payload); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { diff --git a/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx index 4be300c0242..ff7a559e260 100644 --- a/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx @@ -41,6 +41,7 @@ const modelsWithBaseUrl = [ LLMFactory.AzureOpenAI, LLMFactory.TongYiQianWen, LLMFactory.MiniMax, + LLMFactory.SILICONFLOW, ]; const ApiKeyModal = ({ @@ -127,7 +128,9 @@ const ApiKeyModal = ({ ? t('minimaxBaseUrlTip') : llmFactory === LLMFactory.TongYiQianWen ? t('tongyiBaseUrlTip') - : t('baseUrlTip') + : llmFactory === LLMFactory.SILICONFLOW + ? t('siliconBaseUrlTip') + : t('baseUrlTip') } > {t('baseUrl')} @@ -140,7 +143,9 @@ const ApiKeyModal = ({ ? t('tongyiBaseUrlPlaceholder') : llmFactory === LLMFactory.MiniMax ? t('minimaxBaseUrlPlaceholder') - : 'https://api.openai.com/v1' + : llmFactory === LLMFactory.SILICONFLOW + ? 'https://api.siliconflow.cn/v1' + : 'https://api.openai.com/v1' } onKeyDown={handleKeyDown} className="w-full" From 80706640330e21e723348e064332118696c9ccea Mon Sep 17 00:00:00 2001 From: balibabu Date: Mon, 2 Mar 2026 16:53:24 +0800 Subject: [PATCH 102/565] Feat: Modify the style of the classification operator and fix some console errors. (#13314) ### What problem does this PR solve? Feat: Modify the style of the classification operator and fix some console errors. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/components/edit-tag/index.tsx | 4 +- web/src/components/originui/number-input.tsx | 198 +++++++++--------- web/src/pages/agent/flow-tooltip.tsx | 2 +- .../agent/form-sheet/form-config-map.tsx | 6 + .../categorize-form/dynamic-categorize.tsx | 82 ++++---- .../form/categorize-form/dynamic-example.tsx | 82 ++++---- .../agent/form/categorize-form/index.tsx | 9 +- .../form/categorize-form/use-form-schema.ts | 4 + .../variable-aggregator-form/name-input.tsx | 13 +- 9 files changed, 213 insertions(+), 187 deletions(-) diff --git a/web/src/components/edit-tag/index.tsx b/web/src/components/edit-tag/index.tsx index ccf10b8ef7d..decfbb968ba 100644 --- a/web/src/components/edit-tag/index.tsx +++ b/web/src/components/edit-tag/index.tsx @@ -16,7 +16,7 @@ interface EditTagsProps { } const EditTag = React.forwardRef( - ({ value = [], onChange, disabled }: EditTagsProps) => { + function EditTag({ value = [], onChange, disabled }, ref) { const [inputVisible, setInputVisible] = useState(false); const [inputValue, setInputValue] = useState(''); const inputRef = useRef(null); @@ -82,7 +82,7 @@ const EditTag = React.forwardRef( const tagChild = value?.map(forMap); return ( -
+
{inputVisible && ( = ({ - className, - value: initialValue, - onChange, - height, - min = 0, - max = Infinity, -}) => { - const [value, setValue] = useState(() => { - return initialValue ?? 0; - }); +const NumberInput = forwardRef( + function NumberInput( + { + className, + value: initialValue, + onChange, + height, + min = 0, + max = Infinity, + }, + ref, + ) { + const [value, setValue] = useState(() => { + return initialValue ?? 0; + }); - const valueRef = useRef(); + const valueRef = useRef(); - useEffect(() => { - if (initialValue !== undefined) { - setValue(initialValue); - } - }, [initialValue]); + useEffect(() => { + if (initialValue !== undefined) { + setValue(initialValue); + } + }, [initialValue]); - const handleDecrement = () => { - if (isNumber(value) && value > min) { - setValue(value - 1); - onChange?.(value - 1); - } - }; + const handleDecrement = () => { + if (isNumber(value) && value > min) { + setValue(value - 1); + onChange?.(value - 1); + } + }; - const handleIncrement = () => { - if (!isNumber(value)) { - return; - } - if (value > max - 1) { - return; - } - setValue(value + 1); - onChange?.(value + 1); - }; + const handleIncrement = () => { + if (!isNumber(value)) { + return; + } + if (value > max - 1) { + return; + } + setValue(value + 1); + onChange?.(value + 1); + }; - const handleChange = (e: React.ChangeEvent) => { - const currentValue = e.target.value; - const newValue = Number(currentValue); + const handleChange = (e: React.ChangeEvent) => { + const currentValue = e.target.value; + const newValue = Number(currentValue); - if (trim(currentValue) === '') { - if (isNumber(value)) { - valueRef.current = value; + if (trim(currentValue) === '') { + if (isNumber(value)) { + valueRef.current = value; + } + setValue(''); + return; } - setValue(''); - return; - } - if (!isNaN(newValue)) { - if (newValue > max || newValue < min) { - return; + if (!isNaN(newValue)) { + if (newValue > max || newValue < min) { + return; + } + setValue(newValue); + onChange?.(newValue); } - setValue(newValue); - onChange?.(newValue); - } - }; + }; - const handleBlur: FocusEventHandler = useCallback(() => { - if (isNumber(value)) { - onChange?.(value); - } else { - const previousValue = valueRef.current ?? min; - setValue(previousValue); - onChange?.(previousValue); - } - }, [min, onChange, value]); + const handleBlur: FocusEventHandler = useCallback(() => { + if (isNumber(value)) { + onChange?.(value); + } else { + const previousValue = valueRef.current ?? min; + setValue(previousValue); + onChange?.(previousValue); + } + }, [min, onChange, value]); - const style = useMemo( - () => ({ - height: height ? `${height.toString().replace('px', '')}px` : 'auto', - }), - [height], - ); - return ( -
- - - -
- ); -}; + + + +
+ ); + }, +); export default NumberInput; diff --git a/web/src/pages/agent/flow-tooltip.tsx b/web/src/pages/agent/flow-tooltip.tsx index 9386dd06b66..801b675b5ef 100644 --- a/web/src/pages/agent/flow-tooltip.tsx +++ b/web/src/pages/agent/flow-tooltip.tsx @@ -10,7 +10,7 @@ export const RunTooltip = ({ children }: PropsWithChildren) => { const { t } = useTranslation(); return ( - {children} + {children}

{t('flow.testRun')}

diff --git a/web/src/pages/agent/form-sheet/form-config-map.tsx b/web/src/pages/agent/form-sheet/form-config-map.tsx index a552412f13b..5ff04ce28c9 100644 --- a/web/src/pages/agent/form-sheet/form-config-map.tsx +++ b/web/src/pages/agent/form-sheet/form-config-map.tsx @@ -181,4 +181,10 @@ export const FormConfigMap = { [Operator.ExitLoop]: { component: () => <>, }, + [Operator.LoopStart]: { + component: () => <>, + }, + [Operator.ExcelProcessor]: { + component: () => <>, + }, }; diff --git a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx index 0807f7bfa6b..582cb57ff3f 100644 --- a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx +++ b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx @@ -11,17 +11,19 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; +import { Input, InputProps } from '@/components/ui/input'; +import { Separator } from '@/components/ui/separator'; import { BlurTextarea } from '@/components/ui/textarea'; import { useTranslate } from '@/hooks/common-hooks'; import { PlusOutlined } from '@ant-design/icons'; import { useUpdateNodeInternals } from '@xyflow/react'; import humanId from 'human-id'; import trim from 'lodash/trim'; -import { ChevronsUpDown, X } from 'lucide-react'; +import { ChevronsUpDown, Trash2 } from 'lucide-react'; import { ChangeEventHandler, FocusEventHandler, + forwardRef, memo, useCallback, useEffect, @@ -32,7 +34,7 @@ import { v4 as uuid } from 'uuid'; import { z } from 'zod'; import useGraphStore from '../../store'; import DynamicExample from './dynamic-example'; -import { useCreateCategorizeFormSchema } from './use-form-schema'; +import { CreateCategorizeFormSchema } from './use-form-schema'; interface IProps { nodeId?: string; @@ -58,12 +60,10 @@ const getOtherFieldValues = ( x !== form.getValues(`${formListName}.${index}.${latestField}`), ); -const InnerNameInput = ({ - value, - onChange, - otherNames, - validate, -}: INameInputProps) => { +const InnerNameInput = forwardRef< + HTMLInputElement, + InputProps & INameInputProps +>(function InnerNameInput({ value, onChange, otherNames, validate }, ref) { const [name, setName] = useState(); const { t } = useTranslate('flow'); @@ -103,9 +103,10 @@ const InnerNameInput = ({ value={name} onChange={handleNameChange} onBlur={handleNameBlur} + ref={ref} > ); -}; +}); const NameInput = memo(InnerNameInput); @@ -127,7 +128,6 @@ const InnerFormSet = ({ index }: IProps & { index: number }) => { name={buildFieldName('name')} render={({ field }) => ( - {t('categoryName')} { const updateNodeInternals = useUpdateNodeInternals(); - const FormSchema = useCreateCategorizeFormSchema(); const deleteCategorizeCaseEdges = useGraphStore( (state) => state.deleteEdgesBySourceAndSourceHandle, ); - const form = useFormContext>(); + const form = useFormContext>(); const { t } = useTranslate('flow'); const { fields, remove, append } = useFieldArray({ name: 'items', @@ -208,41 +207,42 @@ const DynamicCategorize = ({ nodeId }: IProps) => { ); return ( -
+
{fields.map((field, index) => ( - -
-

- {form.getValues(`items.${index}.name`)} -

- -
- - -
-
-
- - - -
+
+ +
+ {form.getValues(`items.${index}.name`)} + +
+ + +
+
+
+ + + +
+ +
))} -
+ ); }; diff --git a/web/src/pages/agent/form/categorize-form/dynamic-example.tsx b/web/src/pages/agent/form/categorize-form/dynamic-example.tsx index 35d95cbc6c8..c8c0e99cc36 100644 --- a/web/src/pages/agent/form/categorize-form/dynamic-example.tsx +++ b/web/src/pages/agent/form/categorize-form/dynamic-example.tsx @@ -1,3 +1,4 @@ +import { Collapse } from '@/components/collapse'; import { Button } from '@/components/ui/button'; import { FormControl, @@ -7,7 +8,7 @@ import { FormMessage, } from '@/components/ui/form'; import { Textarea } from '@/components/ui/textarea'; -import { Plus, X } from 'lucide-react'; +import { Plus, Trash2 } from 'lucide-react'; import { memo } from 'react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; @@ -24,44 +25,49 @@ const DynamicExample = ({ name }: DynamicExampleProps) => { }); return ( - - {t('flow.examples')} -
- {fields.map((field, index) => ( -
- ( - - - - - + {t('flow.examples')} + } + > + +
+ {fields.map((field, index) => ( +
+ ( + + + + + + )} + /> + {index === 0 ? ( + + ) : ( + )} - /> - {index === 0 ? ( - - ) : ( - - )} -
- ))} -
- -
+
+ ))} +
+ +
+ ); }; diff --git a/web/src/pages/agent/form/categorize-form/index.tsx b/web/src/pages/agent/form/categorize-form/index.tsx index f0e38a73354..de69830067b 100644 --- a/web/src/pages/agent/form/categorize-form/index.tsx +++ b/web/src/pages/agent/form/categorize-form/index.tsx @@ -1,7 +1,7 @@ -import { FormContainer } from '@/components/form-container'; import { LargeModelFormField } from '@/components/large-model-form-field'; import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item'; import { Form } from '@/components/ui/form'; +import { Separator } from '@/components/ui/separator'; import { zodResolver } from '@hookform/resolvers/zod'; import { memo } from 'react'; import { useForm } from 'react-hook-form'; @@ -33,13 +33,12 @@ function CategorizeForm({ node }: INextOperatorForm) { return (
- - - - + + + diff --git a/web/src/pages/agent/form/categorize-form/use-form-schema.ts b/web/src/pages/agent/form/categorize-form/use-form-schema.ts index 9e56bb18b21..6ff507ee890 100644 --- a/web/src/pages/agent/form/categorize-form/use-form-schema.ts +++ b/web/src/pages/agent/form/categorize-form/use-form-schema.ts @@ -30,3 +30,7 @@ export function useCreateCategorizeFormSchema() { return FormSchema; } + +export type CreateCategorizeFormSchema = ReturnType< + typeof useCreateCategorizeFormSchema +>; diff --git a/web/src/pages/agent/form/variable-aggregator-form/name-input.tsx b/web/src/pages/agent/form/variable-aggregator-form/name-input.tsx index 5a0f14ba86b..4ed895e1415 100644 --- a/web/src/pages/agent/form/variable-aggregator-form/name-input.tsx +++ b/web/src/pages/agent/form/variable-aggregator-form/name-input.tsx @@ -1,6 +1,6 @@ -import { Input } from '@/components/ui/input'; +import { Input, InputProps } from '@/components/ui/input'; import { PenLine } from 'lucide-react'; -import { useCallback, useEffect, useRef, useState } from 'react'; +import { forwardRef, useCallback, useEffect, useRef, useState } from 'react'; import { useHandleNameChange } from './use-handle-name-change'; type NameInputProps = { @@ -8,7 +8,10 @@ type NameInputProps = { onChange: (value: string) => void; }; -export function NameInput({ value, onChange }: NameInputProps) { +export const NameInput = forwardRef< + HTMLInputElement, + InputProps & NameInputProps +>(function NameInput({ value, onChange }, ref) { const { name, handleNameBlur, handleNameChange } = useHandleNameChange(value); const inputRef = useRef(null); @@ -33,7 +36,7 @@ export function NameInput({ value, onChange }: NameInputProps) { }, [isEditingMode]); return ( -
+
{isEditingMode ? ( ); -} +}); From 7dfa85806ce2c138e35f8cc8a93fad8eade3c395 Mon Sep 17 00:00:00 2001 From: Yao Wei <251109226@qq.com> Date: Mon, 2 Mar 2026 19:05:50 +0800 Subject: [PATCH 103/565] Refa: Resume parsing module (architectural optimizations based on SmartResume Pipeline) (#13255) Core optimizations (refer to arXiv:2510.09722): 1. PDF text fusion: Metadata + OCR dual-path extraction and fusion 2. Page-aware reconstruction: YOLOv10 page segmentation + hierarchical sorting + line number indexing 3. Parallel task decomposition: Basic information/work experience/educational background three-way parallel LLM extraction 4. Index pointer mechanism: LLM returns a range of line numbers instead of generating the full text, reducing the illusion of full text. --------- Co-authored-by: Aron.Yao Co-authored-by: Aron.Yao Co-authored-by: Yingfeng --- rag/app/resume.py | 2659 ++++++++++++++++++++++++-- rag/prompts/resume_basic_info.md | 39 + rag/prompts/resume_basic_info_en.md | 39 + rag/prompts/resume_education.md | 31 + rag/prompts/resume_education_en.md | 31 + rag/prompts/resume_project_exp.md | 31 + rag/prompts/resume_project_exp_en.md | 31 + rag/prompts/resume_system.md | 3 + rag/prompts/resume_system_en.md | 3 + rag/prompts/resume_work_exp.md | 39 + rag/prompts/resume_work_exp_en.md | 38 + 11 files changed, 2805 insertions(+), 139 deletions(-) create mode 100644 rag/prompts/resume_basic_info.md create mode 100644 rag/prompts/resume_basic_info_en.md create mode 100644 rag/prompts/resume_education.md create mode 100644 rag/prompts/resume_education_en.md create mode 100644 rag/prompts/resume_project_exp.md create mode 100644 rag/prompts/resume_project_exp_en.md create mode 100644 rag/prompts/resume_system.md create mode 100644 rag/prompts/resume_system_en.md create mode 100644 rag/prompts/resume_work_exp.md create mode 100644 rag/prompts/resume_work_exp_en.md diff --git a/rag/app/resume.py b/rag/app/resume.py index b022f81b302..084f8c21b4a 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -14,167 +14,2548 @@ # limitations under the License. # -import logging -import base64 -import datetime +""" +Resume parsing module (aligned with SmartResume Pipeline architecture optimization) + +Key optimizations (ref: arXiv:2510.09722): + 1. PDF text fusion: metadata + OCR dual-path extraction and fusion + 2. Layout-aware reconstruction: YOLOv10 layout segmentation + hierarchical sorting + line indexing + 3. Parallel task decomposition: basic info / work experience / education - 3-way parallel LLM extraction + 4. Index pointer mechanism: LLM returns line number ranges instead of generating full text, reducing hallucination + 5. Four-stage post-processing: source text re-extraction, domain normalization, context deduplication, source text validation + +Compatibility: + - chunk(filename, binary, callback, **kwargs) signature remains unchanged + - Compatible with FACTORY[ParserType.RESUME.value] in task_executor.py +""" + import json import re -import pandas as pd -import requests -from api.db.services.knowledgebase_service import KnowledgebaseService +import random +import datetime +import unicodedata +import concurrent.futures +from io import BytesIO +from typing import Optional +import numpy as np + +# tiktoken for long random string filtering (ref: SmartResume should_remove strategy) +try: + import tiktoken + _tiktoken_encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") +except ImportError: + _tiktoken_encoding = None + +# Long random string pattern: 40+ char alphanumeric mixed strings (hash, token, tracking ID, etc.) +_LONG_RANDOM_PATTERN = re.compile(r'[a-zA-Z0-9\-~_]{40,}') + +import logging as logger from rag.nlp import rag_tokenizer -from deepdoc.parser.resume import refactor -from deepdoc.parser.resume import step_one, step_two -from common.string_utils import remove_redundant_spaces +from deepdoc.parser.utils import get_text + +# json_repair for fixing malformed JSON from LLM responses (ref: SmartResume fault-tolerance strategy) +try: + import json_repair +except ImportError: + json_repair = None + +# YOLOv10 layout detector (lazy initialization to avoid loading model when unused) +_layout_recognizer = None + + +def _get_layout_recognizer(): + """ + Get YOLOv10 layout detector singleton (lazy loading) + + Uses the existing deepdoc LayoutRecognizer based on layout.onnx model. -forbidden_select_fields4resume = [ - "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" + Returns: + LayoutRecognizer instance, or None if loading fails + """ + global _layout_recognizer + if _layout_recognizer is None: + try: + from deepdoc.vision import LayoutRecognizer + _layout_recognizer = LayoutRecognizer("layout") + logger.info("YOLOv10 layout detector loaded successfully") + except Exception as e: + logger.warning(f"YOLOv10 layout detector loading failed, falling back to heuristic sorting: {e}") + _layout_recognizer = False # Mark as failed to avoid repeated attempts + return _layout_recognizer if _layout_recognizer is not False else None + +# ==================== Constants ==================== + +# Fields forbidden from being used as select fields in resume +FORBIDDEN_SELECT_FIELDS = [ + "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", + "sch_rank_kwd", "edu_fea_kwd" ] +# Field name to description mapping (bilingual versions for chunk construction) +FIELD_MAP_ZH = { + "name_kwd": "姓名/名字", + "name_pinyin_kwd": "姓名拼音/名字拼音", + "gender_kwd": "性别(男,女)", + "age_int": "年龄/岁/年纪", + "phone_kwd": "电话/手机/微信", + "email_tks": "email/e-mail/邮箱", + "position_name_tks": "职位/职能/岗位/职责", + "expect_city_names_tks": "期望城市", + "work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年", + "corporation_name_tks": "最近就职(上班)的公司/上一家公司", + "first_school_name_tks": "第一学历毕业学校", + "first_degree_kwd": "第一学历", + "highest_degree_kwd": "最高学历", + "first_major_tks": "第一学历专业", + "edu_first_fea_kwd": "第一学历标签", + "degree_kwd": "过往学历", + "major_tks": "学过的专业/过往专业", + "school_name_tks": "学校/毕业院校", + "sch_rank_kwd": "学校标签", + "edu_fea_kwd": "教育标签", + "corp_nm_tks": "就职过的公司/之前的公司/上过班的公司", + "edu_end_int": "毕业年份", + "industry_name_tks": "所在行业", + "birth_dt": "生日/出生年份", + "expect_position_name_tks": "期望职位/期望职能/期望岗位", + "skill_tks": "技能/技术栈/编程语言/框架/工具", + "language_tks": "语言能力/外语水平", + "certificate_tks": "证书/资质/认证", + "project_tks": "项目经验/项目名称", + "work_desc_tks": "工作职责/工作描述", + "project_desc_tks": "项目描述/项目职责", + "self_evaluation_tks": "自我评价/个人优势/个人总结", +} + +FIELD_MAP_EN = { + "name_kwd": "Name", + "name_pinyin_kwd": "Name Pinyin", + "gender_kwd": "Gender (Male, Female)", + "age_int": "Age", + "phone_kwd": "Phone/Mobile/WeChat", + "email_tks": "Email", + "position_name_tks": "Position/Title/Role", + "expect_city_names_tks": "Preferred City", + "work_exp_flt": "Years of Experience", + "corporation_name_tks": "Most Recent Company", + "first_school_name_tks": "First Degree School", + "first_degree_kwd": "First Degree", + "highest_degree_kwd": "Highest Degree", + "first_major_tks": "First Degree Major", + "edu_first_fea_kwd": "First Degree Tag", + "degree_kwd": "Past Degrees", + "major_tks": "Past Majors", + "school_name_tks": "School/University", + "sch_rank_kwd": "School Tag", + "edu_fea_kwd": "Education Tag", + "corp_nm_tks": "Past Companies", + "edu_end_int": "Graduation Year", + "industry_name_tks": "Industry", + "birth_dt": "Date of Birth", + "expect_position_name_tks": "Preferred Position/Role", + "skill_tks": "Skills/Tech Stack/Languages/Frameworks/Tools", + "language_tks": "Language Proficiency", + "certificate_tks": "Certificates/Qualifications", + "project_tks": "Project Experience/Project Name", + "work_desc_tks": "Job Responsibilities/Description", + "project_desc_tks": "Project Description/Responsibilities", + "self_evaluation_tks": "Self-Evaluation/Personal Strengths/Summary", +} + + +def _is_english(lang: str) -> bool: + """Determine if the language parameter indicates English""" + return lang.lower() in ("english", "en") + + +def get_field_map(lang: str) -> dict: + """Get the corresponding field mapping based on language parameter""" + return FIELD_MAP_EN if _is_english(lang) else FIELD_MAP_ZH + + +# Backward compatible: default to Chinese version +FIELD_MAP = FIELD_MAP_ZH + + +# ==================== Parallel LLM Extraction Prompt Templates ==================== +# Ref: SmartResume task decomposition strategy, splitting extraction into independent subtasks +# Each prompt ends with /no_think marker to suppress reasoning model's thinking output +# Prompts loaded from md files under rag/prompts/, supporting bilingual versions + +from rag.prompts.template import load_prompt + + +def _load_resume_prompt(name: str, lang: str) -> str: + """Load the corresponding version of resume prompt template based on language parameter + + Args: + name: Prompt name (without language suffix), e.g. "resume_system" + lang: Language parameter, e.g. "Chinese" or "English" + Returns: + Prompt template string + """ + suffix = "_en" if _is_english(lang) else "" + return load_prompt(f"{name}{suffix}") + + +def get_system_prompt(lang: str) -> str: + """Get system prompt""" + return _load_resume_prompt("resume_system", lang) + + +def get_basic_info_prompt(lang: str) -> str: + """Get basic info extraction prompt""" + return _load_resume_prompt("resume_basic_info", lang) + + +def get_work_exp_prompt(lang: str) -> str: + """Get work experience extraction prompt""" + return _load_resume_prompt("resume_work_exp", lang) + + +def get_education_prompt(lang: str) -> str: + """Get education background extraction prompt""" + return _load_resume_prompt("resume_education", lang) + + +def get_project_exp_prompt(lang: str) -> str: + """Get project experience extraction prompt""" + return _load_resume_prompt("resume_project_exp", lang) + + +# Backward compatible: default Chinese version constants (for possible external direct references) +SYSTEM_PROMPT = load_prompt("resume_system") +BASIC_INFO_PROMPT = load_prompt("resume_basic_info") +WORK_EXP_PROMPT = load_prompt("resume_work_exp") +EDUCATION_PROMPT = load_prompt("resume_education") +PROJECT_EXP_PROMPT = load_prompt("resume_project_exp") + +# LLM call max retry count (ref: SmartResume retry strategy) +_LLM_MAX_RETRIES = 2 + + +def _normalize_whitespace(text: str) -> str: + """ + Unicode whitespace normalization (ref: SmartResume _clean_text_content) + + Replaces various Unicode spaces (\u00A0 non-breaking space, \u3000 fullwidth space, + \u2000-\u200A various width spaces, etc.) with regular spaces, + then applies NFKC normalization (fullwidth to halfwidth) and merges consecutive spaces. + + Args: + text: Original text + Returns: + Normalized text + """ + if not text: + return "" + # NFKC normalization (fullwidth to halfwidth, etc.) + text = unicodedata.normalize('NFKC', text) + # Unify various Unicode spaces to regular space + text = re.sub( + r'[\u0020\u00A0\u1680\u2000-\u200A\u2028\u2029\u202F\u205F\u3000\u00A7]', + ' ', text + ) + # Merge consecutive spaces + text = re.sub(r' {2,}', ' ', text) + return text.strip() + + +def _should_remove_random_str(match: re.Match) -> bool: + """ + Determine if a matched long string is a meaningless random string (ref: SmartResume should_remove) + + Uses tiktoken encoding to judge: if token count exceeds 50% of original char count, + it indicates a meaningless random string (hash, token, tracking ID, etc.) that should be removed. + Normal English words have high token encoding efficiency, with token count far less than char count. + + Args: + match: Regex match object + Returns: + True means it should be removed + """ + if _tiktoken_encoding is None: + # When tiktoken is unavailable, use simple heuristic: case/digit alternation frequency + s = match.group(0) + changes = sum( + 1 for i in range(1, len(s)) + if s[i].isdigit() != s[i-1].isdigit() + or (s[i].isalpha() and s[i-1].isalpha() and s[i].isupper() != s[i-1].isupper()) + ) + return changes / len(s) > 0.3 + encoded = _tiktoken_encoding.encode(match.group(0)) + return len(encoded) > len(match.group(0)) * 0.5 + + +def _clean_line_content(text: str) -> str: + """ + Clean single line text content (Unicode normalization + long random string filtering) + + Args: + text: Original line text + Returns: + Cleaned text + """ + if not text: + return "" + # Unicode whitespace normalization + text = _normalize_whitespace(text) + # Filter long random strings (hash, token and other meaningless content) + text = _LONG_RANDOM_PATTERN.sub( + lambda m: '' if _should_remove_random_str(m) else m.group(0), + text + ) + # Clean up extra spaces after filtering + text = re.sub(r' {2,}', ' ', text).strip() + return text + + +# ==================== Phase 1: PDF Text Fusion and Layout Reconstruction ==================== + + + + +def _is_noise_char(obj: dict) -> bool: + """ + Determine if a PDF character object is a decorative layer noise character + + Uses a "body text whitelist" strategy instead of enumerating noise features, + to handle noise patterns from different resume templates: + + Two reliable features of body text characters (either one means body text): + 1. Embedded font: Font name format is XXXXXX+FontName (contains '+'), + indicating the font is embedded in the PDF, chosen by the document author + 2. Structure tag: Has PDF Tagged Structure tags (e.g., Span, P, NonStruct, etc.), + indicating the character belongs to the document's semantic structure tree + + Common features of noise characters: + - Uses system fonts (e.g., Helvetica, Arial), font name doesn't contain '+' + - No structure tags (tag is None or non-semantic tags like 'OC') + - Common in resume template background decorations, watermarks, tracking marks + + Args: + obj: pdfplumber character/text object dictionary + Returns: + True means it's a noise character that should be filtered + """ + # Whitelist condition 1: Embedded font (font name contains '+' prefix) + fontname = obj.get("fontname", "") + if "+" in fontname: + return False # Embedded font = body content + + # Whitelist condition 2: Has PDF structure tag + tag = obj.get("tag") + if tag in ("Span", "NonStruct", "P", "H1", "H2", "H3", "H4", "H5", "H6", + "TD", "TH", "LI", "L", "Table", "TR", "Figure", "Caption"): + return False # Has semantic structure tag = body content + + # Doesn't meet any whitelist condition, treat as noise + return True + + + +def _extract_metadata_text(binary: bytes) -> list[dict]: + """ + Extract text blocks from PDF metadata (with coordinate info) + + Strategy: + 1. Use whitelist strategy to filter decorative layer noise chars (embedded font or structure tag = body text) + 2. Safe fallback: if filtered chars are less than 30% of original, skip filtering to avoid false positives + 3. Use extract_words for word-level extraction (with real coordinates) + 4. Aggregate adjacent words into line-level text blocks by Y coordinate + 5. Additionally extract table content (many resumes use table layouts) + + Args: + binary: PDF file binary content + Returns: + List of text blocks, each containing text, x0, top, x1, bottom, page fields + """ + try: + import pdfplumber + blocks = [] + with pdfplumber.open(BytesIO(binary)) as pdf: + for page_idx, page in enumerate(pdf.pages): + page_width = page.width or 600 + + # Filter decorative layer noise chars (whitelist strategy based on embedded font + structure tag) + # Safe fallback: if filtered chars are less than 30% of original, the PDF's body text + # may use non-embedded fonts without structure tags, skip filtering to avoid false positives + try: + original_char_count = len(page.chars) + filtered_page = page.filter( + lambda obj: not _is_noise_char(obj) + ) + filtered_char_count = len(filtered_page.chars) + if original_char_count > 0 and filtered_char_count < original_char_count * 0.3: + # Filtered out over 70% of chars, likely false positives, fall back to original page + filtered_page = page + except Exception: + filtered_page = page + + # Use extract_words for extraction (with real coordinates) + words = [] + try: + words = filtered_page.extract_words( + keep_blank_chars=False, use_text_flow=True + ) + except Exception: + pass + + if words: + # Aggregate adjacent words into line-level text blocks by Y coordinate + # Words on the same line: top coordinate difference within threshold + line_threshold = 5 # Y coordinate difference threshold (unit: PDF points) + current_line_words = [words[0]] + + def _flush_line(line_words): + """Merge words in a line into a single text block""" + # Sort by x0 to ensure left-to-right order + line_words.sort(key=lambda w: float(w.get("x0", 0))) + texts = [] + for w in line_words: + texts.append(w.get("text", "")) + merged_text = " ".join(texts) + if not merged_text.strip(): + return None + return { + "text": merged_text.strip(), + "x0": float(min(w.get("x0", 0) for w in line_words)), + "top": float(min(w.get("top", 0) for w in line_words)), + "x1": float(max(w.get("x1", 0) for w in line_words)), + "bottom": float(max(w.get("bottom", 0) for w in line_words)), + "page": page_idx, + } + + for w in words[1:]: + w_top = float(w.get("top", 0)) + cur_top = float(current_line_words[0].get("top", 0)) + if abs(w_top - cur_top) <= line_threshold: + current_line_words.append(w) + else: + block = _flush_line(current_line_words) + if block: + blocks.append(block) + current_line_words = [w] + + # Process the last line + if current_line_words: + block = _flush_line(current_line_words) + if block: + blocks.append(block) + else: + # Fall back to extract_text when extract_words fails + page_text = None + try: + page_text = page.extract_text() + except Exception: + pass + if page_text and page_text.strip(): + raw_lines = page_text.split("\n") + line_height = 16 + for i, line in enumerate(raw_lines): + cleaned = line.strip() + if not cleaned: + continue + blocks.append({ + "text": cleaned, + "x0": 0, + "top": i * line_height, + "x1": page_width, + "bottom": i * line_height + line_height - 2, + "page": page_idx, + }) + + # Extract table content from the page + # Many resumes use table layouts (e.g., personal info section), extract_words may miss table structure + try: + tables = page.extract_tables() + if tables: + page_blocks = [b for b in blocks if b["page"] == page_idx] + max_top = max((b["top"] for b in page_blocks), default=0) + 20 + row_height = 16 + + for table in tables: + for row in table: + if not row: + continue + cells = [str(c).strip() for c in row if c and str(c).strip()] + if not cells: + continue + row_text = " | ".join(cells) + # Dedup: check if table content was already extracted by extract_words + is_dup = False + for pb in page_blocks: + if all(c in pb["text"] for c in cells[:2]): + is_dup = True + break + if is_dup: + continue + blocks.append({ + "text": row_text, + "x0": 0, + "top": max_top, + "x1": page_width, + "bottom": max_top + row_height - 2, + "page": page_idx, + }) + max_top += row_height + except Exception as e: + logger.debug(f"PDF table extraction skipped (page {page_idx}): {e}") + return blocks + except Exception as e: + logger.warning(f"PDF metadata extraction failed: {e}") + return [] + +def _extract_ocr_text(binary: bytes, meta_blocks: list[dict] | None = None) -> list[dict]: + """ + Extract OCR text blocks using blackout strategy (with coordinate info). + + Strategy (ref: SmartResume): + 1. Render PDF pages to images + 2. Black out regions already extracted by metadata + 3. Run OCR on the blacked-out image, only recognizing content metadata missed + 4. Eliminates duplication at source, no IoU dedup needed downstream + + Args: + binary: PDF file binary content + meta_blocks: Text blocks from metadata extraction, used to black out existing text regions + Returns: + List of text blocks, each containing text, x0, top, x1, bottom, page fields + """ + if meta_blocks is None: + meta_blocks = [] + try: + import pdfplumber + from deepdoc.vision.ocr import OCR + import numpy as np + + ocr = OCR() + blocks = [] + + with pdfplumber.open(BytesIO(binary)) as pdf: + for page_idx, page in enumerate(pdf.pages): + # Render page to image (resolution=216 = 3x scale, since PDF default is 72 DPI) + img = page.to_image(resolution=216) + page_img = np.array(img.annotated) + + # Scale factor from PDF coordinates to image coordinates + pdf_to_img_scale = 216.0 / 72.0 # = 3.0 + + # Black out metadata-extracted text regions before OCR + page_meta_blocks = [b for b in meta_blocks if b.get("page") == page_idx] + if page_meta_blocks: + page_img = _blackout_text_regions(page_img, meta_blocks, page_idx, pdf_to_img_scale) + + ocr_result = ocr(page_img) + if not ocr_result: + continue + for box_info in ocr_result: + if isinstance(box_info, (list, tuple)) and len(box_info) >= 2: + coords = box_info[0] # Coordinate points + text_info = box_info[1] + text = text_info[0] if isinstance(text_info, (list, tuple)) else str(text_info) + if text.strip() and isinstance(coords, (list, tuple)) and len(coords) >= 4: + # Extract bounding box from four corner points + xs = [p[0] for p in coords if isinstance(p, (list, tuple))] + ys = [p[1] for p in coords if isinstance(p, (list, tuple))] + if xs and ys: + blocks.append({ + "text": text.strip(), + "x0": min(xs), "top": min(ys), + "x1": max(xs), "bottom": max(ys), + "page": page_idx, + }) + return blocks + except Exception as e: + logger.warning(f"OCR extraction failed: {e}") + return [] + + +def _fuse_text_blocks(meta_blocks: list[dict], ocr_blocks: list[dict]) -> list[dict]: + """ + Fuse PDF metadata text and OCR text (blackout strategy version). + + Since the OCR phase already blacks out metadata-extracted regions, OCR only recognizes + content that metadata missed. Therefore this function only needs to: + 1. Filter out garbled blocks from metadata + 2. Directly merge valid metadata blocks and OCR blocks (no IoU dedup needed) + + Args: + meta_blocks: Text blocks from metadata extraction + ocr_blocks: Text blocks from OCR extraction (already deduplicated via blackout strategy) + Returns: + Fused text block list + """ + if not ocr_blocks: + return meta_blocks + if not meta_blocks: + return ocr_blocks + + # Filter out garbled blocks from metadata + valid_meta = [] + garbled_count = 0 + for b in meta_blocks: + if _is_valid_line(b.get("text", "")): + valid_meta.append(b) + else: + garbled_count += 1 + + if garbled_count: + logger.info(f"Detected {garbled_count} garbled blocks in metadata, filtered out") + + # Under blackout strategy, OCR won't re-recognize existing text, just merge directly + fused = valid_meta + ocr_blocks + return fused + -def remote_call(filename, binary): - q = { - "header": { - "uid": 1, - "user": "kevinhu", - "log_id": filename - }, - "request": { - "p": { - "request_id": "1", - "encrypt_type": "base64", - "filename": filename, - "langtype": '', - "fileori": base64.b64encode(binary).decode('utf-8') - }, - "c": "resume_parse_module", - "m": "resume_parse" + + +def _layout_aware_reorder(blocks: list[dict]) -> list[dict]: + """ + Layout-aware hierarchical sorting (ref: SmartResume Hierarchical Re-ordering) + + Two-level sorting strategy: + 1. Inter-segment sorting: first by page number, then by Y coordinate (top to bottom), same row by X coordinate (left to right) + 2. Intra-segment sorting: within each logical segment, sort by reading order + + For multi-column resumes, detect column positions by clustering X coordinates, + then sort by column order. + + Args: + blocks: Text block list (with coordinate info) + Returns: + Sorted text block list + """ + if not blocks: + return blocks + + # Group by page + pages = {} + for b in blocks: + pg = b.get("page", 0) + pages.setdefault(pg, []).append(b) + + sorted_blocks = [] + for pg in sorted(pages.keys()): + page_blocks = pages[pg] + + # Detect multi-column layout: by X coordinate median + if len(page_blocks) > 5: + x_centers = [(b["x0"] + b["x1"]) / 2 for b in page_blocks] + x_min, x_max = min(x_centers), max(x_centers) + page_width = x_max - x_min if x_max > x_min else 1 + + # Simple two-column detection: if text blocks are clearly distributed on left and right sides + mid_x = (x_min + x_max) / 2 + left_count = sum(1 for x in x_centers if x < mid_x - page_width * 0.1) + right_count = sum(1 for x in x_centers if x > mid_x + page_width * 0.1) + + if left_count > 3 and right_count > 3: + # Multi-column layout: left column first then right column, each column top to bottom + left_blocks = [b for b in page_blocks if (b["x0"] + b["x1"]) / 2 < mid_x] + right_blocks = [b for b in page_blocks if (b["x0"] + b["x1"]) / 2 >= mid_x] + left_blocks.sort(key=lambda b: (b["top"], b["x0"])) + right_blocks.sort(key=lambda b: (b["top"], b["x0"])) + sorted_blocks.extend(left_blocks) + sorted_blocks.extend(right_blocks) + continue + + # Single-column layout: top to bottom, same row left to right + page_blocks.sort(key=lambda b: (b["top"], b["x0"])) + sorted_blocks.extend(page_blocks) + + return sorted_blocks + + +def _build_indexed_text(blocks: list[dict]) -> tuple[str, list[str], list[dict]]: + """ + + Build indexed text with line numbers (ref: SmartResume Indexed Linearization) + + Merges sorted text blocks into lines and adds a unique index number to each line. + Includes garbled line filtering logic and field label split repair. + Also preserves coordinate info for each line, used for writing position_int etc. to chunks. + + Args: + blocks: Sorted text block list + Returns: + (indexed_text, lines, line_positions) tuple: + - indexed_text: Text string with line numbers + - lines: Original line text list (without line numbers) + - line_positions: Coordinate info for each line, format: + """ + if not blocks: + return "", [], [] + + raw_lines = [] + raw_positions = [] + current_line_parts = [] + current_line_blocks = [] + current_top = blocks[0].get("top", 0) + current_layoutno = blocks[0].get("layoutno", "") + threshold = 10 + + def _merge_line_position(line_blocks: list[dict]) -> dict: + """Merge coordinates of all blocks in a line into outer bounding rectangle""" + return { + "page": line_blocks[0].get("page", 0), + "x0": min(b.get("x0", 0) for b in line_blocks), + "x1": max(b.get("x1", 0) for b in line_blocks), + "top": min(b.get("top", 0) for b in line_blocks), + "bottom": max(b.get("bottom", 0) for b in line_blocks), } + + for b in blocks: + b_layoutno = b.get("layoutno", "") + y_changed = abs(b.get("top", 0) - current_top) > threshold + layout_changed = b_layoutno != current_layoutno and current_layoutno and b_layoutno + if (y_changed or layout_changed) and current_line_parts: + raw_lines.append(" ".join(current_line_parts)) + raw_positions.append(_merge_line_position(current_line_blocks)) + current_line_parts = [] + current_line_blocks = [] + current_top = b.get("top", 0) + current_layoutno = b_layoutno + current_line_parts.append(b["text"]) + current_line_blocks.append(b) + + if current_line_parts: + raw_lines.append(" ".join(current_line_parts)) + raw_positions.append(_merge_line_position(current_line_blocks)) + + # Filter empty and garbled lines (sync filter coordinates) + lines = [] + line_positions = [] + for line, pos in zip(raw_lines, raw_positions): + # Unicode normalization + long random string filtering (ref: SmartResume _clean_text_content) + line = _clean_line_content(line) + if not line: + continue + # Garbled detection: skip if valid chars (Chinese/ASCII letters/digits/common punctuation) ratio is too low + if not _is_valid_line(line): + continue + lines.append(line) + line_positions.append(pos) + + # Fix field label split issues + # Coordinates are not affected, keep original positions + lines = _fix_split_labels(lines) + + # Build indexed text with line numbers + indexed_parts = [f"[{i}]: {line}" for i, line in enumerate(lines)] + indexed_text = "\n".join(indexed_parts) + + return indexed_text, lines, line_positions + +def _is_valid_line(line: str) -> bool: + """ + Check if a text line is valid content (not garbled) + + Multi-dimensional detection: + 1. Valid character ratio (Chinese, ASCII alphanumeric, common punctuation) + 2. Single-character spacing anomaly detection (PDF custom font mapping causing "O U W Z_W V 2" pattern) + 3. Consecutive meaningless alphanumeric sequence detection + + Args: + line: Text line to check + Returns: + True means valid line, False means garbled line + """ + if len(line) <= 3: + # Short lines may be valid content like names, keep them + return True + + cid_count = len(re.findall(r'\(cid:\d+\)', line)) + if cid_count >= 3: + return False + # Valid characters: Chinese (incl. extension), ASCII alphanumeric, common punctuation and spaces, fullwidth chars, CJK punctuation + valid_chars = re.findall( + r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff' + r'a-zA-Z0-9\s@.,:;!?()()【】\-_/\\|·•' + r'、,。:;!?\u201c\u201d\u2018\u2019《》' + r'\uff01-\uff5e' + r'\u3000-\u303f' + r'#%&+=~`\u00b7\u2022\u2013\u2014' + r']', + line + ) + ratio = len(valid_chars) / len(line) if len(line) > 0 else 0 + if ratio < 0.5: + return False + + # Detect PDF custom font mapping causing single-character spacing anomaly pattern + # Feature: lots of "single letter space single letter space" sequences, e.g. "O U W Z_W V 2 X 3" + # Stats: ratio of space-separated single chars among non-space chars + spaced_singles = re.findall(r'(?:^|\s)([a-zA-Z0-9])(?:\s|$)', line) + non_space_len = len(line.replace(" ", "")) + if non_space_len > 5 and len(spaced_singles) > 0: + # If ratio of space-separated single chars to non-space chars is too high, classify as garbled + single_ratio = len(spaced_singles) / non_space_len + if single_ratio > 0.3: + return False + + # Detect consecutive meaningless mixed-case alphanumeric sequences (e.g. "UJqZX9V2") + # Normal English words don't have such frequent case alternation patterns + garbled_seqs = re.findall(r'[a-zA-Z0-9]{4,}', line.replace(" ", "")) + if garbled_seqs: + garbled_count = 0 + for seq in garbled_seqs: + # Count case alternations + case_changes = sum( + 1 for i in range(1, len(seq)) + if (seq[i].isupper() != seq[i-1].isupper() and seq[i].isalpha() and seq[i-1].isalpha()) + or (seq[i].isdigit() != seq[i-1].isdigit()) + ) + # Too high alternation frequency = garbled sequence (normal words like "Spring" have only 1 alternation) + if len(seq) >= 4 and case_changes / len(seq) > 0.5: + garbled_count += 1 + # If garbled sequence ratio is too high + if len(garbled_seqs) > 0 and garbled_count / len(garbled_seqs) > 0.4: + return False + + return True + + +def _fix_split_labels(lines: list[str]) -> list[str]: + """ + Fix field label split issues + + Some PDF layouts split field labels across line start/end, e.g.: + - "名:陈晓俐 姓" -> should be fixed to "姓名:陈晓俐" + - "别:男 性" -> should be fixed to "性别:男" + + Args: + lines: Original line text list + Returns: + Fixed line text list + """ + # Common split field label patterns: (line-end part, line-start part) -> full label + split_patterns = { + ("姓", "名"): "姓名", + ("性", "别"): "性别", + ("年", "龄"): "年龄", + ("电", "话"): "电话", + ("邮", "箱"): "邮箱", + ("学", "历"): "学历", + ("专", "业"): "专业", + ("地", "址"): "地址", + ("籍", "贯"): "籍贯", + ("民", "族"): "民族", } - for _ in range(3): + + fixed = [] + for line in lines: + # Detect in-line split patterns: "X:content Y" where (Y, X) is a split pair + for (suffix_char, prefix_char), full_label in split_patterns.items(): + # Pattern: "prefix_char:content suffix_char" (first half at line start, second half at line end) + pattern = rf'^({re.escape(prefix_char)})\s*[::]\s*(.+?)\s+{re.escape(suffix_char)}\s*$' + m = re.match(pattern, line) + if m: + content = m.group(2).strip() + line = f"{full_label}:{content}" + break + # Pattern: "suffix_char content prefix_char:" (second half at line start, first half at line end) + pattern2 = rf'^{re.escape(suffix_char)}\s*[::]?\s*(.+?)\s+{re.escape(prefix_char)}\s*$' + m2 = re.match(pattern2, line) + if m2: + content = m2.group(1).strip() + line = f"{full_label}:{content}" + break + fixed.append(line) + return fixed + + + + + +def extract_text(filename: str, binary: bytes) -> tuple[str, list[str], list[dict]]: + """ + Extract text content based on file type (Pipeline Phase 1). + + PDF files use dual-path fusion + layout reconstruction + line indexing. + Other formats fall back to simple text extraction. + + Args: + filename: File name + binary: File binary content + Returns: + (indexed_text, lines, line_positions) tuple: + - indexed_text: Text with line number indices + - lines: List of original line texts + - line_positions: List of per-line coordinate info (empty list for non-PDF formats) + """ + fname_lower = filename.lower() + + try: + if fname_lower.endswith(".pdf"): + # Dual-path extraction + meta_blocks = _extract_metadata_text(binary) + ocr_blocks = [] + + # Determine whether OCR supplementation is needed: + # 1. Metadata text too short (< 100 chars) + # 2. High garbled text ratio in metadata (caused by custom font mapping) + meta_text_len = sum(len(b["text"]) for b in meta_blocks) + need_ocr = False + + if meta_text_len < 100: + logger.info("PDF metadata text too short, enabling OCR supplementation") + need_ocr = True + else: + # Check metadata text quality: calculate valid line ratio + # If many lines are judged as garbled by _is_valid_line, the PDF font mapping has issues + valid_line_count = 0 + total_line_count = 0 + for b in meta_blocks: + text = b.get("text", "").strip() + if not text: + continue + total_line_count += 1 + if _is_valid_line(text): + valid_line_count += 1 + if total_line_count > 0: + valid_ratio = valid_line_count / total_line_count + if valid_ratio < 0.6: + logger.info( + f"PDF metadata text quality low (valid line ratio {valid_ratio:.1%}), enabling OCR supplementation" + ) + need_ocr = True + + if need_ocr: + # Blackout strategy: black out metadata-extracted regions before OCR + ocr_blocks = _extract_ocr_text(binary, meta_blocks=meta_blocks) + + # Text fusion + fused_blocks = _fuse_text_blocks(meta_blocks, ocr_blocks) + + # Layout-aware sorting (prefer YOLOv10 layout detection, fall back to heuristic on failure) + sorted_blocks = _layout_detect_reorder(fused_blocks, binary) + + # Build line-indexed text (with coordinate info) + return _build_indexed_text(sorted_blocks) + + elif fname_lower.endswith(".docx"): + from docx import Document + doc = Document(BytesIO(binary)) + lines = [p.text.strip() for p in doc.paragraphs if p.text.strip()] + + # Extract table content from DOCX + # Reference: table handling in naive.py Docx class + # Many resumes use table layouts for personal info; iterating only paragraphs would miss this content + for table in doc.tables: + for row in table.rows: + cells = [] + for cell in row.cells: + cell_text = cell.text.strip() + if cell_text: + cells.append(cell_text) + if not cells: + continue + row_text = " | ".join(cells) + # Deduplicate: skip if this row text already exists in lines + if row_text not in lines: + lines.append(row_text) + + indexed = "\n".join(f"[{i}]: {line}" for i, line in enumerate(lines)) + # DOCX has no coordinate info, return empty list + return indexed, lines, [] + + else: + text = get_text(filename, binary) + lines = [line.strip() for line in text.split("\n") if line.strip()] + indexed = "\n".join(f"[{i}]: {line}" for i, line in enumerate(lines)) + return indexed, lines, [] + + except Exception: + logger.exception(f"Text extraction failed: {filename}") + return "", [], [] + + +# ==================== Phase 2: Parallel LLM Structured Extraction ==================== + + +def _clean_llm_json_response(response: str) -> str: + """ + Clean LLM JSON response. + + Uses SmartResume's lightweight string extraction strategy: + 1. Remove markdown code block markers + 2. Remove ... thinking tags (reasoning models may output these) + 3. text.find("{") and text.rfind("}") to locate valid JSON block + + Args: + response: Raw LLM response text + Returns: + Cleaned JSON string + """ + text = response.strip() + # Remove markdown code block markers + text = text.replace("```json", "").replace("```", "").strip() + # Remove reasoning model thinking tags + text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + # Clean escaped quotes (SmartResume's approach) + text = text.replace('\\"', '"') + # SmartResume strategy: locate first { and last } + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + return text[start:end + 1] + return text + + +def _parse_json_with_repair(text: str) -> dict: + """ + Parse JSON string, attempt repair on failure (ref SmartResume's json_repair strategy). + + Repair strategies: + 1. Standard json.loads + 2. Replace Python-style booleans/None + 3. Use json_repair library + + Args: + text: JSON string + Returns: + Parsed dictionary + Raises: + json.JSONDecodeError: Raised when all repair strategies fail + """ + # First attempt: standard parsing + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Second attempt: replace Python-style values (ref SmartResume) + repaired = text.replace("'", '"') + repaired = repaired.replace('True', 'true') + repaired = repaired.replace('False', 'false') + repaired = repaired.replace('None', 'null') + try: + return json.loads(repaired) + except json.JSONDecodeError: + pass + + # Third attempt: use json_repair library + if json_repair is not None: try: - resume = requests.post( - "http://127.0.0.1:61670/tog", - data=json.dumps(q)) - resume = resume.json()["response"]["results"] - resume = refactor(resume) - for k in ["education", "work", "project", - "training", "skill", "certificate", "language"]: - if not resume.get(k) and k in resume: - del resume[k] - - resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", - "updated_at": datetime.datetime.now().strftime( - "%Y-%m-%d %H:%M:%S")}])) - resume = step_two.parse(resume) - return resume + return json_repair.loads(text) except Exception: - logging.exception("Resume parser has not been supported yet!") - return {} - - -def chunk(filename, binary=None, callback=None, **kwargs): - """ - The supported file formats are pdf, docx and txt. - To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow - """ - if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): - raise NotImplementedError("file type not supported yet(pdf supported)") - - if not binary: - with open(filename, "rb") as f: - binary = f.read() - - callback(0.2, "Resume parsing is going on...") - resume = remote_call(filename, binary) - if len(resume.keys()) < 7: - callback(-1, "Resume is not successfully parsed.") - raise Exception("Resume parser remote call fail!") - callback(0.6, "Done parsing. Chunking...") - logging.debug("chunking resume: " + json.dumps(resume, ensure_ascii=False, indent=2)) - - field_map = { - "name_kwd": "姓名/名字", - "name_pinyin_kwd": "姓名拼音/名字拼音", - "gender_kwd": "性别(男,女)", - "age_int": "年龄/岁/年纪", - "phone_kwd": "电话/手机/微信", - "email_tks": "email/e-mail/邮箱", - "position_name_tks": "职位/职能/岗位/职责", - "expect_city_names_tks": "期望城市", - "work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年", - "corporation_name_tks": "最近就职(上班)的公司/上一家公司", - - "first_school_name_tks": "第一学历毕业学校", - "first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)", - "highest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)", - "first_major_tks": "第一学历专业", - "edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)", - - "degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)", - "major_tks": "学过的专业/过往专业", - "school_name_tks": "学校/毕业院校", - "sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)", - "edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)", - - "corp_nm_tks": "就职过的公司/之前的公司/上过班的公司", - "edu_end_int": "毕业年份", - "industry_name_tks": "所在行业", - - "birth_dt": "生日/出生年份", - "expect_position_name_tks": "期望职位/期望职能/期望岗位", - } + pass + + # All strategies failed + raise json.JSONDecodeError("All JSON repair strategies failed", text, 0) + + +def _call_llm(prompt: str, tenant_id , lang: str) -> Optional[dict]: + """ + Call LLM and parse JSON response (ref SmartResume's retry + fault-tolerance strategy). + + Retry mechanism: + - Retry up to _LLM_MAX_RETRIES times + - On retry, increase temperature and randomize seed for output diversity + - Use json_repair on JSON parse failure + + Args: + prompt: User prompt + lang: Language + Returns: + Parsed dictionary, or None on failure + + """ + try: + from api.db.services.llm_service import LLMBundle + from common.constants import LLMType + + llm = LLMBundle(tenant_id, LLMType.CHAT, lang=lang) + + for attempt in range(_LLM_MAX_RETRIES + 1): + try: + # Increase temperature on retry for diversity (ref SmartResume) + temperature = 0.1 if attempt == 0 else 1.0 + gen_conf = {"temperature": temperature, "max_tokens": 2048} + if attempt > 0: + gen_conf["seed"] = random.randint(0, 1000000) + + response = llm.chat( + system=get_system_prompt(lang), + history=[{"role": "user", "content": prompt}], + gen_conf=gen_conf, + ) + cleaned = _clean_llm_json_response(response) + return _parse_json_with_repair(cleaned) + + except json.JSONDecodeError as e: + if attempt < _LLM_MAX_RETRIES: + logger.info(f"LLM JSON parse failed (attempt {attempt + 1}), retrying: {e}") + continue + else: + logger.warning(f"LLM JSON parse failed (retries exhausted): {e}") + return None + + except Exception as e: + logger.warning(f"LLM call failed: {e}") + return None + + +def _normalize_for_comparison(text: str) -> str: + """ + Normalize text for comparison (ref SmartResume's _normalize_for_comparison). + + Unify fullwidth/halfwidth, remove whitespace, Unicode normalization, + so that "阿里巴巴" and "阿 里 巴 巴" can match. + + Args: + text: Original text + Returns: + Normalized text + """ + if not text: + return "" + # Unicode NFKC normalization (fullwidth to halfwidth, etc.) + text = unicodedata.normalize("NFKC", text) + # Remove all whitespace characters + text = re.sub(r'\s+', '', text) + return text.lower() + +def _calc_single_exp_years(start_str: str, end_str: str) -> float: + """ + Calculate years for a single experience entry. + + Args: + start_str: Start date string + end_str: End date string ("至今" etc. means current) + Returns: + Years (float, 1 decimal place), returns 0 if unable to calculate + """ + from datetime import datetime - titles = [] - for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]: - v = resume.get(n, "") - if isinstance(v, list): - v = v[0] - if n.find("tks") > 0: - v = remove_redundant_spaces(v) - titles.append(str(v)) + start_str = str(start_str).strip() + end_str = str(end_str).strip() + if not start_str: + return 0 + + start_date = _parse_date_str(start_str) + if not start_date: + return 0 + + if end_str in ("至今", "现在", "present", "Present", "now", "Now", ""): + end_date = datetime.now() + else: + end_date = _parse_date_str(end_str) + if not end_date: + end_date = datetime.now() + + months = (end_date.year - start_date.year) * 12 + (end_date.month - start_date.month) + if months <= 0: + return 0 + return round(months / 12.0, 1) + + +def _calculate_work_years(experiences: list[dict]) -> float: + """ + Calculate total work years based on start/end dates of each work experience. + + Args: + experiences: List of work experiences, each containing start_date, end_date fields + Returns: + Total work years (float), returns 0 if unable to calculate + """ + total = 0.0 + for exp in experiences: + total += _calc_single_exp_years( + exp.get("start_date", ""), exp.get("end_date", "") + ) + return round(total, 1) + + +def _parse_date_str(date_str: str) -> Optional[datetime.datetime]: + """ + Parse date string, supporting multiple common formats. + + Supported formats: + - 2024.1 / 2024.01 + - 2024-1 / 2024-01 + - 2024/1 / 2024/01 + - 2024年1月 + - 2024 (year only, defaults to January) + + Args: + date_str: Date string + Returns: + datetime object, or None on parse failure + """ + from datetime import datetime + + date_str = date_str.strip() + # Try matching year.month / year-month / year/month / year(nian)month(yue) formats + patterns = [ + (r"((?:19|20)\d{2})[.\-/年](\d{1,2})", "%Y-%m"), + (r"^((?:19|20)\d{2})$", "%Y"), + ] + for pattern, _ in patterns: + m = re.search(pattern, date_str) + if m: + try: + year = int(m.group(1)) + month = int(m.group(2)) if len(m.groups()) > 1 else 1 + # Month range validation + if month < 1 or month > 12: + month = 1 + return datetime(year, month, 1) + except (ValueError, IndexError): + continue + return None + + + + +def _extract_description_from_range( + index_range: list, lines: list[str], + company: str = "", position: str = "" +) -> str: + """ + Extract description from original text by index range (ref SmartResume's _extract_description_from_range). + + Key improvement: + - Filter out lines containing both company name and position title (avoid mixing header lines into description) + - Boundary safety checks + + Args: + index_range: [start_line_number, end_line_number] + lines: List of original line texts + company: Company name (used to filter header lines) + position: Position title (used to filter header lines) + Returns: + Extracted description text + """ + if not index_range or len(index_range) != 2: + return "" + + start_idx, end_idx = int(index_range[0]), int(index_range[1]) + + # Boundary safety check + if start_idx < 0 or end_idx >= len(lines) or start_idx > end_idx: + return "" + + extracted_lines = lines[start_idx:end_idx + 1] + + # Filter out lines containing both company name and position title (ref SmartResume) + if company or position: + norm_company = _normalize_for_comparison(company) + norm_position = _normalize_for_comparison(position) + filtered = [] + for line in extracted_lines: + norm_line = _normalize_for_comparison(line) + # If a line contains both company name and position title, it's likely a header line, skip + if norm_company and norm_position and norm_company in norm_line and norm_position in norm_line: + continue + # If a line exactly equals company name or position title, also skip + if norm_line == norm_company or norm_line == norm_position: + continue + filtered.append(line) + extracted_lines = filtered + + if not extracted_lines: + return "" + + return "\n".join(line.strip() for line in extracted_lines if line.strip()) + + +def _extract_basic_info(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: + """Extract basic info (subtask 1). + + Basic info is usually at the beginning of the resume, first 8000 chars suffice. + """ + prompt = get_basic_info_prompt(lang).format(indexed_text=indexed_text[:8000]) + return _call_llm(prompt,tenant_id, lang) + + +def _extract_work_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: + """Extract work experience (subtask 2, using index pointers). + + Work experience may span the middle-to-end of the resume, use full text to avoid truncation. + """ + prompt = get_work_exp_prompt(lang).format(indexed_text=indexed_text) + return _call_llm(prompt, tenant_id , lang) + + +def _extract_education(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: + """Extract education background (subtask 3). + + Education is usually at the end of the resume, must use full text to avoid truncation. + Resume text is generally under 30K chars, within LLM context window. + """ + prompt = get_education_prompt(lang).format(indexed_text=indexed_text) + return _call_llm(prompt,tenant_id, lang) + + +def _extract_project_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: + """Extract project experience (subtask 4, using index pointers). + + Project experience may span the middle-to-end of the resume, use full text to avoid truncation. + """ + prompt = get_project_exp_prompt(lang).format(indexed_text=indexed_text) + return _call_llm(prompt, tenant_id , lang) + + +def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) -> Optional[dict]: + """ + Extract resume info using parallel task decomposition strategy (ref SmartResume Section 3.2). + + Decomposes extraction into four independent subtasks executed in parallel: + 1. Basic info (name, phone, skills, self-evaluation, etc.) + 2. Work experience (company, position, description line ranges) + 3. Education background (school, major, degree) + 4. Project experience (project name, role, description line ranges) + + Args: + indexed_text: Line-indexed resume text + lines: List of original line texts (for index-based extraction) + lang: Language + Returns: + Merged structured resume dictionary, or None on failure + """ + try: + # Execute four subtasks in parallel + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + future_basic = executor.submit(_extract_basic_info, indexed_text, tenant_id , lang) + future_work = executor.submit(_extract_work_experience, indexed_text, tenant_id , lang) + future_edu = executor.submit(_extract_education, indexed_text, tenant_id, lang) + future_project = executor.submit(_extract_project_experience, indexed_text, tenant_id , lang) + + basic_info = future_basic.result(timeout=60) + work_exp = future_work.result(timeout=60) + education = future_edu.result(timeout=60) + project_exp = future_project.result(timeout=60) + + # Merge results + resume = {} + + # Merge basic info + if basic_info: + resume.update(basic_info) + logger.info(f"Basic info extraction succeeded: {len(basic_info)} fields") + + # Process work experience (index pointer extraction) + if work_exp and "workExperience" in work_exp: + experiences = work_exp["workExperience"] + companies = [] + positions = [] + work_descs = [] + # Save detailed info for each experience (dates, years) for chunk generation + work_exp_details = [] + for exp in experiences: + company = exp.get("company", "") + position = exp.get("position", "") + start_date = exp.get("start_date", "") + end_date = exp.get("end_date", "") + # Calculate years for this experience entry + years = _calc_single_exp_years(start_date, end_date) + if company: + companies.append(company) + if position: + positions.append(position) + # Save detailed info for each experience entry + work_exp_details.append({ + "company": company, + "position": position, + "start_date": start_date, + "end_date": end_date, + "years": years, + }) + # Index pointer mechanism: extract description from original text by line range + # Use _extract_description_from_range to filter header lines (ref SmartResume) + desc_lines = exp.get("desc_lines", []) + if isinstance(desc_lines, list) and len(desc_lines) == 2: + desc = _extract_description_from_range( + desc_lines, lines, company=company, position=position + ) + if desc.strip(): + work_descs.append(desc.strip()) + + if companies: + resume["corp_nm_tks"] = companies + resume["corporation_name_tks"] = companies[0] + if positions: + resume["position_name_tks"] = positions + if work_descs: + resume["work_desc_tks"] = work_descs + # Save experience details for _build_chunk_document + if work_exp_details: + resume["_work_exp_details"] = work_exp_details + # Calculate total work years from each experience's dates (overrides LLM's guess in basic info) + calculated_years = _calculate_work_years(experiences) + if calculated_years > 0: + resume["work_exp_flt"] = calculated_years + logger.info(f"Work experience extraction succeeded: {len(experiences)} entries, calculated total years: {calculated_years}") + + # Process education background + if education and "education" in education: + edu_list = education["education"] + schools = [] + majors = [] + degrees = [] + for edu in edu_list: + if edu.get("school"): + schools.append(edu["school"]) + if edu.get("major"): + majors.append(edu["major"]) + if edu.get("degree"): + degrees.append(edu["degree"]) + # Extract graduation year + end_date = edu.get("end_date", "") + if end_date and not resume.get("edu_end_int"): + year_match = re.search(r"(19|20)\d{2}", str(end_date)) + if year_match: + resume["edu_end_int"] = int(year_match.group(0)) + + if schools: + resume["school_name_tks"] = schools + resume["first_school_name_tks"] = schools[-1] # Earliest school is usually last + if majors: + resume["major_tks"] = majors + resume["first_major_tks"] = majors[-1] + if degrees: + resume["degree_kwd"] = degrees + # Infer highest degree (supports both Chinese and English degree names) + degree_rank = { + "博士": 5, "PhD": 5, "Doctor": 5, + "硕士": 4, "Master": 4, "MBA": 4, "EMBA": 4, "MPA": 4, + "本科": 3, "Bachelor": 3, + "大专": 2, "专科": 2, "Associate": 2, "Diploma": 2, + "高中": 1, "High School": 1, + } + highest = max(degrees, key=lambda d: degree_rank.get(d, 0), default="") + if highest: + resume["highest_degree_kwd"] = highest + resume["first_degree_kwd"] = degrees[-1] if degrees else "" + logger.info(f"Education extraction succeeded: {len(edu_list)} entries") + + # Process project experience (index pointer extraction, similar to work experience) + if project_exp and "projectExperience" in project_exp: + projects = project_exp["projectExperience"] + project_names = [] + project_descs = [] + for proj in projects: + name = proj.get("project_name", "") + if name: + project_names.append(name) + # Index pointer mechanism: extract project description from original text by line range + desc_lines = proj.get("desc_lines", []) + if isinstance(desc_lines, list) and len(desc_lines) == 2: + desc = _extract_description_from_range( + desc_lines, lines, company=name, position=proj.get("role", "") + ) + if desc.strip(): + project_descs.append(desc.strip()) + + if project_names: + resume["project_tks"] = project_names + if project_descs: + resume["project_desc_tks"] = project_descs + logger.info(f"Project experience extraction succeeded: {len(projects)} entries") + + if not resume.get("name_kwd"): + resume["name_kwd"] = "Unknown" if _is_english(lang) else "未知" + + return resume if len(resume) > 2 else None + + except concurrent.futures.TimeoutError: + logger.warning("LLM parallel extraction timed out") + return None + except Exception as e: + logger.warning(f"LLM parallel extraction failed: {e}") + return None + + +# ==================== Phase 3: Regex Fallback Parsing ==================== + + + +def parse_with_regex(text: str, lang: str = "Chinese") -> dict: + """ + Parse resume text using regex (fallback strategy) + + When LLM parsing fails, use regex to extract basic structured info from text. + + Args: + text: Resume text content (without line number index) + lang: Language parameter, default "Chinese" + Returns: + Structured resume info dictionary + """ + resume: dict = {} + lines = [line.strip() for line in text.split("\n") if line.strip()] + + # --- Extract Name --- + if _is_english(lang): + # English resume: extract from "Name: XXX" format + for line in lines[:30]: + name_match = re.search(r'(?:Name|Full\s*Name)\s*[::]\s*([A-Za-z][A-Za-z\s\-\.]{1,40})', line, re.IGNORECASE) + if name_match: + resume["name_kwd"] = name_match.group(1).strip() + break + # English resume strategy 2: first line if short text without digits, may be a name + if "name_kwd" not in resume and lines: + first = lines[0].strip() + if len(first) <= 40 and not re.search(r"\d", first) and re.match(r'^[A-Za-z][A-Za-z\s\-\.]+$', first): + resume["name_kwd"] = first + else: + # Chinese resume: extract from "姓名:XXX" format + for line in lines[:30]: + name_match = re.search(r'姓\s*名\s*[::]\s*([\u4e00-\u9fa5]{2,4})', line) + if name_match: + resume["name_kwd"] = name_match.group(1) + break + + # Strategy 2: search first 20 lines for standalone Chinese names (2-4 chars), excluding common title words + if "name_kwd" not in resume: + title_words = { + "个人", "简历", "求职", "应聘", "基本", "信息", "概述", "简介", + "教育", "工作", "经历", "经验", "技能", "项目", "自我", "评价", + "专业", "技术", "证书", "语言", "能力", "培训", "荣誉", "奖项", + } + for line in lines[:20]: + if any(w in line for w in title_words): + continue + if re.search(r'[::]', line) and len(line) > 6: + continue + cleaned = re.sub(r"^[A-Za-z_\-\d\s]+\s+", "", line) + cleaned = re.sub(r"\s+[A-Za-z_\-\d\s]+$", "", cleaned).strip() + if 2 <= len(cleaned) <= 4 and re.match(r"^[\u4e00-\u9fa5]{2,4}$", cleaned): + resume["name_kwd"] = cleaned + break + + # Strategy 3: first line if short without digits, may be a name + if "name_kwd" not in resume and lines: + first = lines[0].strip() + if len(first) <= 10 and not re.search(r"\d", first): + cn_part = re.findall(r'[\u4e00-\u9fa5]+', first) + if cn_part and 2 <= len(cn_part[0]) <= 4: + resume["name_kwd"] = cn_part[0] + + # --- Extract Phone Number --- + phones = re.findall(r"1[3-9]\d{9}", text) + if phones: + resume["phone_kwd"] = phones[0] + + # --- Extract Email --- + emails = re.findall(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", text) + if emails: + resume["email_tks"] = emails[0] + + # --- Extract Gender --- + if _is_english(lang): + # English resume: extract from "Gender: Male/Female" format + gender_label = re.search(r'(?:Gender|Sex)\s*[::]\s*(Male|Female|M|F)', text, re.IGNORECASE) + if gender_label: + raw = gender_label.group(1).strip().upper() + resume["gender_kwd"] = "Male" if raw in ("M", "MALE") else "Female" + else: + gender_match = re.search(r'\b(Male|Female)\b', text[:500], re.IGNORECASE) + if gender_match: + resume["gender_kwd"] = gender_match.group(1).capitalize() + else: + # Chinese resume: extract from "性别:男/女" format + gender_label = re.search(r'性\s*别\s*[::]\s*(男|女)', text) + if gender_label: + resume["gender_kwd"] = gender_label.group(1) + else: + gender_match = re.search(r"(男|女)", text[:500]) + if gender_match: + resume["gender_kwd"] = gender_match.group(1) + + # --- Extract Age --- + if _is_english(lang): + # English resume: match "25 years old" or "Age: 25" + age_match = re.search(r'(?:Age)\s*[::]\s*(\d{1,2})', text, re.IGNORECASE) + if not age_match: + age_match = re.search(r'(\d{1,2})\s*years?\s*old', text, re.IGNORECASE) + if age_match: + resume["age_int"] = int(age_match.group(1)) + else: + # Chinese resume: match "25岁" + age_match = re.search(r"(\d{1,2})\s*岁", text) + if age_match: + resume["age_int"] = int(age_match.group(1)) + + # --- Extract Date of Birth --- + if _is_english(lang): + # English resume: match "1990-01-15" or "Jan 15, 1990" etc. + birth_match = re.search(r'(?:Birth|DOB|Date\s*of\s*Birth)\s*[::]\s*(.{6,20})', text, re.IGNORECASE) + if birth_match: + resume["birth_dt"] = birth_match.group(1).strip() + else: + birth_match = re.search(r"(19|20)\d{2}[-/]\d{1,2}[-/]\d{1,2}", text) + if birth_match: + resume["birth_dt"] = birth_match.group(0) + else: + # Chinese resume: match "1990年1月15日" or "1990-01-15" + birth_match = re.search(r"(19|20)\d{2}[年/-]\d{1,2}[月/-]\d{1,2}", text) + if birth_match: + resume["birth_dt"] = birth_match.group(0) + + # --- Extract Education Level --- + degree_keywords_zh = ["博士", "硕士", "本科", "大专", "专科", "高中", "MBA", "EMBA", "MPA"] + degree_keywords_en = ["PhD", "Master", "Bachelor", "Associate", "Diploma", "High School", + "MBA", "EMBA", "MPA", "Doctor"] + degree_keywords = degree_keywords_en if _is_english(lang) else degree_keywords_zh + found_degrees = [d for d in degree_keywords if d in text] + if found_degrees: + resume["degree_kwd"] = found_degrees + + # --- Extract School --- + if _is_english(lang): + # English resume: match "University/College/Institute/School" keywords + schools = re.findall( + r'([A-Z][A-Za-z\s\-&]{2,40}(?:University|College|Institute|School|Academy))', + text + ) + # Remove extra whitespace + schools = [re.sub(r'\s+', ' ', s).strip() for s in schools] + else: + # Chinese resume: match "XX大学/学院/职业技术学院" + schools = re.findall(r"[\u4e00-\u9fa5]{2,15}(?:大学|学院|职业技术学院)", text) + if schools: + resume["school_name_tks"] = list(set(schools)) + resume["first_school_name_tks"] = schools[0] + + # --- Extract Major --- + if _is_english(lang): + # English resume: match "Major: XXX" / "Field of Study: XXX" / "Specialization: XXX" + majors = re.findall( + r'(?:Major|Field\s*of\s*Study|Specialization|Concentration)\s*[::]\s*([A-Za-z\s\-&,]{2,40})', + text, re.IGNORECASE + ) + majors = [m.strip() for m in majors if m.strip()] + else: + # Chinese resume: match "专业:XXX" + majors = re.findall(r"专业[::]\s*([\u4e00-\u9fa5]{2,20})", text) + if majors: + resume["major_tks"] = majors + resume["first_major_tks"] = majors[0] + + # --- Extract Company Names --- + if _is_english(lang): + # English resume: match common company suffixes + en_company_patterns = [ + r'([A-Z][A-Za-z\s\-&,\.]{2,40}(?:Inc\.|Corp\.|Ltd\.|LLC|Co\.|Company|Group|Technologies|Technology|Solutions|Consulting|Services|Bank))', + ] + companies = [] + for pattern in en_company_patterns: + companies.extend(re.findall(pattern, text)) + companies = [re.sub(r'\s+', ' ', c).strip() for c in companies] + else: + # Chinese resume: match "XX有限公司" format + company_patterns = [ + r"[\u4e00-\u9fa5]{2,20}[((][\u4e00-\u9fa5]{2,10}[))](?:科技|信息技术|网络科技)?(?:股份)?有限公司", + r"[\u4e00-\u9fa5]{4,20}(?:科技|信息技术|网络科技|银行)?(?:股份)?有限公司", + ] + companies = [] + for pattern in company_patterns: + companies.extend(re.findall(pattern, text)) + + unique_companies = [] + seen = set() + # Filter verb list (bilingual) + filter_verbs = ( + ["completed", "conducted", "implemented", "responsible", "participated", "developed"] + if _is_english(lang) + else ["完成", "进行", "实施", "负责", "参与", "开发"] + ) + min_len = 3 if _is_english(lang) else 6 + for c in companies: + if len(c) < min_len or any(v in c.lower() for v in filter_verbs) or c in seen: + continue + is_sub = False + for existing in list(unique_companies): + if c in existing: + is_sub = True + break + if existing in c: + unique_companies.remove(existing) + seen.discard(existing) + if not is_sub: + unique_companies.append(c) + seen.add(c) + + if unique_companies: + resume["corp_nm_tks"] = unique_companies + resume["corporation_name_tks"] = unique_companies[0] + + # --- Extract Position (improved: context constraints to reduce noise) --- + if _is_english(lang): + # English resume: Strategy 1 - extract from "Title: XXX" / "Position: XXX" / "Role: XXX" format + position_label_matches = re.findall( + r'(?:Title|Position|Role|Job\s*Title)\s*[::]\s*([A-Za-z\s\-/&]{2,30})', + text, re.IGNORECASE + ) + positions = [p.strip() for p in position_label_matches if p.strip()] + + # English resume: Strategy 2 - match common position suffix keywords + en_position_suffixes = [ + "Engineer", "Manager", "Director", "Supervisor", "Specialist", + "Designer", "Consultant", "Assistant", "Architect", "Analyst", + "Developer", "Lead", "Officer", "Coordinator", "Administrator", + "Intern", "VP", "President", + ] + for line in lines: + if len(line) > 60: + continue # Skip overly long lines (usually description text) + for suffix in en_position_suffixes: + match = re.search(rf'([A-Za-z\s\-]{{1,25}}{suffix})\b', line, re.IGNORECASE) + if match: + pos = match.group(1).strip() + # Filter out matches that are clearly not positions (contain verbs) + filter_pos_verbs = ["responsible", "participated", "completed", "developed", "designed"] + if not any(v in pos.lower() for v in filter_pos_verbs) and len(pos) > 3: + positions.append(pos) + else: + # Chinese resume: Strategy 1 - extract from "职位/岗位:XXX" format + position_label_matches = re.findall( + r'(?:职位|岗位|职务|职称|担任)\s*[::]\s*([\u4e00-\u9fa5a-zA-Z]{2,15})', + text + ) + positions = list(position_label_matches) + + # Chinese resume: Strategy 2 - extract from work experience paragraphs (company name followed by position) + for line in lines: + pos_match = re.search( + r'(?:有限公司|集团|银行)\s+([\u4e00-\u9fa5]{2,8}(?:工程师|经理|总监|主管|专员|设计师|顾问|助理|架构师|分析师|运营|产品))', + line + ) + if pos_match: + positions.append(pos_match.group(1)) + + # Chinese resume: Strategy 3 - position keywords in standalone lines (length-limited to avoid matching description text) + position_suffixes = ["工程师", "经理", "总监", "主管", "专员", "设计师", "顾问", + "助理", "架构师", "分析师", "开发者", "负责人"] + for line in lines: + if len(line) > 20: + continue # Skip overly long lines + for suffix in position_suffixes: + match = re.search(rf'([\u4e00-\u9fa5]{{1,6}}{suffix})', line) + if match: + pos = match.group(1) + if not any(v in pos for v in ["负责", "参与", "完成", "开发了", "设计了"]): + positions.append(pos) + + if positions: + # Deduplicate while preserving order + seen_pos = set() + unique_positions = [] + for p in positions: + if p not in seen_pos: + seen_pos.add(p) + unique_positions.append(p) + resume["position_name_tks"] = unique_positions + + # --- Extract Years of Experience --- + if _is_english(lang): + # English resume: match "5 years experience" / "5+ years of experience" + work_exp_match = re.search(r'(\d+)\+?\s*years?\s*(?:of\s*)?(?:experience|work)', text, re.IGNORECASE) + if work_exp_match: + resume["work_exp_flt"] = float(work_exp_match.group(1)) + else: + # Chinese resume: match "5年...经验" + work_exp_match = re.search(r"(\d+)\s*年.*?经验", text) + if work_exp_match: + resume["work_exp_flt"] = float(work_exp_match.group(1)) + + # --- Extract Graduation Year --- + if _is_english(lang): + # English resume: match "Graduated 2020" / "Graduation: 2020" / "Class of 2020" + grad_match = re.search(r'(?:Graduat(?:ed|ion)|Class\s*of)\s*[::]?\s*((?:19|20)\d{2})', text, re.IGNORECASE) + if grad_match: + resume["edu_end_int"] = int(grad_match.group(1)) + else: + # Chinese resume: match "2020年...毕业" + grad_match = re.search(r"((?:19|20)\d{2})\s*年.*?毕业", text) + if grad_match: + resume["edu_end_int"] = int(grad_match.group(1)) + + if "name_kwd" not in resume: + resume["name_kwd"] = "Unknown" if _is_english(lang) else "未知" + + return resume + + + +# ==================== Phase 4: Post-processing Pipeline ==================== + + +def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") -> dict: + """ + Four-phase post-processing pipeline (ref: SmartResume Section 3.2.3) + + 1. Source text validation: check if key fields can be found in the original text + 2. Domain normalization: standardize date formats, clean company name suffix noise + 3. Contextual deduplication: remove duplicate company/school entries + 4. Field completion: ensure all required fields exist + + Args: + resume: Raw resume dictionary extracted by LLM + lines: Original line text list (for source text validation) + lang: Language parameter, default "Chinese" + Returns: + Post-processed resume dictionary + """ + _en = _is_english(lang) + full_text = "\n".join(lines) if lines else "" + # Normalize full text for comparison (ref: SmartResume _validate_fields_in_text) + norm_full_text = _normalize_for_comparison(full_text) + + # --- Phase 1: Source text validation (prune hallucinations, ref: SmartResume _validate_fields_in_text) --- + # Name validation: clear if not found in source text (SmartResume strategy: discard hallucinated fields) + _unknown_names = ("未知", "Unknown") + if resume.get("name_kwd") and resume["name_kwd"] not in _unknown_names: + norm_name = _normalize_for_comparison(resume["name_kwd"]) + if norm_full_text and norm_name and norm_name not in norm_full_text: + logger.warning(f"Name '{resume['name_kwd']}' not found in source text, classified as LLM hallucination, cleared") + resume["name_kwd"] = "" + + # Validate company names (strict matching: full name must appear in source text, no longer using loose 4-char prefix matching) + if resume.get("corp_nm_tks") and norm_full_text: + verified_companies = [] + for company in resume["corp_nm_tks"]: + norm_company = _normalize_for_comparison(company) + if norm_company and norm_company in norm_full_text: + verified_companies.append(company) + else: + logger.debug(f"Company '{company}' not found in source text, filtered out") + # Update even if all filtered out (SmartResume strategy: prefer missing over wrong) + resume["corp_nm_tks"] = verified_companies + if verified_companies: + resume["corporation_name_tks"] = verified_companies[0] + else: + resume["corporation_name_tks"] = "" + + # Validate school names (ref: SmartResume _validate_fields_in_text) + if resume.get("school_name_tks") and norm_full_text: + verified_schools = [] + for school in resume["school_name_tks"]: + norm_school = _normalize_for_comparison(school) + if norm_school and norm_school in norm_full_text: + verified_schools.append(school) + else: + logger.debug(f"School '{school}' not found in source text, filtered out") + resume["school_name_tks"] = verified_schools + if verified_schools: + if resume.get("first_school_name_tks"): + # Ensure first_school is also in the verified list + if resume["first_school_name_tks"] not in verified_schools: + resume["first_school_name_tks"] = verified_schools[-1] + else: + resume["first_school_name_tks"] = "" + + # Validate position names + if resume.get("position_name_tks") and norm_full_text: + verified_positions = [] + for pos in resume["position_name_tks"]: + norm_pos = _normalize_for_comparison(pos) + if norm_pos and norm_pos in norm_full_text: + verified_positions.append(pos) + if verified_positions: + resume["position_name_tks"] = verified_positions + + # --- Phase 2: Domain normalization --- + # Standardize date format + if resume.get("birth_dt"): + resume["birth_dt"] = re.sub(r"[年月]", "-", str(resume["birth_dt"])).rstrip("-") + + # Clean non-digit characters from phone number (keep + sign) + if resume.get("phone_kwd"): + phone = re.sub(r"[^\d+]", "", str(resume["phone_kwd"])) + if phone: + resume["phone_kwd"] = phone + + # Standardize gender (output format determined by language parameter) + if resume.get("gender_kwd"): + gender = str(resume["gender_kwd"]).strip() + if gender in ("male", "Male", "M", "m", "男"): + resume["gender_kwd"] = "Male" if _en else "男" + elif gender in ("female", "Female", "F", "f", "女"): + resume["gender_kwd"] = "Female" if _en else "女" + + # --- Phase 3: Contextual deduplication --- + for list_field in ["corp_nm_tks", "school_name_tks", "major_tks", + "position_name_tks", "skill_tks"]: + if isinstance(resume.get(list_field), list): + # Order-preserving deduplication + seen = set() + deduped = [] + for item in resume[list_field]: + item_str = str(item).strip() + if item_str and item_str not in seen: + seen.add(item_str) + deduped.append(item_str) + resume[list_field] = deduped + + # --- Phase 4: Field completion --- + required_fields = [ + "name_kwd", "gender_kwd", "phone_kwd", "email_tks", + "position_name_tks", "school_name_tks", "major_tks", + ] + for field in required_fields: + if field not in resume: + if field.endswith("_tks"): + resume[field] = [] + elif field.endswith("_int") or field.endswith("_flt"): + resume[field] = 0 + else: + resume[field] = "" + + # Clean internal marker fields (already handled in Phase 1, this is a safety fallback) + resume.pop("_name_confidence", None) + + return resume + + +# ==================== Pipeline Orchestration & Chunk Construction ==================== + + +def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese") -> tuple[dict, list[str], list[dict]]: + """ + Resume parsing pipeline orchestration function + + Execution flow: + 1. Text extraction (dual-path fusion + layout reconstruction + line-number index) + 2. Parallel LLM structured extraction (three sub-tasks) + 3. Regex fallback parsing (when LLM fails) + 4. Four-phase post-processing + + Args: + filename: File name + binary: File binary content + lang: Language, default "Chinese" + Returns: + (resume, lines, line_positions) tuple: + - resume: Structured resume information dictionary + - lines: Original line text list (for chunk text matching and positioning) + - line_positions: Per-line coordinate info list (for writing chunk position_int fields) + """ + # Phase 1: Text extraction + indexed_text, lines, line_positions = extract_text(filename, binary) + if not indexed_text or not lines: + logger.warning(f"Text extraction returned empty: {filename}") + default_name = "Unknown" if _is_english(lang) else "未知" + return {"name_kwd": default_name}, [], [] + + # Phase 2: Parallel LLM structured extraction + resume = parse_with_llm(indexed_text, lines, tenant_id , lang) + + # Phase 3: Fallback to regex parsing when LLM fails + if not resume: + logger.info(f"LLM parsing failed, falling back to regex parsing: {filename}") + plain_text = "\n".join(lines) + resume = parse_with_regex(plain_text, lang) + + # Phase 4: Post-processing pipeline + resume = _postprocess_resume(resume, lines, lang) + + return resume, lines, line_positions + + +def _build_chunk_document(filename: str, resume: dict, + lang: str = "Chinese") -> list[dict]: + """ + Build a list of document chunks from structured resume information + + Each field generates an independent chunk containing tokenization results and metadata. + Compatible with the build_chunks flow in task_executor.py. + + Key design: Each chunk redundantly includes key identity fields (name, phone, email, etc.), + so that when any chunk is retrieved, the candidate's identity can be immediately identified. + The full resume can be fetched via doc_id to get all chunks for complete information. + + Args: + filename: File name + resume: Structured resume information dictionary + lang: Language parameter, default "Chinese" + Returns: + Document chunk list, each chunk contains content_with_weight, content_ltks, + position_int, page_num_int, top_int and other fields + """ + chunks = [] + # Get the corresponding field map version based on language parameter + field_map = get_field_map(lang) doc = { "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize("-".join(titles) + "-简历") + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)), } doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - pairs = [] - for n, m in field_map.items(): - if not resume.get(n): + + # Extract key identity fields, redundantly written to each chunk + # These fields are small in size but high in information density; once retrieved, the candidate can be immediately identified + _IDENTITY_FIELDS = ("name_kwd", "phone_kwd", "email_tks", "gender_kwd", + "highest_degree_kwd", "work_exp_flt", "corporation_name_tks") + identity_meta = {} + for ik in _IDENTITY_FIELDS: + iv = resume.get(ik) + if not iv: continue - v = resume[n] - if isinstance(v, list): - v = " ".join(v) - if n.find("tks") > 0: - v = remove_redundant_spaces(v) - pairs.append((m, str(v))) - - doc["content_with_weight"] = "\n".join( - ["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k, v in pairs]) - doc["content_ltks"] = rag_tokenizer.tokenize(doc["content_with_weight"]) - doc["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(doc["content_ltks"]) - for n, _ in field_map.items(): - if n not in resume: + if ik.endswith("_tks"): + identity_meta[ik] = rag_tokenizer.tokenize( + " ".join(iv) if isinstance(iv, list) else str(iv) + ) + elif ik.endswith("_kwd"): + identity_meta[ik] = iv if isinstance(iv, list) else str(iv) + elif ik.endswith("_flt"): + try: + identity_meta[ik] = float(iv) + except (ValueError, TypeError): + pass + else: + identity_meta[ik] = str(iv) + + # Build resume summary text, appended to each chunk's content to improve semantic retrieval recall + summary_parts = [] + _en = _is_english(lang) + if resume.get("name_kwd"): + summary_parts.append(f"{'Name' if _en else '姓名'}:{resume['name_kwd']}") + if resume.get("phone_kwd"): + summary_parts.append(f"{'Phone' if _en else '电话'}:{resume['phone_kwd']}") + if resume.get("corporation_name_tks"): + corp = resume["corporation_name_tks"] + summary_parts.append(f"{'Company' if _en else '公司'}:{corp if isinstance(corp, str) else ' '.join(corp)}") + if resume.get("highest_degree_kwd"): + summary_parts.append(f"{'Degree' if _en else '学历'}:{resume['highest_degree_kwd']}") + if resume.get("work_exp_flt"): + if _en: + summary_parts.append(f"Experience:{resume['work_exp_flt']}yrs") + else: + summary_parts.append(f"经验:{resume['work_exp_flt']}年") + resume_summary = " | ".join(summary_parts) if summary_parts else "" + + # List fields that need per-element splitting (each experience/project generates a separate chunk to avoid oversized merged chunks) + _SPLIT_LIST_FIELDS = {"work_desc_tks", "project_desc_tks"} + + # Basic info field set: these fields should be merged into one chunk to avoid splitting name, phone, email, etc. + _BASIC_INFO_FIELDS = { + "name_kwd", "name_pinyin_kwd", "gender_kwd", "age_int", + "phone_kwd", "email_tks", "birth_dt", "work_exp_flt", + "position_name_tks", "expect_city_names_tks", + "expect_position_name_tks", + } + + # Education field set: degree, school, major, tags, etc. should be merged into one chunk + _EDUCATION_FIELDS = { + "first_school_name_tks", "first_degree_kwd", "highest_degree_kwd", + "first_major_tks", "edu_first_fea_kwd", "degree_kwd", "major_tks", + "school_name_tks", "sch_rank_kwd", "edu_fea_kwd", "edu_end_int", + } + + # Skills & certificates field set: skills, languages, certificates are small, merge into one chunk + _SKILL_CERT_FIELDS = { + "skill_tks", "language_tks", "certificate_tks", + } + + # Work overview field set: company list, industry, most recent company merged into one chunk + _WORK_OVERVIEW_FIELDS = { + "corporation_name_tks", "corp_nm_tks", "industry_name_tks", + } + + # All merge groups: (field_set, group_title) tuple list + _MERGE_GROUPS = [ + (_BASIC_INFO_FIELDS, "Basic Info" if _en else "基本信息"), + (_EDUCATION_FIELDS, "Education" if _en else "教育背景"), + (_SKILL_CERT_FIELDS, "Skills & Certificates" if _en else "技能与证书"), + (_WORK_OVERVIEW_FIELDS, "Work Overview" if _en else "工作概况"), + ] + + # Collect all fields that need merge processing; skip them during individual iteration + _ALL_MERGED_FIELDS = set() + for fields_set, _ in _MERGE_GROUPS: + _ALL_MERGED_FIELDS.update(fields_set) + + # Merge fields by group, generating one chunk per group + for fields_set, group_title in _MERGE_GROUPS: + group_parts = [] + group_field_values = {} # Store structured values for each field, to be written into chunk + for field_key in field_map: + if field_key not in fields_set: + continue + value = resume.get(field_key) + if not value: + continue + field_desc = field_map[field_key] + if isinstance(value, list): + text_value = " ".join(str(v) for v in value if v) + else: + text_value = str(value) + if not text_value.strip(): + continue + group_parts.append(f"{field_desc}: {text_value}") + group_field_values[field_key] = value + + if not group_parts: continue - if isinstance(resume[n], list) and ( - len(resume[n]) == 1 or n not in forbidden_select_fields4resume): - resume[n] = resume[n][0] - if n.find("_tks") > 0: - resume[n] = rag_tokenizer.fine_grained_tokenize(resume[n]) - doc[n] = resume[n] - logging.debug("chunked resume to " + str(doc)) - KnowledgebaseService.update_parser_config( - kwargs["kb_id"], {"field_map": field_map}) - return [doc] + content = f"{group_title}\n" + "\n".join(group_parts) + if resume_summary: + content += f"\n[{resume_summary}]" + chunk = { + "content_with_weight": content, + "content_ltks": rag_tokenizer.tokenize(content), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( + rag_tokenizer.tokenize(content) + ), + } + chunk.update(doc) + # Redundantly write identity fields + for mk, mv in identity_meta.items(): + chunk[mk] = mv + # Write each field's structured value into chunk (for structured retrieval) + for fk, fv in group_field_values.items(): + if fk.endswith("_tks"): + text_val = " ".join(str(v) for v in fv) if isinstance(fv, list) else str(fv) + chunk[fk] = rag_tokenizer.tokenize(text_val) + elif fk.endswith("_kwd"): + chunk[fk] = fv if isinstance(fv, list) else str(fv) + elif fk.endswith("_int"): + try: + chunk[fk] = int(fv) + except (ValueError, TypeError): + pass + elif fk.endswith("_flt"): + try: + chunk[fk] = float(fv) + except (ValueError, TypeError): + pass + else: + chunk[fk] = str(fv) + chunks.append(chunk) + # Iterate over field map, generating a chunk for each non-merged field with a value + for field_key, field_desc in field_map.items(): + # Skip fields already processed in merge groups + if field_key in _ALL_MERGED_FIELDS: + continue + value = resume.get(field_key) + if not value: + continue -if __name__ == "__main__": - import sys + # For work/project descriptions (long text lists), split into multiple chunks per element + if field_key in _SPLIT_LIST_FIELDS and isinstance(value, list): + # Get company name list to add context to each work description + corp_list = resume.get("corp_nm_tks", []) if field_key == "work_desc_tks" else [] + project_list = resume.get("project_tks", []) if field_key == "project_desc_tks" else [] + # Get detailed info for each work experience entry (time period, years) + work_details = resume.get("_work_exp_details", []) if field_key == "work_desc_tks" else [] + for idx, item in enumerate(value): + item_text = str(item).strip() + if not item_text: + continue - def dummy(a, b): - pass + # Add company/project name prefix to each description for context + if field_key == "work_desc_tks" and idx < len(work_details): + # Use detailed info to build prefix, including company, time range, years + detail = work_details[idx] + company = detail.get("company", "") + start_d = detail.get("start_date", "") + end_d = detail.get("end_date", "") + years = detail.get("years", 0) + # Build time range text + time_parts = [] + if start_d: + time_range = f"{start_d}-{end_d}" if end_d else str(start_d) + time_parts.append(time_range) + if years > 0: + time_parts.append(f"{years}{'yrs' if _en else '年'}") + time_text = " ".join(time_parts) + if company and time_text: + content_prefix = f"{field_desc}({company} {time_text})" + elif company: + content_prefix = f"{field_desc}({company})" + else: + content_prefix = f"{field_desc}({'#' if _en else '第'}{idx + 1}{'' if _en else '段'})" + elif field_key == "work_desc_tks" and idx < len(corp_list): + content_prefix = f"{field_desc}({corp_list[idx]})" + elif field_key == "project_desc_tks" and idx < len(project_list): + content_prefix = f"{field_desc}({project_list[idx]})" + else: + content_prefix = f"{field_desc}({'#' if _en else '第'}{idx + 1}{'' if _en else '段'})" + + if resume_summary: + content = f"{content_prefix}: {item_text}\n[{resume_summary}]" + else: + content = f"{content_prefix}: {item_text}" + + chunk = { + "content_with_weight": content, + "content_ltks": rag_tokenizer.tokenize(content), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( + rag_tokenizer.tokenize(content) + ), + } + chunk.update(doc) + + # Redundantly write identity fields + for mk, mv in identity_meta.items(): + if mk != field_key: + chunk[mk] = mv + + # Tokenization result for current segment + chunk[field_key] = rag_tokenizer.tokenize(item_text) + chunks.append(chunk) + continue + + # Merge list values into text + if isinstance(value, list): + text_value = " ".join(str(v) for v in value if v) + else: + text_value = str(value) + + if not text_value.strip(): + continue + + # Build chunk content: "field_desc: field_value", append summary for semantic association + if resume_summary and field_key not in ("name_kwd", "phone_kwd"): + content = f"{field_desc}: {text_value}\n[{resume_summary}]" + else: + content = f"{field_desc}: {text_value}" + + chunk = { + "content_with_weight": content, + "content_ltks": rag_tokenizer.tokenize(content), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( + rag_tokenizer.tokenize(content) + ), + } + chunk.update(doc) + + # Redundantly write identity fields (do not overwrite the current field's own value) + for mk, mv in identity_meta.items(): + if mk != field_key: + chunk[mk] = mv + + # Write resume field value into the chunk's corresponding field (for structured retrieval) + if field_key.endswith("_tks"): + chunk[field_key] = rag_tokenizer.tokenize(text_value) + elif field_key.endswith("_kwd"): + if isinstance(value, list): + chunk[field_key] = value + else: + chunk[field_key] = text_value + elif field_key.endswith("_int"): + try: + chunk[field_key] = int(value) + except (ValueError, TypeError): + pass + elif field_key.endswith("_flt"): + try: + chunk[field_key] = float(value) + except (ValueError, TypeError): + pass + else: + chunk[field_key] = text_value + + chunks.append(chunk) + + # If no chunks were generated, create at least one chunk containing the name + if not chunks: + name = resume.get("name_kwd", "Unknown" if _en else "未知") + content = f"{'Name' if _en else '姓名'}: {name}" + chunk = { + "content_with_weight": content, + "content_ltks": rag_tokenizer.tokenize(content), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( + rag_tokenizer.tokenize(content) + ), + } + chunk.update(doc) + chunks.append(chunk) + + # Write coordinate info to each chunk (position_int, page_num_int, top_int) + # + # Resume chunks are split by semantic fields (basic info, education, work description, etc.), + # not by PDF physical regions. Field values may be scattered across multiple locations in the PDF, + # and using text matching to reverse-lookup coordinates would cause disordered sorting. + # + # Therefore, assign incrementing coordinates based on chunk generation order (i.e., semantic logical order), + # ensuring display order: basic info -> education -> skills/certs -> work overview -> work desc -> project desc... + # + # add_positions input format: [(page, left, right, top, bottom), ...] + # - page starts from 0, function internally stores +1 + # - task_executor sorts by page_num_int and top_int (page first, then Y coordinate) + from rag.nlp import add_positions + + for i, ck in enumerate(chunks): + # All chunks placed on page=0, top increments by index to ensure logical ordering + add_positions(ck, [[0, 0, 0, i, i]]) + + return chunks + +def _blackout_text_regions(image: "np.ndarray", meta_blocks: list[dict], page_idx: int, + pdf_to_img_scale: float) -> "np.ndarray": + """ + Black out metadata-extracted text regions on the page image to prevent OCR duplication. + + Ref: SmartResume blackout strategy — extract metadata text first, black out those regions, + then run OCR on the blacked-out image so it only recognizes content metadata missed. + More reliable than IoU-based deduplication. + + Args: + image: Page image (numpy array) + meta_blocks: Text blocks from metadata extraction + page_idx: Current page number + pdf_to_img_scale: Scale factor from PDF coordinates to image coordinates + Returns: + Image with text regions blacked out + """ + import cv2 + blacked = image.copy() + page_blocks = [b for b in meta_blocks if b.get("page") == page_idx] + # Draw filled black rectangles over each metadata text block + padding = 2 # Extra pixels to ensure full coverage + for b in page_blocks: + x0 = int(b["x0"] * pdf_to_img_scale) - padding + y0 = int(b["top"] * pdf_to_img_scale) - padding + x1 = int(b["x1"] * pdf_to_img_scale) + padding + y1 = int(b["bottom"] * pdf_to_img_scale) + padding + # Clamp to image boundaries + x0 = max(0, x0) + y0 = max(0, y0) + x1 = min(blacked.shape[1], x1) + y1 = min(blacked.shape[0], y1) + cv2.rectangle(blacked, (x0, y0), (x1, y1), (0, 0, 0), -1) + return blacked + + + +def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): + """ + Resume parsing entry function (compatible with task_executor.py) + + This function is the entry point registered as FACTORY[ParserType.RESUME.value], + with a signature consistent with other parsers (e.g., naive.chunk). + + Args: + filename: File name + binary: File binary content + from_page: Start page number (not used in resume parsing) + to_page: End page number (not used in resume parsing) + lang: Language, default "Chinese" + callback: Progress callback function, accepts (progress, message) parameters + **kwargs: Other parameters (parser_config, kb_id, tenant_id, etc.) + Returns: + Document chunk list + """ + if callback is None: + def callback(prog, msg): return None + + try: + callback(0.1, "Starting resume parsing...") + + # Parse resume + resume, lines, line_positions = parse_resume(filename, binary, tenant_id , lang) + callback(0.6, "Resume structured extraction complete") + + # Build document chunks (with coordinate info) + chunks = _build_chunk_document(filename, resume, lang) + callback(0.9, f"Document chunk construction complete, {len(chunks)} chunks total") + + callback(1.0, "Resume parsing complete") + return chunks + + except Exception as e: + logger.exception(f"Resume parsing exception: {filename}") + callback(-1, f"Resume parsing failed: {str(e)}") + return [] + + +def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict]) -> list[dict]: + if not page_blocks: + return [] + + if not layout_regions: + return sorted(page_blocks, key=lambda b: ( + (b.get("top", 0) + b.get("bottom", 0)) / 2, + (b.get("x0", 0) + b.get("x1", 0)) / 2, + )) + + type_groups: dict[str, list] = {} + for lt in layout_regions: + tp = lt.get("type", "") + type_groups.setdefault(tp, []).append(lt) + entries = [] + for tp, group in type_groups.items(): + for idx, lt in enumerate(group): + key = f"{tp}-{idx}" + x0, x1 = lt.get("x0", 0), lt.get("x1", 0) + top, bottom = lt.get("top", 0), lt.get("bottom", 0) + entries.append({ + "key": key, "type": tp, + "x0": x0, "top": top, "x1": x1, "bottom": bottom, + "cy": (top + bottom) / 2, "cx": (x0 + x1) / 2, + }) + + for b in page_blocks: + if b.get("layoutno"): + continue + b_cx = (b.get("x0", 0) + b.get("x1", 0)) / 2 + b_cy = (b.get("top", 0) + b.get("bottom", 0)) / 2 + for entry in entries: + if (entry["x0"] <= b_cx <= entry["x1"] + and entry["top"] <= b_cy <= entry["bottom"]): + b["layoutno"] = entry["key"] + b["layout_type"] = entry["type"] + break + + for entry in entries: + layout_key = entry["key"] + layout_area = (entry["x1"] - entry["x0"]) * (entry["bottom"] - entry["top"]) + if layout_area <= 0: + continue + layout_blocks = [b for b in page_blocks if b.get("layoutno") == layout_key] + if not layout_blocks: + continue + text_total_area = sum( + (b.get("x1", 0) - b.get("x0", 0)) * (b.get("bottom", 0) - b.get("top", 0)) + for b in layout_blocks + ) + if text_total_area / layout_area < 0.075: + for b in layout_blocks: + b["layoutno"] = "" + b["layout_type"] = "" + + entry_map = {e["key"]: e for e in entries} + for b in page_blocks: + b_cx = (b.get("x0", 0) + b.get("x1", 0)) / 2 + b_cy = (b.get("top", 0) + b.get("bottom", 0)) / 2 + b["_x_center"] = b_cx + b["_y_center"] = b_cy + layoutno = b.get("layoutno", "") + if layoutno and layoutno in entry_map: + b["_lx_center"] = entry_map[layoutno]["cx"] + b["_ly_center"] = entry_map[layoutno]["cy"] + else: + b["_lx_center"] = b_cx + b["_ly_center"] = b_cy + + active_keys = {b.get("layoutno") for b in page_blocks if b.get("layoutno")} + active_entries = [e for e in entries if e["key"] in active_keys] + + for b in page_blocks: + if b.get("layoutno"): + continue + if not active_entries: + continue + b_cx, b_cy = b["_x_center"], b["_y_center"] + min_dist = float("inf") + best_cx, best_cy = b_cx, b_cy + for ae in active_entries: + lx1, ly1, lx2, ly2 = ae["x0"], ae["top"], ae["x1"], ae["bottom"] + if b_cy < ly1: + dy = ly1 - b_cy + elif b_cy > ly2: + dy = b_cy - ly2 + else: + dy = 0 + if b_cx < lx1: + dx = lx1 - b_cx + elif b_cx > lx2: + dx = b_cx - lx2 + else: + dx = 0 + dist = (dx ** 2 + dy ** 2) ** 0.5 + if dist < min_dist: + min_dist = dist + best_cx, best_cy = ae["cx"], ae["cy"] + b["_lx_center"] = best_cx + b["_ly_center"] = best_cy + + sorted_blocks = sorted(page_blocks, key=lambda b: ( + b.get("_ly_center", 0), + b.get("_lx_center", 0), + b.get("_y_center", 0), + b.get("_x_center", 0), + )) + + for b in sorted_blocks: + b.pop("_ly_center", None) + b.pop("_lx_center", None) + b.pop("_y_center", None) + b.pop("_x_center", None) + + return sorted_blocks + + +def _layout_detect_reorder(blocks: list[dict], binary: bytes) -> list[dict]: + if not blocks: + return blocks + + recognizer = _get_layout_recognizer() + if recognizer is None: + logger.info("Layout detector unavailable, falling back to heuristic sorting") + return _layout_aware_reorder(blocks) + + try: + import pdfplumber + pages_blocks: dict[int, list[dict]] = {} + for b in blocks: + pg = b.get("page", 0) + pages_blocks.setdefault(pg, []).append(b) + + page_indices = sorted(pages_blocks.keys()) + image_list = [] + ocr_res_per_page = [] + + with pdfplumber.open(BytesIO(binary)) as pdf: + for pg in page_indices: + if pg >= len(pdf.pages): + continue + page = pdf.pages[pg] + pil_img = page.to_image(resolution=72 * 3).annotated + image_list.append(pil_img) + + page_bxs = [] + for b in pages_blocks[pg]: + page_bxs.append({ + "x0": float(b["x0"]), + "top": float(b["top"]), + "x1": float(b["x1"]), + "bottom": float(b["bottom"]), + "text": b["text"], + "page": pg, + }) + ocr_res_per_page.append(page_bxs) + + if not image_list: + return _layout_aware_reorder(blocks) + + tagged_blocks, page_layouts = recognizer( + image_list, ocr_res_per_page, scale_factor=3, thr=0.2, drop=False + ) + + if not tagged_blocks: + logger.warning("Layout detector unavailable, falling back to heuristic sorting") + return _layout_aware_reorder(blocks) + + tagged_per_page: dict[int, list[dict]] = {} + for b in tagged_blocks: + pg = b.get("page", 0) + tagged_per_page.setdefault(pg, []).append(b) + + sorted_all = [] + total_layout_count = 0 + for pn, pg in enumerate(page_indices): + page_bxs = tagged_per_page.get(pg, []) + lts = page_layouts[pn] if pn < len(page_layouts) else [] + total_layout_count += len(lts) + sorted_page = _resort_page_with_layout(page_bxs, lts) + sorted_all.extend(sorted_page) + + for b in sorted_all: + if "page" not in b: + b["page"] = 0 + logger.info(f"YOLOv10 detector completed, {len(sorted_all)} total chunks," + f"checked {total_layout_count} layout") + return sorted_all - chunk(sys.argv[1], callback=dummy) + except Exception as e: + logger.warning(f"Layout detector unavailable, falling back to heuristic sorting: {e}") + return _layout_aware_reorder(blocks) diff --git a/rag/prompts/resume_basic_info.md b/rag/prompts/resume_basic_info.md new file mode 100644 index 00000000000..7a3756813de --- /dev/null +++ b/rag/prompts/resume_basic_info.md @@ -0,0 +1,39 @@ +请从以下带行号索引的简历文本中提取基本信息。 + +{indexed_text} + +提取如下信息到 JSON,若某些字段不存在则输出 "" 空或 0: +{{ + "name_kwd": "", + "gender_kwd": "", + "age_int": 0, + "phone_kwd": "", + "email_tks": "", + "birth_dt": "", + "work_exp_flt": 0, + "current_location": "", + "expect_city_names_tks": [], + "expect_position_name_tks": [], + "skill_tks": [], + "language_tks": [], + "certificate_tks": [], + "self_evaluation_tks": "" +}} + +字段说明: +- name_kwd: 姓名,如"张三" +- gender_kwd: 男/女,若不存在则不填 +- age_int: 当前年龄,整数 +- phone_kwd: 电话/手机,请保留原文中的形式,保留国家码区号括号 +- email_tks: 邮箱,如 "xxx@qq.com" +- birth_dt: 出生年月,如 "1996-11" +- work_exp_flt: 工作年限,浮点数 +- current_location: 现居地/当前城市,不要从工作经历中推测,要写明现居地 +- expect_city_names_tks: 期望工作城市列表,简历中需要明确说明是期望城市 +- expect_position_name_tks: 期望职位列表 +- skill_tks: 技能/技术栈列表 +- language_tks: 语言能力列表 +- certificate_tks: 证书/资质列表 +- self_evaluation_tks: 自我评价/个人优势/个人总结,完整提取原文内容 + +只返回 JSON。 /no_think \ No newline at end of file diff --git a/rag/prompts/resume_basic_info_en.md b/rag/prompts/resume_basic_info_en.md new file mode 100644 index 00000000000..7ea6dd0bc81 --- /dev/null +++ b/rag/prompts/resume_basic_info_en.md @@ -0,0 +1,39 @@ +Please extract basic information from the following line-indexed resume text. + +{indexed_text} + +Extract the following information into JSON. If a field does not exist, output "" or 0: +{{ + "name_kwd": "", + "gender_kwd": "", + "age_int": 0, + "phone_kwd": "", + "email_tks": "", + "birth_dt": "", + "work_exp_flt": 0, + "current_location": "", + "expect_city_names_tks": [], + "expect_position_name_tks": [], + "skill_tks": [], + "language_tks": [], + "certificate_tks": [], + "self_evaluation_tks": "" +}} + +Field descriptions: +- name_kwd: Full name, e.g. "John Smith" +- gender_kwd: Male/Female, leave empty if not present +- age_int: Current age, integer +- phone_kwd: Phone number, keep original format including country code and brackets +- email_tks: Email address, e.g. "xxx@gmail.com" +- birth_dt: Date of birth, e.g. "1996-11" +- work_exp_flt: Years of work experience, float +- current_location: Current city/location, do not infer from work experience, must be explicitly stated +- expect_city_names_tks: List of preferred work cities, must be explicitly stated in the resume +- expect_position_name_tks: List of desired positions +- skill_tks: List of skills/tech stack +- language_tks: List of language proficiencies +- certificate_tks: List of certificates/qualifications +- self_evaluation_tks: Self-evaluation/personal strengths/summary, extract full original text + +Return JSON only. /no_think \ No newline at end of file diff --git a/rag/prompts/resume_education.md b/rag/prompts/resume_education.md new file mode 100644 index 00000000000..95ff8eb4d6b --- /dev/null +++ b/rag/prompts/resume_education.md @@ -0,0 +1,31 @@ +请从以下带行号索引的简历文本中提取教育背景。 + +{indexed_text} + +提取为 JSON: +{{ + "education": [ + {{ + "school": "", + "major": "", + "degree": "", + "department": "", + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +字段说明: +- school: 学校全称,如"厦门大学",中英文都可以 +- major: 专业,如"机械工程" +- degree: 学位,本科/硕士/博士/专科/高中/初中,若不存在则填"" +- department: 系/学院,如"信息工程系" +- start_date: 开始时间,格式为 %Y.%m 或 %Y +- end_date: 结束时间,若至今填写"至今",若不存在填写"" +- desc_lines: [起始行号, 结束行号],教育描述对应的行号范围(可选) + - 包括课程成绩、研究方向、GPA、荣誉奖项等 + - 不存在则填 [] + +只返回 JSON。 /no_think \ No newline at end of file diff --git a/rag/prompts/resume_education_en.md b/rag/prompts/resume_education_en.md new file mode 100644 index 00000000000..9d726b48b49 --- /dev/null +++ b/rag/prompts/resume_education_en.md @@ -0,0 +1,31 @@ +Please extract education background from the following line-indexed resume text. + +{indexed_text} + +Extract into JSON: +{{ + "education": [ + {{ + "school": "", + "major": "", + "degree": "", + "department": "", + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +Field descriptions: +- school: Full school name, e.g. "Stanford University", both Chinese and English are acceptable +- major: Major/field of study, e.g. "Computer Science" +- degree: Degree level - Bachelor/Master/PhD/Associate/High School/Middle School, leave "" if not available +- department: Department/College, e.g. "School of Engineering" +- start_date: Start date, format %Y.%m or %Y +- end_date: End date, use "Present" if still enrolled, "" if not available +- desc_lines: [start_line, end_line], line number range for education description (optional) + - Includes coursework, research focus, GPA, honors/awards, etc. + - Use [] if not available + +Return JSON only. /no_think \ No newline at end of file diff --git a/rag/prompts/resume_project_exp.md b/rag/prompts/resume_project_exp.md new file mode 100644 index 00000000000..ed216deabab --- /dev/null +++ b/rag/prompts/resume_project_exp.md @@ -0,0 +1,31 @@ +请从以下带行号索引的简历文本中提取项目经验。 + +{indexed_text} + +提取为 JSON,每段项目经验包含: +{{ + "projectExperience": [ + {{ + "project_name": "", + "role": "", + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +字段说明: +- project_name: 项目名称 +- role: 担任角色/职责,如"项目负责人"、"后端开发" +- start_date: 开始时间,格式为 %Y.%m 或 %Y +- end_date: 结束时间,若至今填写"至今",若不存在填写"" +- desc_lines: [起始行号, 结束行号],项目描述对应的行号范围(整数数组) + - 指项目描述的原文引用段落 index 范围,包括项目内容、技术栈、成果等 + - 不包括 project_name、role、start_date、end_date 所在行 + - 尽可能写全,直到下一段项目经验或其他段落标题为止 + - 遇到以下段落标题时必须截止,不要将其包含在 desc_lines 中: + 个人评价、自我评价、个人总结、个人优势、自我描述、技能特长、专业技能、教育背景、教育经历、工作经历、工作经验、证书资质、语言能力、兴趣爱好、求职意向 + - 如果不存在就写 [] + +只返回 JSON。 /no_think \ No newline at end of file diff --git a/rag/prompts/resume_project_exp_en.md b/rag/prompts/resume_project_exp_en.md new file mode 100644 index 00000000000..e33de88e5ce --- /dev/null +++ b/rag/prompts/resume_project_exp_en.md @@ -0,0 +1,31 @@ +Please extract project experience from the following line-indexed resume text. + +{indexed_text} + +Extract into JSON, each project experience entry contains: +{{ + "projectExperience": [ + {{ + "project_name": "", + "role": "", + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +Field descriptions: +- project_name: Project name +- role: Role/responsibility, e.g. "Project Lead", "Backend Developer" +- start_date: Start date, format %Y.%m or %Y +- end_date: End date, use "Present" if ongoing, "" if not available +- desc_lines: [start_line, end_line], line number range for project description (integer array) + - Refers to the original text reference range for project description, including project content, tech stack, achievements, etc. + - Does not include lines containing project_name, role, start_date, end_date + - Include as much as possible until the next project experience entry or other section heading + - STOP before these section headings (do not include them in desc_lines): + Self-evaluation, Personal Summary, Skills, Technical Skills, Education, Work Experience, Certificates, Languages, Hobbies, Career Objective + - Use [] if not available + +Return JSON only. /no_think \ No newline at end of file diff --git a/rag/prompts/resume_system.md b/rag/prompts/resume_system.md new file mode 100644 index 00000000000..9b3419f41ec --- /dev/null +++ b/rag/prompts/resume_system.md @@ -0,0 +1,3 @@ +你是一个专业的简历分析助手。你的任务是将给定的简历文本转换为 JSON 输出。 +(如果有中英文简历同时出现时,只关注中文简历) +严格按照 JSON 格式返回结果,不要有任何其他文字。 \ No newline at end of file diff --git a/rag/prompts/resume_system_en.md b/rag/prompts/resume_system_en.md new file mode 100644 index 00000000000..8d02488f26c --- /dev/null +++ b/rag/prompts/resume_system_en.md @@ -0,0 +1,3 @@ +You are a professional resume analysis assistant. Your task is to convert the given resume text into JSON output. +(If both Chinese and English resumes appear, focus only on the English resume) +Strictly return results in JSON format without any other text. \ No newline at end of file diff --git a/rag/prompts/resume_work_exp.md b/rag/prompts/resume_work_exp.md new file mode 100644 index 00000000000..2a7465c16ef --- /dev/null +++ b/rag/prompts/resume_work_exp.md @@ -0,0 +1,39 @@ +请从以下带行号索引的简历文本中提取工作经历。 + +{indexed_text} + +提取为 JSON,每段工作经历包含: +{{ + "workExperience": [ + {{ + "company": "", + "position": "", + "internship": 0, + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +字段说明: +- company: 公司全称(含括号内地区信息),如"阿里巴巴(中国)有限公司" +- position: 职位名称,遵循原文不要编造或推测 +- internship: 该段经历是否是实习,是实习为1,不是为0 +- start_date: 入职时间,格式为 %Y.%m 或 %Y,如 "2024.1" +- end_date: 离职时间,若至今填写"至今",若不存在填写"" +- desc_lines: [起始行号, 结束行号],工作描述对应的行号范围(整数数组) + - 指工作经历描述的原文引用段落 index 范围,包括工作成果、业绩、主要工作、技术栈等 + - 不包括 company、position、start_date、end_date 所在行 + - 尽可能写全,直到下一段工作经历或其他段落标题为止 + - 遇到以下段落标题时必须截止,不要将其包含在 desc_lines 中: + 个人评价、自我评价、个人总结、个人优势、自我描述、技能特长、专业技能、教育背景、教育经历、项目经验、项目经历、证书资质、语言能力、兴趣爱好、求职意向 + - 如果不存在就写 [] + +示例: +[22]: 阿里巴巴 2021.11-2022.11 高级工程师 +[23]: 工作描述: 从事地推工作完成xx业绩 +[24]: 在地推任务中考核为A +则 desc_lines 应为 [23, 24] + +只返回 JSON。 /no_think \ No newline at end of file diff --git a/rag/prompts/resume_work_exp_en.md b/rag/prompts/resume_work_exp_en.md new file mode 100644 index 00000000000..46e4c9ac8b9 --- /dev/null +++ b/rag/prompts/resume_work_exp_en.md @@ -0,0 +1,38 @@ +Please extract work experience from the following line-indexed resume text. + +{indexed_text} + +Extract into JSON, each work experience entry contains: +{{ + "workExperience": [ + {{ + "company": "", + "position": "", + "internship": 0, + "start_date": "", + "end_date": "", + "desc_lines": [start_index, end_index] + }} + ] +}} + +Field descriptions: +- company: Full company name (including region info in brackets), e.g. "Google Inc." +- position: Job title, follow original text, do not fabricate or guess +- internship: Whether this is an internship, 1 for yes, 0 for no +- start_date: Start date, format %Y.%m or %Y, e.g. "2024.1" +- end_date: End date, use "Present" if still employed, "" if not available +- desc_lines: [start_line, end_line], line number range for job description (integer array) + - Refers to the original text reference range for job description, including achievements, responsibilities, tech stack, etc. + - Include as much as possible until the next work experience entry or other section heading + - STOP before these section headings (do not include them in desc_lines): + Self-evaluation, Personal Summary, Skills, Technical Skills, Education, Project Experience, Certificates, Languages, Hobbies, Career Objective + - Use [] if not available + +Example: +[22]: Google Inc. 2021.11-2022.11 Senior Engineer +[23]: Job description: Responsible for backend development +[24]: Achieved 99.9% uptime for core services +Then desc_lines should be [23, 24] + +Return JSON only. /no_think \ No newline at end of file From c2662da5d504d00753b019d7c8a97844a6c046df Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:10:11 +0800 Subject: [PATCH 104/565] feat: enable Arabic in production UI and add complete Arabic documentation (#13315) ### What problem does this PR solve? This PR adds end-to-end Arabic support in production. It also adds a full Arabic README ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update --- README.md | 1 + README_ar.md | 412 +++ README_fr.md | 1 + README_id.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_pt_br.md | 1 + README_tzh.md | 1 + README_zh.md | 1 + web/src/app.tsx | 15 +- web/src/components/empty/constant.tsx | 21 +- web/src/components/empty/empty.tsx | 12 +- web/src/components/ui/dropdown-menu.tsx | 18 +- web/src/components/ui/input.tsx | 15 +- web/src/components/ui/modal/modal.tsx | 4 +- web/src/components/ui/select.tsx | 6 +- web/src/components/ui/sheet.tsx | 5 +- web/src/constants/common.ts | 15 +- web/src/hooks/use-user-setting-request.tsx | 58 +- web/src/locales/BULGARIAN_LANGUAGE_CHANGES.md | 54 - web/src/locales/ar.ts | 2527 +++++++++++++++++ web/src/locales/config.ts | 33 +- web/src/locales/en.ts | 25 + .../components/metedata/manage-modal.tsx | 6 +- web/src/pages/dataset/dataset/index.tsx | 4 +- web/src/pages/memory/memory-message/index.tsx | 2 +- web/src/pages/next-search/index.tsx | 6 +- .../data-source-detail-page/index.tsx | 43 +- .../setting-locale/translation-table.tsx | 10 +- .../modal/verify-button/index.tsx | 7 +- 30 files changed, 3149 insertions(+), 157 deletions(-) create mode 100644 README_ar.md delete mode 100644 web/src/locales/BULGARIAN_LANGUAGE_CHANGES.md create mode 100644 web/src/locales/ar.ts diff --git a/README.md b/README.md index bda071a86da..e3c89bc6ad0 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_ar.md b/README_ar.md new file mode 100644 index 00000000000..d2e3f062c5b --- /dev/null +++ b/README_ar.md @@ -0,0 +1,412 @@ +

+ +ragflow logo + +
+ +

+ README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil) + README en Français + README in Arabic +

+ +

+ + follow on X(Twitter) + + + Static Badge + + + docker pull infiniflow/ragflow:v0.24.0 + + + Latest Release + + + license + + + Ask DeepWiki + +

+ +

+ Document | + Roadmap | + Twitter | + Discord | + Demo +

+ +
+ +
+ +
+infiniflow%2Fragflow | Trendshift +
+ +
+📕 جدول المحتويات + +- 💡 [ما هو RAGFlow؟](#-what-is-ragflow) +- 🎮 [Demo](#-demo) +- 📌 [آخر التحديثات](#-latest-updates) +- 🌟 [الميزات الرئيسية](#-key-features) +- 🔎 [بنية النظام](#-system-architecture) +- 🎬 [ابدأ](#-get-started) +- 🔧 [التكوينات](#-configurations) +- 🔧 [إنشاء صورة Docker](#-build-a-docker-image) +- 🔨 [إطلاق الخدمة من المصدر للتطوير](#-launch-service-from-source-for-development) +- 📚 [التوثيق](#-documentation) +- 📜 [Roadmap](#-roadmap) +- 🏄 [المجتمع](#-community) +- 🙌 [مساهمة](#-contributing) + +
+ +## 💡 ما هو RAGFlow؟ + +يُعد مشروع [RAGFlow](https://ragflow.io/) محركًا رائدًا ومفتوح المصدر للاسترجاع المعزز بالتوليد (RAG)، ويجمع أحدث تقنيات RAG مع قدرات الوكلاء لبناء طبقة سياق متقدمة لنماذج LLMs. يوفّر سير عمل RAG مبسّطًا وقابلًا للتكيّف مع المؤسسات بمختلف أحجامها. وبالاعتماد على [محرك سياق موحّد](https://ragflow.io/basics/what-is-agent-context-engine) وقوالب وكلاء جاهزة، يتيح RAGFlow للمطورين تحويل البيانات المعقّدة إلى أنظمة AI عالية الدقة وجاهزة للإنتاج بكفاءة وموثوقية. + +## 🎮 Demo + +جرّب النسخة التجريبية على [https://demo.ragflow.io](https://demo.ragflow.io). + +
+ + +
+ +## 🔥 آخر التحديثات + +- 2025-12-26 يدعم ميزة "Memory" لوكلاء الذكاء الاصطناعي. +- 11-11-2025 يدعم Gemini 3 Pro. +- 12-11-2025 يدعم مزامنة البيانات من Confluence، S3، Notion، Discord، Google Drive. +- 23-10-2025 يدعم MinerU وDocling كطرق لتحليل المستندات. +- 15-10-2025 يدعم العرض الأوركسترالي pipeline. +- 08-08-2025 يدعم أحدث موديلات سلسلة OpenAI. +- 01-08-2025 يدعم سير العمل الوكيل وMCP. +- 23-05-2025 تمت إضافة مكون منفذ كود Python/JavaScript إلى Agent. +- 05-05-2025 يدعم الاستعلام بين اللغات. +- 19-03-2025 يدعم استخدام نموذج متعدد الوسائط لفهم الصور داخل ملفات PDF أو DOCX. + +## 🎉 تابعونا + +⭐️ قم بتمييز مستودعنا بنجمة لتبقى على اطلاع بالميزات والتحسينات الجديدة والمثيرة! احصل على إشعارات فورية بالجديد +الإصدارات! 🌟 + +
+ +
+ +## 🌟 الميزات الرئيسية + +### 🍭 **"الجودة في الداخل، الجودة في الخارج"** + +- [الفهم العميق للمستندات](./deepdoc/README.md) لاستخراج المعرفة من البيانات غير المنظمة + ذات التنسيقات المعقدة. +- يجد "إبرة في كومة قش بيانات" من الرموز غير المحدودة حرفيًا. + +### 🍱 **التقطيع القائم على القالب** + +- ذكي وقابل للتفسير. +- الكثير من خيارات القالب للاختيار من بينها. + +### 🌱 **استشهادات مؤرضة لتقليل الهلوسة** + +- تصور تقطيع النص للسماح بالتدخل البشري. +- عرض سريع للمراجع الرئيسية والاستشهادات التي يمكن تتبعها لدعم الإجابات المبنية على أسس سليمة. + +### 🍔 **التوافق مع مصادر البيانات غير المتجانسة** + +- يدعم Word، والشرائح، وExcel، وtxt، والصور، والنسخ الممسوحة ضوئيًا، والبيانات المنظمة، وصفحات الويب، والمزيد. + +### 🛀 **سير عمل RAG آلي وسهل** + +- تنسيق RAG مبسط يلبي احتياجات الشركات الشخصية والكبيرة على حد سواء. +- نماذج LLMs قابلة للتكوين بالإضافة إلى نماذج embedding. +- الاستدعاء المتعدد المقترن بإعادة التصنيف المدمجة. +- APIs بديهي للتكامل السلس مع الأعمال. + +## 🔎 هندسة النظام + +
+ +
+ +## 🎬 ابدأ + +### 📝 المتطلبات الأساسية + +- CPU >= 4 مراكز +- الرام >= 16 جيجا +- القرص >= 50 جيجا بايت +- Docker >= 24.0.0 & Docker Compose >= v2.26.1 +- [gVisor](https://gvisor.dev/docs/user_guide/install/): مطلوب فقط إذا كنت تنوي استخدام ميزة منفذ التعليمات البرمجية (وضع الحماية) لـ RAGFlow. + +> [!TIP] +> إذا لم تقم بتثبيت Docker على جهازك المحلي (Windows أو Mac أو Linux)، راجع [تثبيت Docker Engine](https://docs.docker.com/engine/install/). + +### 🚀 بدء تشغيل الخادم + +1. تأكد من `vm.max_map_count` >= 262144: + + > للتحقق من قيمة `vm.max_map_count`: + > + > ```bash + > $ sysctl vm.max_map_count + > ``` + > + > أعد تعيين `vm.max_map_count` إلى قيمة 262144 على الأقل إذا لم تكن كذلك. + > + > ```bash + > # In this case, we set it to 262144: + > $ sudo sysctl -w vm.max_map_count=262144 + > ``` + > + > سيتم إعادة ضبط هذا التغيير بعد إعادة تشغيل النظام. لضمان بقاء التغيير دائمًا، قم بإضافة أو تحديث + > `vm.max_map_count` القيمة في **/etc/sysctl.conf** وفقًا لذلك: + > + > ```bash + > vm.max_map_count=262144 + > ``` + > +2. استنساخ الريبو: + + ```bash + $ git clone https://github.com/infiniflow/ragflow.git + ``` +3. ابدأ تشغيل الخادم باستخدام صور Docker المعدة مسبقًا: + +> [!CAUTION] +> جميع الصور Docker مصممة لمنصات x86. لا نعرض حاليًا صور Docker لـ ARM64. +> إذا كنت تستخدم نظامًا أساسيًا ARM64، فاتبع [هذا الدليل](https://ragflow.io/docs/dev/build_docker_image) لإنشاء صورة Docker متوافقة مع نظامك. + +> يقوم الأمر أدناه بتنزيل إصدار `v0.24.0` من الصورة RAGFlow Docker. راجع الجدول التالي للحصول على أوصاف لإصدارات RAGFlow المختلفة. لتنزيل إصدار RAGFlow مختلف عن `v0.24.0`، قم بتحديث المتغير `RAGFLOW_IMAGE` وفقًا لذلك في **docker/.env** قبل استخدام `docker compose` لبدء تشغيل الخادم. + +```bash + $ cd ragflow/docker + + # git checkout v0.24.0 + # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) + # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. + + # Use CPU for DeepDoc tasks: + $ docker compose -f docker-compose.yml up -d + + # To use GPU to accelerate DeepDoc tasks: + # sed -i '1i DEVICE=gpu' .env + # docker compose -f docker-compose.yml up -d +``` + +> ملاحظة: قبل `v0.22.0`، قدمنا ​​كلتا الصورتين بنماذج embedding وصورًا رفيعة بدون نماذج embedding. التفاصيل على النحو التالي: + +| RAGFlow علامة الصورة | حجم الصورة (جيجابايت) | هل لديه نماذج embedding؟ | مستقر؟ | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | إصدار مستقر | +| v0.21.1-slim | ≈2 | ❌ | إصدار مستقر | + +> بدءًا من `v0.22.0`، نقوم بشحن الإصدار النحيف فقط ولم نعد نلحق اللاحقة **-slim** بعلامة الصورة. + +4. التحقق من حالة الخادم بعد تشغيل الخادم: + + ```bash + $ docker logs -f docker-ragflow-cpu-1 + ``` + + _النتيجة التالية تؤكد الإطلاق الناجح للنظام:_ + + ```bash + + ____ ___ ______ ______ __ + / __ \ / | / ____// ____// /____ _ __ + / /_/ // /| | / / __ / /_ / // __ \| | /| / / + / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ / + /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/ + + * Running on all addresses (0.0.0.0) + ``` + + > إذا تخطيت خطوة التأكيد هذه وقمت بتسجيل الدخول مباشرة إلى RAGFlow، فقد يعرض متصفحك تنبيه `network abnormal` + > خطأ لأنه في تلك اللحظة، قد لا تتم تهيئة RAGFlow بشكل كامل. + > +5. في متصفح الويب الخاص بك، أدخل عنوان IP الخاص بالخادم الخاص بك وقم بتسجيل الدخول إلى RAGFlow. + + > باستخدام الإعدادات الافتراضية، ما عليك سوى إدخال `http://IP_OF_YOUR_MACHINE` (**من دون** رقم المنفذ) كإعداد افتراضي + > HTTP يمكن حذف منفذ العرض `80` عند استخدام التكوينات الافتراضية. + > +6. في [service_conf.yaml.template](./docker/service_conf.yaml.template)، حدد المصنع LLM المطلوب في `user_default_llm` وقم بالتحديث + الحقل `API_KEY` مع مفتاح API المقابل. + + > راجع [llm_api_key_setup](https://ragflow.io/docs/dev/llm_api_key_setup) لمزيد من المعلومات. + > + + _العرض بدأ!_ + +## 🔧 التكوينات + +عندما يتعلق الأمر بتكوينات النظام، ستحتاج إلى إدارة الملفات التالية: + +- [.env](./docker/.env): يحتفظ بالإعدادات الأساسية للنظام، مثل `SVR_HTTP_PORT`، `MYSQL_PASSWORD`، و + `MINIO_PASSWORD`. +- [service_conf.yaml.template](./docker/service_conf.yaml.template): تكوين الخدمات الخلفية. سيتم ملء متغيرات البيئة في هذا الملف تلقائيًا عند بدء تشغيل الحاوية Docker. ستكون أي متغيرات بيئة تم تعيينها داخل حاوية Docker متاحة للاستخدام، مما يسمح لك بتخصيص سلوك الخدمة استنادًا إلى بيئة النشر. +- [docker-compose.yml](./docker/docker-compose.yml): يعتمد النظام على [docker-compose.yml](./docker/docker-compose.yml) لبدء التشغيل. + +> يوفر الملف [./docker/README](./docker/README.md) وصفًا تفصيليًا لإعدادات البيئة والخدمة +> التكوينات التي يمكن استخدامها كـ `${ENV_VARS}` في ملف [service_conf.yaml.template](./docker/service_conf.yaml.template). + +لتحديث منفذ العرض الافتراضي HTTP (80)، انتقل إلى [docker-compose.yml](./docker/docker-compose.yml) وقم بتغيير `80:80` +إلى `:80`. + +تتطلب تحديثات التكوينات المذكورة أعلاه إعادة تشغيل جميع الحاويات لتصبح سارية المفعول: + +> ```bash +> $ docker compose -f docker-compose.yml up -d +> ``` + +### تبديل محرك المستندات من Elasticsearch إلى Infinity + +RAGFlow يستخدم Elasticsearch بشكل افتراضي لتخزين النص الكامل والمتجهات. للتبديل إلى [Infinity](https://github.com/infiniflow/infinity/)، اتبع الخطوات التالية: + +1. إيقاف كافة الحاويات قيد التشغيل: + + ```bash + $ docker compose -f docker/docker-compose.yml down -v + ``` + +> [!WARNING] +> `-v` سوف يحذف docker وحدات تخزين الحاوية، وسيتم مسح البيانات الموجودة. + +2. اضبط `DOC_ENGINE` في **docker/.env** على `infinity`. +3. ابدأ الحاويات: + + ```bash + $ docker compose -f docker-compose.yml up -d + ``` + +> [!WARNING] +> التبديل إلى Infinity على جهاز Linux/arm64 غير مدعوم رسميًا بعد. + +## 🔧 أنشئ صورة Docker + +يبلغ حجم هذه الصورة حوالي 2 غيغابايت وتعتمد على خدمات LLM وembedding الخارجية. + +```bash +git clone https://github.com/infiniflow/ragflow.git +cd ragflow/ +docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . +``` + +أو إذا كنت خلف وكيل، فيمكنك تمرير وسيطات الوكيل: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + +## 🔨 إطلاق الخدمة من المصدر للتطوير + +1. قم بتثبيت `uv` و`pre-commit`، أو قم بتخطي هذه الخطوة إذا كانا مثبتين بالفعل: + + ```bash + pipx install uv pre-commit + ``` +2. استنساخ الكود المصدري وتثبيت تبعيات بايثون: + + ```bash + git clone https://github.com/infiniflow/ragflow.git + cd ragflow/ + uv sync --python 3.12 # install RAGFlow dependent python modules + uv run download_deps.py + pre-commit install + ``` +3. قم بتشغيل الخدمات التابعة (MinIO وElasticsearch وRedis وMySQL) باستخدام Docker Compose: + + ```bash + docker compose -f docker/docker-compose-base.yml up -d + ``` + + أضف السطر التالي إلى `/etc/hosts` لحل كافة المضيفين المحددين في **docker/.env** إلى `127.0.0.1`: + + ``` + 127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager + ``` +4. إذا لم تتمكن من الوصول إلى HuggingFace، فقم بتعيين متغير البيئة `HF_ENDPOINT` لاستخدام موقع مرآة: + + ```bash + export HF_ENDPOINT=https://hf-mirror.com + ``` +5. إذا كان نظام التشغيل لديك لا يحتوي على jemalloc، فيرجى تثبيته على النحو التالي: + + ```bash + # Ubuntu + sudo apt-get install libjemalloc-dev + # CentOS + sudo yum install jemalloc + # OpenSUSE + sudo zypper install jemalloc + # macOS + sudo brew install jemalloc + ``` +6. إطلاق الخدمة الخلفية: + + ```bash + source .venv/bin/activate + export PYTHONPATH=$(pwd) + bash docker/launch_backend_service.sh + ``` +7. تثبيت تبعيات الواجهة الأمامية: + + ```bash + cd web + npm install + ``` +8. إطلاق خدمة الواجهة الأمامية: + + ```bash + npm run dev + ``` + + _النتيجة التالية تؤكد الإطلاق الناجح للنظام:_ + + ![](https://github.com/user-attachments/assets/0daf462c-a24d-4496-a66f-92533534e187) +9. أوقف خدمة الواجهة الأمامية والخلفية RAGFlow بعد اكتمال التطوير: + + ```bash + pkill -f "ragflow_server.py|task_executor.py" + ``` + +## 📚 التوثيق + +- [البدء السريع](https://ragflow.io/docs/dev/) +- [التكوين](https://ragflow.io/docs/dev/configurations) +- [ملاحظات الإصدار](https://ragflow.io/docs/dev/release_notes) +- [أدلة المستخدم](https://ragflow.io/docs/dev/category/guides) +- [أدلة المطورين](https://ragflow.io/docs/dev/category/developers) +- [المراجع](https://ragflow.io/docs/dev/category/references) +- [الأسئلة الشائعة](https://ragflow.io/docs/dev/faq) + +## 📜 Roadmap + +راجع [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) + +## 🏄 المجتمع + +- [Discord](https://discord.gg/NjYzJD3GM3) +- [Twitter](https://twitter.com/infiniflowai) +- [مناقشات جيثب](https://github.com/orgs/infiniflow/discussions) + +## 🙌 المساهمة + +RAGFlow يزدهر من خلال التعاون مفتوح المصدر. وبهذه الروح، فإننا نحتضن المساهمات المتنوعة من المجتمع. +إذا كنت ترغب في أن تكون جزءًا، فراجع [إرشادات المساهمة](https://ragflow.io/docs/dev/contributing) أولاً. diff --git a/README_fr.md b/README_fr.md index 4f311df0a04..3f8555edf63 100644 --- a/README_fr.md +++ b/README_fr.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_id.md b/README_id.md index c1954dc5566..8a7d2715f77 100644 --- a/README_id.md +++ b/README_id.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_ja.md b/README_ja.md index 9a0538809c3..03f6a5b477d 100644 --- a/README_ja.md +++ b/README_ja.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_ko.md b/README_ko.md index 79703a62d67..55a03b2919a 100644 --- a/README_ko.md +++ b/README_ko.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_pt_br.md b/README_pt_br.md index 8e75e8cb032..0d162ca672e 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_tzh.md b/README_tzh.md index 1b1508bee96..e7b21fe53ed 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/README_zh.md b/README_zh.md index 99a65cdd46e..ef64dc1786a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -13,6 +13,7 @@ Bahasa Indonesia Português(Brasil) README en Français + README in Arabic

diff --git a/web/src/app.tsx b/web/src/app.tsx index 8a0fc46f6bf..4d84b8b3bb4 100644 --- a/web/src/app.tsx +++ b/web/src/app.tsx @@ -1,10 +1,11 @@ import { Toaster as Sonner } from '@/components/ui/sonner'; import { Toaster } from '@/components/ui/toaster'; -import i18n from '@/locales/config'; +import i18n, { changeLanguageAsync } from '@/locales/config'; import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { configResponsive } from 'ahooks'; import { App, ConfigProvider, ConfigProviderProps, theme } from 'antd'; import pt_BR from 'antd/lib/locale/pt_BR'; +import arEG from 'antd/locale/ar_EG'; import deDE from 'antd/locale/de_DE'; import enUS from 'antd/locale/en_US'; import ru_RU from 'antd/locale/ru_RU'; @@ -54,6 +55,7 @@ const AntLanguageMap = { vi: vi_VN, 'pt-BR': pt_BR, de: deDE, + ar: arEG, }; if (process.env.NODE_ENV === 'development') { @@ -85,6 +87,12 @@ function Root({ children }: React.PropsWithChildren) { const { theme: themeragflow } = useTheme(); const getLocale = (lng: string) => AntLanguageMap[lng as keyof typeof AntLanguageMap] ?? enUS; + const updateDocumentLocale = (lng: string) => { + document.documentElement.lang = lng; + document.documentElement.dir = lng.toLowerCase().startsWith('ar') + ? 'rtl' + : 'ltr'; + }; const [locale, setLocal] = useState(getLocale(storage.getLanguage())); @@ -92,9 +100,10 @@ function Root({ children }: React.PropsWithChildren) { const handleLanguageChanged = (lng: string) => { storage.setLanguage(lng); setLocal(getLocale(lng)); - document.documentElement.lang = lng; + updateDocumentLocale(lng); }; + updateDocumentLocale(storage.getLanguage() || i18n.language || 'en'); i18n.on('languageChanged', handleLanguageChanged); return () => { @@ -130,7 +139,7 @@ const RootProvider = ({ children }: React.PropsWithChildren) => { useEffect(() => { const lng = storage.getLanguage(); if (lng) { - i18n.changeLanguage(lng); + void changeLanguageAsync(lng); } }, []); diff --git a/web/src/components/empty/constant.tsx b/web/src/components/empty/constant.tsx index 811c27a5b6b..641920041d8 100644 --- a/web/src/components/empty/constant.tsx +++ b/web/src/components/empty/constant.tsx @@ -1,4 +1,3 @@ -import { t } from 'i18next'; import { HomeIcon } from '../svg-icon'; export enum EmptyType { @@ -17,27 +16,27 @@ export enum EmptyCardType { export const EmptyCardData = { [EmptyCardType.Agent]: { icon: , - title: t('empty.agentTitle'), - notFound: t('empty.notFoundAgent'), + titleKey: 'empty.agentTitle', + notFoundKey: 'empty.notFoundAgent', }, [EmptyCardType.Dataset]: { icon: , - title: t('empty.datasetTitle'), - notFound: t('empty.notFoundDataset'), + titleKey: 'empty.datasetTitle', + notFoundKey: 'empty.notFoundDataset', }, [EmptyCardType.Chat]: { icon: , - title: t('empty.chatTitle'), - notFound: t('empty.notFoundChat'), + titleKey: 'empty.chatTitle', + notFoundKey: 'empty.notFoundChat', }, [EmptyCardType.Search]: { icon: , - title: t('empty.searchTitle'), - notFound: t('empty.notFoundSearch'), + titleKey: 'empty.searchTitle', + notFoundKey: 'empty.notFoundSearch', }, [EmptyCardType.Memory]: { icon: , - title: t('empty.memoryTitle'), - notFound: t('empty.notFoundMemory'), + titleKey: 'empty.memoryTitle', + notFoundKey: 'empty.notFoundMemory', }, }; diff --git a/web/src/components/empty/empty.tsx b/web/src/components/empty/empty.tsx index c6f6f1f6c9b..fbb97506f67 100644 --- a/web/src/components/empty/empty.tsx +++ b/web/src/components/empty/empty.tsx @@ -4,6 +4,7 @@ import { useIsDarkTheme } from '../theme-provider'; import { Plus } from 'lucide-react'; import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import SvgIcon from '../svg-icon'; import { EmptyCardData, EmptyCardType, EmptyType } from './constant'; import { EmptyCardProps, EmptyProps } from './interface'; @@ -82,8 +83,13 @@ export const EmptyAppCard = (props: { testId?: string; }) => { const { type, showIcon, className, isSearch, children, testId } = props; + const { t } = useTranslation(); let defaultClass = ''; let style = {}; + const cardData = EmptyCardData[type]; + const title = t(cardData.titleKey); + const notFound = t(cardData.notFoundKey); + switch (props.size) { case 'small': style = { width: '256px' }; @@ -100,10 +106,8 @@ export const EmptyAppCard = (props: { return (

{children} - + )); DropdownMenuSubTrigger.displayName = @@ -85,7 +85,7 @@ const DropdownMenuItem = React.forwardRef< ref={ref} className={cn( 'relative flex cursor-default select-none items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-none transition-colors focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0', - inset && 'pl-8', + inset && 'ps-8', justifyBetween && 'flex justify-between', className, )} @@ -101,13 +101,13 @@ const DropdownMenuCheckboxItem = React.forwardRef< - + @@ -125,12 +125,12 @@ const DropdownMenuRadioItem = React.forwardRef< - + @@ -150,7 +150,7 @@ const DropdownMenuLabel = React.forwardRef< ref={ref} className={cn( 'px-2 py-1.5 text-sm font-semibold', - inset && 'pl-8', + inset && 'ps-8', className, )} {...props} @@ -176,7 +176,7 @@ const DropdownMenuShortcut = ({ }: React.HTMLAttributes) => { return ( ); diff --git a/web/src/components/ui/input.tsx b/web/src/components/ui/input.tsx index a84828d519d..f742b0256d7 100644 --- a/web/src/components/ui/input.tsx +++ b/web/src/components/ui/input.tsx @@ -80,8 +80,9 @@ const Input = React.forwardRef( className, )} style={{ - paddingLeft: !!prefix && prefixWidth ? `${prefixWidth}px` : '', - paddingRight: isPasswordInput + paddingInlineStart: + !!prefix && prefixWidth ? `${prefixWidth}px` : '', + paddingInlineEnd: isPasswordInput ? '40px' : !!suffix ? `${suffixWidth}px` @@ -109,7 +110,7 @@ const Input = React.forwardRef( {prefix && ( {prefix} @@ -118,8 +119,8 @@ const Input = React.forwardRef( {suffix && ( {suffix} @@ -130,7 +131,7 @@ const Input = React.forwardRef( type="button" className=" p-2 text-text-secondary - absolute border-0 right-1 top-[50%] translate-y-[-50%] + absolute border-0 end-1 top-[50%] translate-y-[-50%] dark:peer-autofill/input:text-text-secondary-inverse dark:peer-autofill/input:hover:text-text-primary-inverse dark:peer-autofill/input:focus-visible:text-text-primary-inverse @@ -165,7 +166,7 @@ const SearchInput = (props: InputProps) => { } + prefix={} /> ); }; diff --git a/web/src/components/ui/modal/modal.tsx b/web/src/components/ui/modal/modal.tsx index 6e8883ed0d8..4003f497837 100644 --- a/web/src/components/ui/modal/modal.tsx +++ b/web/src/components/ui/modal/modal.tsx @@ -164,7 +164,7 @@ const Modal: ModalType = ({ )} > {confirmLoading && ( - + )} {okText ?? t('modal.okText')} @@ -253,7 +253,7 @@ const Modal: ModalType = ({
- + { @@ -539,7 +539,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => { )}
{metadataType === MetadataType.Manage && ( -
+
{t('knowledgeDetails.metadata.toMetadataSettingTip')}
)} diff --git a/web/src/pages/dataset/dataset/index.tsx b/web/src/pages/dataset/dataset/index.tsx index 51651e7f265..8f9b41e2523 100644 --- a/web/src/pages/dataset/dataset/index.tsx +++ b/web/src/pages/dataset/dataset/index.tsx @@ -142,7 +142,7 @@ export default function Dataset() {
)} {manageMetadataVisible && ( diff --git a/web/src/pages/memory/memory-message/index.tsx b/web/src/pages/memory/memory-message/index.tsx index 62e27678d37..6ca98f987d8 100644 --- a/web/src/pages/memory/memory-message/index.tsx +++ b/web/src/pages/memory/memory-message/index.tsx @@ -19,7 +19,7 @@ export default function MemoryMessage() { return (
{openSetting && ( }
-
+
{!isSearching && ( -
+
From 1518a6c12436681947d4e5e12267d61b0f22d776 Mon Sep 17 00:00:00 2001 From: balibabu Date: Wed, 4 Mar 2026 11:10:05 +0800 Subject: [PATCH 118/565] Fix: Change the background color of the message notification button. (#13344) ### What problem does this PR solve? Fix: Change the background color of the message notification button. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/src/components/message-item/group-button.tsx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/web/src/components/message-item/group-button.tsx b/web/src/components/message-item/group-button.tsx index c23a045a271..f215699f906 100644 --- a/web/src/components/message-item/group-button.tsx +++ b/web/src/components/message-item/group-button.tsx @@ -102,7 +102,12 @@ export const AssistantGroupButton = ({ )} {prompt && ( - )} From cac84a44670044f4fc4355d3d322eff02d218070 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Wed, 4 Mar 2026 11:51:10 +0800 Subject: [PATCH 119/565] Fix: Correct PDF chunking parameter name in naive (#13357) ### What problem does this PR solve? Fix: Correct PDF chunking parameter name in naive #13325 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/app/naive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rag/app/naive.py b/rag/app/naive.py index 22606c3b32c..e7f6730d801 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -252,7 +252,7 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No "deepdoc": by_deepdoc, "mineru": by_mineru, "docling": by_docling, - "tcadp": by_tcadp, + "tcadp parser": by_tcadp, "paddleocr": by_paddleocr, "plaintext": by_plaintext, # default } @@ -854,7 +854,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca urls = extract_links_from_pdf(binary) if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" + layout_recognizer = "DeepDOC" if layout_recognizer else "PlainText" name = layout_recognizer.strip().lower() parser = PARSERS.get(name, by_plaintext) From 8a838a8369b4d9ece52475c6bcaf3a98ece4f716 Mon Sep 17 00:00:00 2001 From: balibabu Date: Wed, 4 Mar 2026 12:48:35 +0800 Subject: [PATCH 120/565] Fix: The dropdown menu for large models does not automatically focus on the search box. #13313 (#13360) ### What problem does this PR solve? Fix: The dropdown menu for large models does not automatically focus on the search box. #13313 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/src/components/ui/command.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/web/src/components/ui/command.tsx b/web/src/components/ui/command.tsx index 856600c5a90..2a8996c805c 100644 --- a/web/src/components/ui/command.tsx +++ b/web/src/components/ui/command.tsx @@ -74,7 +74,6 @@ const CommandList = React.forwardRef< className, )} onWheel={(e) => e.stopPropagation()} - onMouseEnter={(e) => e.currentTarget.focus()} tabIndex={-1} {...props} /> From efaa9541fab78923e9565d224b603d1d8a1aa272 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 4 Mar 2026 13:07:45 +0800 Subject: [PATCH 121/565] Supports login cross multiple RAGFlow servers (#13322) ### What problem does this PR solve? 1. Use redis to store the secret key. 2. During startup API server will read the secret from redis. If no such secret key, generate one and store it into redis, atomically. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- .github/workflows/tests.yml | 2 +- common/settings.py | 23 ++++++++++++----------- rag/utils/redis_conn.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 934005edec3..6d370097f60 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -142,7 +142,7 @@ jobs: RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}} RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID} echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV} - sudo docker pull ubuntu:22.04 + sudo docker pull ubuntu:24.04 sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} . if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then export HTTP_API_TEST_LEVEL=p3 diff --git a/common/settings.py b/common/settings.py index fe3d07b33cd..2b67dc34d72 100644 --- a/common/settings.py +++ b/common/settings.py @@ -16,7 +16,6 @@ import os import json import secrets -from datetime import date import logging from common.constants import RAG_FLOW_SERVICE_NAME from common.file_utils import get_project_base_directory @@ -34,6 +33,7 @@ from rag.utils.gcs_conn import RAGFlowGCS from rag.utils.minio_conn import RAGFlowMinio from rag.utils.opendal_conn import OpenDALStorage +from rag.utils.redis_conn import REDIS_CONN from rag.utils.s3_conn import RAGFlowS3 from rag.utils.oss_conn import RAGFlowOSS @@ -138,21 +138,22 @@ def get_svr_queue_names(): return [get_svr_queue_name(priority) for priority in [1, 0]] def _get_or_create_secret_key(): - secret_key = os.environ.get("RAGFLOW_SECRET_KEY") - if secret_key and len(secret_key) >= 32: - return secret_key - - # Check if there's a configured secret key - configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") - if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: - return configured_key + # secret_key = os.environ.get("RAGFLOW_SECRET_KEY") + # if secret_key and len(secret_key) >= 32: + # return secret_key + # + # # Check if there's a configured secret key + # configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") + # if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: + # return configured_key # Generate a new secure key and warn about it import logging - new_key = secrets.token_hex(32) + generated_key = secrets.token_hex(32) + secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key) logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") - return new_key + return secret_key class StorageFactory: storage_mapping = { diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index d134f05331f..960e98af815 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -334,6 +334,42 @@ def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace self.__open__() return -1 + def get_or_create_secret_key(self, key_name: str, new_value: str) -> str: + """ + Atomically get an existing key or create a new one. + + This method guarantees that across multiple concurrent calls, only one + key will be created and all callers will receive the same key. + + Returns: + The secret key string + + Raises: + redis.RedisError: If Redis operations fail + """ + # First, try to get the existing key + existing_value = self.REDIS.get(key_name) + if existing_value is not None: + logging.debug("Retrieved existing key from Redis") + return existing_value + + # Use SETNX to atomically set the key only if it doesn't exist + # SETNX returns True if the key was set, False if it already existed + if self.REDIS.setnx(key_name, new_value): + logging.info("Successfully created new secret key in Redis") + return new_value + + # SETNX failed, meaning another process created the key concurrently + # Retrieve and return that key + final_key = self.REDIS.get(key_name) + if final_key is None: + # This should rarely happen, but retry if it does + logging.warning("Key disappeared during concurrent access, retrying...") + return self.get_or_create_secret_key(key_name, new_value) + + logging.debug("Retrieved key created by another process") + return final_key + def transaction(self, key, value, exp=3600): try: pipeline = self.REDIS.pipeline(transaction=True) From 7ddebf51771c0881ecef91f30bc97b7b4b4f61c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=91=E5=8D=BF?= <121151546+shaoqing404@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:23:37 +0800 Subject: [PATCH 122/565] Fix Dify external retrieval by providing metadata.document_id (#13337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? ## Summary Dify’s external retrieval expects `records[].metadata.document_id` to be a non-empty string. RAGFlow currently only sets `metadata.doc_id`, which causes Dify validation to fail. This PR adds `metadata.document_id` (mapped from `doc_id`) in the Dify-compatible retrieval response. ## Changes - Add `meta["document_id"] = c["doc_id"]` in `api/apps/sdk/dify_retrieval.py` ## Testing - Not run (logic-only change). ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/sdk/dify_retrieval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 881614e5d97..d1e0b2d8b39 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -166,6 +166,8 @@ async def retrieval(tenant_id): c.pop("vector", None) meta = getattr(doc, 'meta_fields', {}) meta["doc_id"] = c["doc_id"] + # Dify expects metadata.document_id for external retrieval sources. + meta["document_id"] = c["doc_id"] records.append({ "content": c["content_with_weight"], "score": c["similarity"], From f0d0de50ee5ba08c88ab34c13f7e7c3bf8f3af56 Mon Sep 17 00:00:00 2001 From: yiminghub2024 <482890@qq.com> Date: Wed, 4 Mar 2026 13:54:20 +0800 Subject: [PATCH 123/565] Enhance local model deployment documentation support gpustack guide (#13339) ### Type of change - [X] Documentation Update:Enhance local model deployment documentation support gpustack guide --- docs/guides/models/deploy_local_llm.mdx | 36 +++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/guides/models/deploy_local_llm.mdx b/docs/guides/models/deploy_local_llm.mdx index e7e3fbeaee3..0971925eded 100644 --- a/docs/guides/models/deploy_local_llm.mdx +++ b/docs/guides/models/deploy_local_llm.mdx @@ -9,11 +9,11 @@ sidebar_custom_props: { import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -Deploy and run local models using Ollama, Xinference, VLLM ,SGLANG or other frameworks. +Deploy and run local models using Ollama, Xinference, Vllm ,Sglang , Gpustack or other frameworks. --- -RAGFlow supports deploying models locally using Ollama, Xinference, IPEX-LLM, or jina. If you have locally deployed models to leverage or wish to enable GPU or CUDA for inference acceleration, you can bind Ollama or Xinference into RAGFlow and use either of them as a local "server" for interacting with your local models. +RAGFlow supports deploying models locally using Ollama, Xinference, IPEX-LLM, Vllm ,Sglang , Gpustack or jina. If you have locally deployed models to leverage or wish to enable GPU or CUDA for inference acceleration, you can bind Ollama or Xinference into RAGFlow and use either of them as a local "server" for interacting with your local models. RAGFlow seamlessly integrates with Ollama and Xinference, without the need for further environment configurations. You can use them to deploy two types of local models in RAGFlow: chat models and embedding models. @@ -350,6 +350,38 @@ select vllm chat model as default llm model as follow: create chat->create conversations-chat as follow: ![chat](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow_vllm2.png) +### 6. Deploy Gpustack + +ubuntu 22.04/24.04 + +### 6.1 RUN Gpustack WITH BEST PRACTISE + +```bash +sudo docker run -d --name gpustack \ + --restart unless-stopped \ + -p 80:80 \ + -p 10161:10161 \ + --volume gpustack-data:/var/lib/gpustack \ + gpustack/gpustack + ``` +you can get docker info +```bash + docker ps + ``` +when see the follow ,it means vllm engine is ready for access +```bash +root@gpustack-prod:~# docker ps +CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES +abf59be84b1a gpustack/gpustack "/usr/bin/entrypoint…" 6 hours ago Up 6 hours 0.0.0.0:80->80/tcp, [::]:80->80/tcp, 0.0.0.0:10161->10161/tcp, [::]:10161->10161/tcp gpustack + ``` +### 6.2 INTERGRATEING RAGFLOW WITH GPUSTACK CHAT/EM/RERANK LLM WITH WEBUI + +setting->model providers->search->gpustack->add ,configure as follow: + +![add vllm](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow-gpustack11.png) + +select gpustack chat model as default llm model as follow: +![chat](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow-gpustack22.png) From 8df3c165775b4b9de8a779183fb9525b6851d1cb Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 4 Mar 2026 16:36:42 +0800 Subject: [PATCH 124/565] Disable benchmark (#13370) ### What problem does this PR solve? benchmark always failed in new CI machine. please enable it after the issue is fixed. ### Type of change - [x] Other (please describe): disable benchmark Signed-off-by: Jin Hai --- .github/workflows/tests.yml | 200 ++++++++++++++++++------------------ 1 file changed, 100 insertions(+), 100 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6d370097f60..c0d27aa7656 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -421,106 +421,106 @@ jobs: done source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log - - name: RAGFlow CLI retrieval test Infinity - env: - PYTHONPATH: ${{ github.workspace }} - run: | - set -euo pipefail - source .venv/bin/activate - - export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - - EMAIL="ci-${GITHUB_RUN_ID}@example.com" - PASS="ci-pass-${GITHUB_RUN_ID}" - DATASET="ci_dataset_${GITHUB_RUN_ID}" - - CLI="python admin/client/ragflow_cli.py" - - LOG_FILE="infinity_cli_test.log" - : > "${LOG_FILE}" - - ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]' - run_cli() { - local logfile="$1" - shift - local allow_re="" - if [[ "${1:-}" == "--allow" ]]; then - allow_re="$2" - shift 2 - fi - local cmd_display="$*" - echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}" - local tmp_log - tmp_log="$(mktemp)" - set +e - timeout 180s "$@" 2>&1 | tee "${tmp_log}" - local status=${PIPESTATUS[0]} - set -e - cat "${tmp_log}" >> "${logfile}" - if grep -qiE "${ERROR_RE}" "${tmp_log}"; then - if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then - echo "Allowed CLI error markers in ${logfile}" - rm -f "${tmp_log}" - return 0 - fi - echo "Detected CLI error markers in ${logfile}" - rm -f "${tmp_log}" - exit 1 - fi - rm -f "${tmp_log}" - return ${status} - } - - set -a - source docker/.env - set +a - - HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}" - USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')" - USER_PORT="${SVR_HTTP_PORT}" - ADMIN_HOST="${USER_HOST}" - ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}" - - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do - echo "Waiting for service to be available..." - sleep 5 - done - - admin_ready=0 - for i in $(seq 1 30); do - if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then - admin_ready=1 - break - fi - sleep 1 - done - if [[ "${admin_ready}" -ne 1 ]]; then - echo "Admin service did not become ready" - exit 1 - fi - - run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version" - ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?' - run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'" - - user_ready=0 - for i in $(seq 1 30); do - if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then - user_ready=1 - break - fi - sleep 1 - done - if [[ "${user_ready}" -ne 1 ]]; then - echo "User service did not become ready" - exit 1 - fi - - run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version" - run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'" - run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'" - run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync" - run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'" +# - name: RAGFlow CLI retrieval test Infinity +# env: +# PYTHONPATH: ${{ github.workspace }} +# run: | +# set -euo pipefail +# source .venv/bin/activate +# +# export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" +# +# EMAIL="ci-${GITHUB_RUN_ID}@example.com" +# PASS="ci-pass-${GITHUB_RUN_ID}" +# DATASET="ci_dataset_${GITHUB_RUN_ID}" +# +# CLI="python admin/client/ragflow_cli.py" +# +# LOG_FILE="infinity_cli_test.log" +# : > "${LOG_FILE}" +# +# ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]' +# run_cli() { +# local logfile="$1" +# shift +# local allow_re="" +# if [[ "${1:-}" == "--allow" ]]; then +# allow_re="$2" +# shift 2 +# fi +# local cmd_display="$*" +# echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}" +# local tmp_log +# tmp_log="$(mktemp)" +# set +e +# timeout 180s "$@" 2>&1 | tee "${tmp_log}" +# local status=${PIPESTATUS[0]} +# set -e +# cat "${tmp_log}" >> "${logfile}" +# if grep -qiE "${ERROR_RE}" "${tmp_log}"; then +# if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then +# echo "Allowed CLI error markers in ${logfile}" +# rm -f "${tmp_log}" +# return 0 +# fi +# echo "Detected CLI error markers in ${logfile}" +# rm -f "${tmp_log}" +# exit 1 +# fi +# rm -f "${tmp_log}" +# return ${status} +# } +# +# set -a +# source docker/.env +# set +a +# +# HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}" +# USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')" +# USER_PORT="${SVR_HTTP_PORT}" +# ADMIN_HOST="${USER_HOST}" +# ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}" +# +# until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do +# echo "Waiting for service to be available..." +# sleep 5 +# done +# +# admin_ready=0 +# for i in $(seq 1 30); do +# if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then +# admin_ready=1 +# break +# fi +# sleep 1 +# done +# if [[ "${admin_ready}" -ne 1 ]]; then +# echo "Admin service did not become ready" +# exit 1 +# fi +# +# run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version" +# ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?' +# run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'" +# +# user_ready=0 +# for i in $(seq 1 30); do +# if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then +# user_ready=1 +# break +# fi +# sleep 1 +# done +# if [[ "${user_ready}" -ne 1 ]]; then +# echo "User service did not become ready" +# exit 1 +# fi +# +# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version" +# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'" +# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'" +# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync" +# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'" - name: Stop ragflow to save coverage Infinity if: ${{ !cancelled() }} From 9e3959db7db3943588d4f62c15d6892c1730718d Mon Sep 17 00:00:00 2001 From: Good0987 Date: Wed, 4 Mar 2026 04:41:35 -0500 Subject: [PATCH 125/565] Test: add scenario for embedding_model update when chunk_count > 0 (#13351) ### What problem does this PR solve? Guard embedding_model change when dataset has existing chunks. API must return code 102 with message 'When chunk_num (N) > 0, embedding_model must remain ' to prevent silent embedding drift. ### Type of change - [x] Add Testcases Co-authored-by: Liu An --- .../test_update_dataset.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py index a502be2b523..0cafe1f6743 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py @@ -26,7 +26,6 @@ from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names from configs import DEFAULT_PARSER_CONFIG -# TODO: Missing scenario for updating embedding_model with chunk_count != 0 class TestAuthorization: @@ -275,6 +274,30 @@ def test_embedding_model(self, HttpApiAuth, add_dataset_func, embedding_model): assert res["code"] == 0, res assert res["data"][0]["embedding_model"] == embedding_model, res + @pytest.mark.p1 + def test_embedding_model_with_existing_chunks(self, HttpApiAuth, add_chunks): + """Guard: embedding_model cannot change when dataset has chunks (chunk_count > 0).""" + dataset_id, _, _ = add_chunks + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"], res + dataset = res["data"][0] + assert dataset.get("chunk_count", 0) > 0, res + + current_embedding = dataset["embedding_model"] + candidates = ["BAAI/bge-small-en-v1.5@Builtin", "embedding-3@ZHIPU-AI"] + new_embedding = candidates[0] if current_embedding != candidates[0] else candidates[1] + + payload = {"embedding_model": new_embedding} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102, res + expected_message = ( + f"When chunk_num ({dataset['chunk_count']}) > 0, " + f"embedding_model must remain {current_embedding}" + ) + assert res["message"] == expected_message, res + @pytest.mark.p2 @pytest.mark.parametrize( "name, embedding_model", From 724cc6aa9ba80ecaa4faad4bfc0c1ea0c29549ba Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 4 Mar 2026 17:48:47 +0800 Subject: [PATCH 126/565] Update graspologic to gitee (#13362) ### What problem does this PR solve? Accelerate python module downloading ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- pyproject.toml | 2 +- uv.lock | 55 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 815377dfd63..53dc38cf8cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "google-cloud-storage>=2.19.0,<3.0.0", "google-genai>=1.41.0,<2.0.0", "google-search-results==2.4.2", - "graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd", + "graspologic @ git+https://gitee.com/infiniflow/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd", "groq==0.9.0", "grpcio-status==1.67.1", "html-text==0.6.2", diff --git a/uv.lock b/uv.lock index 70e96a0bdc2..6c545065650 100644 --- a/uv.lock +++ b/uv.lock @@ -2576,7 +2576,7 @@ wheels = [ [[package]] name = "graspologic" version = "0.1.dev847+g38e680cab" -source = { git = "https://github.com/yuzhichang/graspologic.git?rev=38e680cab72bc9fb68a7992c3bcc2d53b24e42fd#38e680cab72bc9fb68a7992c3bcc2d53b24e42fd" } +source = { git = "https://gitee.com/infiniflow/graspologic.git?rev=38e680cab72bc9fb68a7992c3bcc2d53b24e42fd#38e680cab72bc9fb68a7992c3bcc2d53b24e42fd" } dependencies = [ { name = "anytree" }, { name = "beartype" }, @@ -5825,6 +5825,19 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "pytest-base-url" +version = "2.1.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "pytest" }, + { name = "requests" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ae/1a/b64ac368de6b993135cb70ca4e5d958a5c268094a3a2a4cac6f0021b6c4f/pytest_base_url-2.1.0.tar.gz", hash = "sha256:02748589a54f9e63fcbe62301d6b0496da0d10231b753e950c63e03aee745d45", size = 6702, upload-time = "2024-01-31T22:43:00.81Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/1c/b00940ab9eb8ede7897443b771987f2f4a76f06be02f1b3f01eb7567e24a/pytest_base_url-2.1.0-py3-none-any.whl", hash = "sha256:3ad15611778764d451927b2a53240c1a7a591b521ea44cebfe45849d2d2812e6", size = 5302, upload-time = "2024-01-31T22:42:58.897Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0" @@ -5839,6 +5852,21 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-playwright" +version = "0.7.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "playwright" }, + { name = "pytest" }, + { name = "pytest-base-url" }, + { name = "python-slugify" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e8/6b/913e36aa421b35689ec95ed953ff7e8df3f2ee1c7b8ab2a3f1fd39d95faf/pytest_playwright-0.7.2.tar.gz", hash = "sha256:247b61123b28c7e8febb993a187a07e54f14a9aa04edc166f7a976d88f04c770", size = 16928, upload-time = "2025-11-24T03:43:22.53Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/76/61/4d333d8354ea2bea2c2f01bad0a4aa3c1262de20e1241f78e73360e9b620/pytest_playwright-0.7.2-py3-none-any.whl", hash = "sha256:8084e015b2b3ecff483c2160f1c8219b38b66c0d4578b23c0f700d1b0240ea38", size = 16881, upload-time = "2025-11-24T03:43:24.423Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0" @@ -5999,6 +6027,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-slugify" +version = "8.0.4" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "text-unidecode" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -6303,6 +6343,7 @@ test = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-playwright" }, { name = "pytest-xdist" }, { name = "python-docx" }, { name = "python-pptx" }, @@ -6351,7 +6392,7 @@ requires-dist = [ { name = "google-cloud-storage", specifier = ">=2.19.0,<3.0.0" }, { name = "google-genai", specifier = ">=1.41.0,<2.0.0" }, { name = "google-search-results", specifier = "==2.4.2" }, - { name = "graspologic", git = "https://github.com/yuzhichang/graspologic.git?rev=38e680cab72bc9fb68a7992c3bcc2d53b24e42fd" }, + { name = "graspologic", git = "https://gitee.com/infiniflow/graspologic.git?rev=38e680cab72bc9fb68a7992c3bcc2d53b24e42fd" }, { name = "groq", specifier = "==0.9.0" }, { name = "grpcio-status", specifier = "==1.67.1" }, { name = "html-text", specifier = "==0.6.2" }, @@ -6441,6 +6482,7 @@ test = [ { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "pytest-playwright", specifier = ">=0.7.2" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "python-docx", specifier = ">=1.1.2" }, { name = "python-pptx", specifier = ">=1.0.2" }, @@ -7639,6 +7681,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c5/db/daa85799b9af2aa50539b27eeb0d6a2a0ac35465f62683107847830dbe4d/tencentcloud_sdk_python-3.0.1478-py2.py3-none-any.whl", hash = "sha256:10ddee1c1348f49e2b54af606f978d4cb17fca656639e8d99b6527e6e4793833", size = 12984723, upload-time = "2025-10-20T20:54:27.767Z" }, ] +[[package]] +name = "text-unidecode" +version = "1.3" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" }, +] + [[package]] name = "tf-playwright-stealth" version = "1.2.0" From 69cdede8ec97a38be84cd5b330016709fdb8a805 Mon Sep 17 00:00:00 2001 From: Stephen Hu <812791840@qq.com> Date: Wed, 4 Mar 2026 18:00:17 +0800 Subject: [PATCH 127/565] Refa:improve excel parser logic (#13372) ### What problem does this PR solve? improve excel parser logic ### Type of change - [x] Refactoring --- deepdoc/parser/excel_parser.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py index 2fe3420192c..b75d31f6a47 100644 --- a/deepdoc/parser/excel_parser.py +++ b/deepdoc/parser/excel_parser.py @@ -74,9 +74,16 @@ def clean_string(s): return df.apply(lambda col: col.map(clean_string)) + @staticmethod + def _fill_worksheet_from_dataframe(ws, df: pd.DataFrame): + for col_num, column_name in enumerate(df.columns, 1): + ws.cell(row=1, column=col_num, value=column_name) + for row_num, row in enumerate(df.values, 2): + for col_num, value in enumerate(row, 1): + ws.cell(row=row_num, column=col_num, value=value) + @staticmethod def _dataframe_to_workbook(df): - # if contains multiple sheets use _dataframes_to_workbook if isinstance(df, dict) and len(df) > 1: return RAGFlowExcelParser._dataframes_to_workbook(df) @@ -84,30 +91,19 @@ def _dataframe_to_workbook(df): wb = Workbook() ws = wb.active ws.title = "Data" - - for col_num, column_name in enumerate(df.columns, 1): - ws.cell(row=1, column=col_num, value=column_name) - - for row_num, row in enumerate(df.values, 2): - for col_num, value in enumerate(row, 1): - ws.cell(row=row_num, column=col_num, value=value) - + RAGFlowExcelParser._fill_worksheet_from_dataframe(ws, df) return wb - + @staticmethod def _dataframes_to_workbook(dfs: dict): wb = Workbook() default_sheet = wb.active wb.remove(default_sheet) - + for sheet_name, df in dfs.items(): df = RAGFlowExcelParser._clean_dataframe(df) ws = wb.create_sheet(title=sheet_name) - for col_num, column_name in enumerate(df.columns, 1): - ws.cell(row=1, column=col_num, value=column_name) - for row_num, row in enumerate(df.values, 2): - for col_num, value in enumerate(row, 1): - ws.cell(row=row_num, column=col_num, value=value) + RAGFlowExcelParser._fill_worksheet_from_dataframe(ws, df) return wb @staticmethod From 83a79cb1c9c7fb3bee67433b4d203f3618aa6a7c Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:01:41 +0800 Subject: [PATCH 128/565] benchmark fail in ci (#13377) ### What problem does this PR solve? ci fails in elastic search because of benchmark ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- .github/workflows/tests.yml | 202 ++++++++++++++++++------------------ 1 file changed, 101 insertions(+), 101 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c0d27aa7656..b1c4452c8ee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -267,7 +267,7 @@ jobs: local tmp_log tmp_log="$(mktemp)" set +e - timeout 180s "$@" 2>&1 | tee "${tmp_log}" + timeout 300s "$@" 2>&1 | tee "${tmp_log}" local status=${PIPESTATUS[0]} set -e cat "${tmp_log}" >> "${logfile}" @@ -421,106 +421,106 @@ jobs: done source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log -# - name: RAGFlow CLI retrieval test Infinity -# env: -# PYTHONPATH: ${{ github.workspace }} -# run: | -# set -euo pipefail -# source .venv/bin/activate -# -# export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" -# -# EMAIL="ci-${GITHUB_RUN_ID}@example.com" -# PASS="ci-pass-${GITHUB_RUN_ID}" -# DATASET="ci_dataset_${GITHUB_RUN_ID}" -# -# CLI="python admin/client/ragflow_cli.py" -# -# LOG_FILE="infinity_cli_test.log" -# : > "${LOG_FILE}" -# -# ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]' -# run_cli() { -# local logfile="$1" -# shift -# local allow_re="" -# if [[ "${1:-}" == "--allow" ]]; then -# allow_re="$2" -# shift 2 -# fi -# local cmd_display="$*" -# echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}" -# local tmp_log -# tmp_log="$(mktemp)" -# set +e -# timeout 180s "$@" 2>&1 | tee "${tmp_log}" -# local status=${PIPESTATUS[0]} -# set -e -# cat "${tmp_log}" >> "${logfile}" -# if grep -qiE "${ERROR_RE}" "${tmp_log}"; then -# if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then -# echo "Allowed CLI error markers in ${logfile}" -# rm -f "${tmp_log}" -# return 0 -# fi -# echo "Detected CLI error markers in ${logfile}" -# rm -f "${tmp_log}" -# exit 1 -# fi -# rm -f "${tmp_log}" -# return ${status} -# } -# -# set -a -# source docker/.env -# set +a -# -# HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}" -# USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')" -# USER_PORT="${SVR_HTTP_PORT}" -# ADMIN_HOST="${USER_HOST}" -# ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}" -# -# until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do -# echo "Waiting for service to be available..." -# sleep 5 -# done -# -# admin_ready=0 -# for i in $(seq 1 30); do -# if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then -# admin_ready=1 -# break -# fi -# sleep 1 -# done -# if [[ "${admin_ready}" -ne 1 ]]; then -# echo "Admin service did not become ready" -# exit 1 -# fi -# -# run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version" -# ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?' -# run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'" -# -# user_ready=0 -# for i in $(seq 1 30); do -# if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then -# user_ready=1 -# break -# fi -# sleep 1 -# done -# if [[ "${user_ready}" -ne 1 ]]; then -# echo "User service did not become ready" -# exit 1 -# fi -# -# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version" -# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'" -# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'" -# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync" -# run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'" + - name: RAGFlow CLI retrieval test Infinity + env: + PYTHONPATH: ${{ github.workspace }} + run: | + set -euo pipefail + source .venv/bin/activate + + export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" + + EMAIL="ci-${GITHUB_RUN_ID}@example.com" + PASS="ci-pass-${GITHUB_RUN_ID}" + DATASET="ci_dataset_${GITHUB_RUN_ID}" + + CLI="python admin/client/ragflow_cli.py" + + LOG_FILE="infinity_cli_test.log" + : > "${LOG_FILE}" + + ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]' + run_cli() { + local logfile="$1" + shift + local allow_re="" + if [[ "${1:-}" == "--allow" ]]; then + allow_re="$2" + shift 2 + fi + local cmd_display="$*" + echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}" + local tmp_log + tmp_log="$(mktemp)" + set +e + timeout 300s "$@" 2>&1 | tee "${tmp_log}" + local status=${PIPESTATUS[0]} + set -e + cat "${tmp_log}" >> "${logfile}" + if grep -qiE "${ERROR_RE}" "${tmp_log}"; then + if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then + echo "Allowed CLI error markers in ${logfile}" + rm -f "${tmp_log}" + return 0 + fi + echo "Detected CLI error markers in ${logfile}" + rm -f "${tmp_log}" + exit 1 + fi + rm -f "${tmp_log}" + return ${status} + } + + set -a + source docker/.env + set +a + + HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}" + USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')" + USER_PORT="${SVR_HTTP_PORT}" + ADMIN_HOST="${USER_HOST}" + ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}" + + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do + echo "Waiting for service to be available..." + sleep 5 + done + + admin_ready=0 + for i in $(seq 1 30); do + if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then + admin_ready=1 + break + fi + sleep 1 + done + if [[ "${admin_ready}" -ne 1 ]]; then + echo "Admin service did not become ready" + exit 1 + fi + + run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version" + ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?' + run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'" + + user_ready=0 + for i in $(seq 1 30); do + if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then + user_ready=1 + break + fi + sleep 1 + done + if [[ "${user_ready}" -ne 1 ]]; then + echo "User service did not become ready" + exit 1 + fi + + run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version" + run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'" + run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'" + run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync" + run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'" - name: Stop ragflow to save coverage Infinity if: ${{ !cancelled() }} From eaa627c912d69c9d2dbc7b3afd4b750c964302b4 Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:10:06 +0800 Subject: [PATCH 129/565] Playwright : add new test for configuration tab in datasets (#13365) ### What problem does this PR solve? this pr adds new tests, for the full configuration tab in datasests ### Type of change - [x] Other (please describe): new tests --- api/db/services/dialog_service.py | 134 +++-- .../e2e/test_dataset_upload_parse.py | 471 +++++++++++++++++- test/playwright/helpers/datasets.py | 60 ++- ...t_dialog_service_use_sql_source_columns.py | 221 ++++++++ 4 files changed, 833 insertions(+), 53 deletions(-) create mode 100644 test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 22d38da2f68..1dcab82a7f0 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -778,6 +778,47 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N table_name = base_table logging.debug(f"use_sql: Using ES/OS table name: {table_name}") + expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd" + + def has_source_columns(columns): + normalized_names = {str(col.get("name", "")).lower() for col in columns} + return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names) + + def is_aggregate_sql(sql_text): + return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower())) + + def normalize_sql(sql): + logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") + # Remove think blocks if present (format: ...) + sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) + sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL) + # Remove markdown code blocks (```sql ... ```) + sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) + sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) + # Remove trailing semicolon that ES SQL parser doesn't like + return sql.rstrip().rstrip(';').strip() + + def add_kb_filter(sql): + # Add kb_id filter for ES/OS only (Infinity already has it in table name) + if doc_engine == "infinity" or not kb_ids: + return sql + + # Build kb_filter: single KB or multiple KBs with OR + if len(kb_ids) == 1: + kb_filter = f"kb_id = '{kb_ids[0]}'" + else: + kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + + if "where " not in sql.lower(): + o = sql.lower().split("order by") + if len(o) > 1: + sql = o[0] + f" WHERE {kb_filter} order by " + o[1] + else: + sql += f" WHERE {kb_filter}" + elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower(): + sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE) + return sql + def is_row_count_question(q: str) -> bool: q = (q or "").lower() if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q): @@ -881,38 +922,15 @@ def is_row_count_question(q: str) -> bool: tried_times = 0 - async def get_table(): + async def get_table(custom_user_prompt=None): nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override - if row_count_override: + if row_count_override and custom_user_prompt is None: sql = row_count_override else: - sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) - logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") - # Remove think blocks if present (format: ...) - sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) - sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL) - # Remove markdown code blocks (```sql ... ```) - sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) - sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) - # Remove trailing semicolon that ES SQL parser doesn't like - sql = sql.rstrip().rstrip(';').strip() - - # Add kb_id filter for ES/OS only (Infinity already has it in table name) - if doc_engine != "infinity" and kb_ids: - # Build kb_filter: single KB or multiple KBs with OR - if len(kb_ids) == 1: - kb_filter = f"kb_id = '{kb_ids[0]}'" - else: - kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" - - if "where " not in sql.lower(): - o = sql.lower().split("order by") - if len(o) > 1: - sql = o[0] + f" WHERE {kb_filter} order by " + o[1] - else: - sql += f" WHERE {kb_filter}" - elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower(): - sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE) + prompt = custom_user_prompt if custom_user_prompt is not None else user_prompt + sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": prompt}], {"temperature": 0.06}) + sql = normalize_sql(sql) + sql = add_kb_filter(sql) logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 @@ -924,6 +942,46 @@ async def get_table(): logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows") return tbl, sql + async def repair_table_for_missing_source_columns(previous_sql): + if doc_engine in ("infinity", "oceanbase"): + json_field_names = list(field_map.keys()) + repair_prompt = """Table name: {}; +JSON fields available in 'chunk_data' column (use exact names): +{} + +Question: {} +Previous SQL: +{} + +The previous SQL result is missing required source columns for citations. +Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list. +For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name'). +Return ONLY SQL.""".format( + table_name, + "\n".join([f" - {field}" for field in json_field_names]), + question, + previous_sql, + expected_doc_name_column + ) + else: + repair_prompt = """Table name: {} +Available fields: +{} + +Question: {} +Previous SQL: +{} + +The previous SQL result is missing required source columns for citations. +Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list. +Return ONLY SQL.""".format( + table_name, + "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), + question, + previous_sql + ) + return await get_table(custom_user_prompt=repair_prompt) + try: tbl, sql = await get_table() logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}") @@ -977,6 +1035,22 @@ async def get_table(): logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}") return None + if not is_aggregate_sql(sql) and not has_source_columns(tbl.get("columns", [])): + logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}") + try: + repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql) + if ( + repaired_tbl + and len(repaired_tbl.get("rows", [])) > 0 + and has_source_columns(repaired_tbl.get("columns", [])) + ): + tbl, sql = repaired_tbl, repaired_sql + logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}") + else: + logging.warning(f"use_sql: Source-column SQL repair did not provide required columns. Repaired SQL: {repaired_sql}") + except Exception as e: + logging.warning(f"use_sql: Source-column SQL repair failed, returning best-effort answer. Error: {e}") + logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer") docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) @@ -1072,7 +1146,7 @@ def map_column_name(col_name): logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}") # For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately # to provide source chunks, but keep the original table format answer - if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()): + if is_aggregate_sql(sql): # Keep original table format as answer answer = "\n".join([columns, line, rows]) diff --git a/test/playwright/e2e/test_dataset_upload_parse.py b/test/playwright/e2e/test_dataset_upload_parse.py index 40c0af93cc4..29f3a399cfe 100644 --- a/test/playwright/e2e/test_dataset_upload_parse.py +++ b/test/playwright/e2e/test_dataset_upload_parse.py @@ -1,3 +1,4 @@ +import base64 import json import re import time @@ -25,6 +26,178 @@ RESULT_TIMEOUT_MS = 15000 +def make_test_png(path: Path) -> Path: + png_b64 = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8" + "/w8AAgMBAp6X6QAAAABJRU5ErkJggg==" + ) + path.write_bytes(base64.b64decode(png_b64)) + return path + + +def extract_dataset_id_from_url(url: str) -> str: + match = re.search(r"/(?:datasets|dataset/dataset)/([^/?#]+)", url or "") + if not match: + raise AssertionError(f"Unable to parse dataset id from url={url!r}") + return match.group(1) + + +def set_switch_state(page, test_id: str, desired_checked: bool) -> None: + switch = page.get_by_test_id(test_id).first + expect(switch).to_be_visible(timeout=RESULT_TIMEOUT_MS) + switch.scroll_into_view_if_needed() + current_checked = (switch.get_attribute("data-state") or "") == "checked" + if current_checked == desired_checked: + return + switch.click() + expect(switch).to_have_attribute( + "data-state", + "checked" if desired_checked else "unchecked", + timeout=RESULT_TIMEOUT_MS, + ) + + +def set_number_input(page, test_id: str, value: str | int | float) -> None: + number_input = page.get_by_test_id(test_id).first + expect(number_input).to_be_visible(timeout=RESULT_TIMEOUT_MS) + number_input.scroll_into_view_if_needed() + number_input.click() + try: + number_input.press("Control+a") + except Exception: + pass + number_input.fill(str(value)) + try: + number_input.press("Tab") + except Exception: + pass + + +def select_combobox_option( + page, + trigger_test_id: str, + preferred_text: str | None = None, +) -> str: + trigger = page.get_by_test_id(trigger_test_id).first + expect(trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS) + trigger.scroll_into_view_if_needed() + current_text = "" + try: + current_text = trigger.inner_text().strip() + except Exception: + current_text = "" + trigger.click() + + options = page.get_by_test_id("combobox-option") + expect(options.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) + + if preferred_text: + preferred_option = options.filter( + has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I) + ) + if preferred_option.count() > 0: + preferred_option.first.click() + return preferred_text + + selected_text = "" + option_count = options.count() + for idx in range(option_count): + option = options.nth(idx) + try: + if not option.is_visible(): + continue + except Exception: + continue + text = option.inner_text().strip() + if not text: + continue + if current_text and text.lower() == current_text.lower() and option_count > 1: + continue + option.click() + selected_text = text + break + + if not selected_text: + fallback = options.first + selected_text = fallback.inner_text().strip() + fallback.click() + return selected_text + + +def select_ragflow_option( + page, + trigger_test_id: str, + preferred_text: str | None = None, +) -> str: + trigger = page.get_by_test_id(trigger_test_id).first + expect(trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS) + trigger.scroll_into_view_if_needed() + current_text = "" + try: + current_text = trigger.inner_text().strip() + except Exception: + current_text = "" + trigger.click() + + options = page.locator("[role='option']") + expect(options.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) + + if preferred_text: + preferred_option = options.filter( + has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I) + ) + if preferred_option.count() > 0: + preferred_option.first.click() + return preferred_text + + selected_text = "" + option_count = options.count() + for idx in range(option_count): + option = options.nth(idx) + try: + if not option.is_visible(): + continue + except Exception: + continue + text = option.inner_text().strip() + if not text: + continue + if current_text and text.lower() == current_text.lower() and option_count > 1: + continue + option.click() + selected_text = text + break + + if not selected_text: + fallback = options.first + selected_text = fallback.inner_text().strip() + fallback.click() + return selected_text + + +def get_request_json_payload(response) -> dict: + payload = None + request = response.request + try: + post_data_json = request.post_data_json + payload = post_data_json() if callable(post_data_json) else post_data_json + except Exception: + payload = None + + if payload is None: + try: + post_data = request.post_data + raw = post_data() if callable(post_data) else post_data + if raw: + payload = json.loads(raw) + except Exception: + payload = None + + if not isinstance(payload, dict): + raise AssertionError(f"Expected JSON object payload for /v1/kb/update, got={payload!r}") + return payload + + def step_01_login( flow_page, flow_state, @@ -35,6 +208,8 @@ def step_01_login( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): repo_root = Path(__file__).resolve().parents[3] file_paths = [ @@ -71,6 +246,8 @@ def step_02_open_datasets( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): require(flow_state, "logged_in") page = flow_page @@ -97,11 +274,29 @@ def step_03_create_dataset( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): require(flow_state, "logged_in") page = flow_page with step("open create dataset modal"): - modal = open_create_dataset_modal(page, expect, RESULT_TIMEOUT_MS) + try: + modal = open_create_dataset_modal(page, expect, RESULT_TIMEOUT_MS) + except AssertionError: + fallback_id = (ensure_dataset_ready or {}).get("kb_id") + fallback_name = (ensure_dataset_ready or {}).get("kb_name") + if not fallback_id or not fallback_name: + raise + page.goto( + urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset/{fallback_id}"), + wait_until="domcontentloaded", + ) + wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS * 2) + flow_state["dataset_name"] = fallback_name + flow_state["dataset_id"] = fallback_id + snap("dataset_created") + snap("dataset_detail_ready") + return snap("dataset_modal_open") dataset_name = f"qa-dataset-{int(time.time() * 1000)}" @@ -122,16 +317,256 @@ def step_03_create_dataset( if save_button is None or save_button.count() == 0: save_button = modal.locator("button", has_text=re.compile(r"^save$", re.I)).first expect(save_button).to_be_visible(timeout=RESULT_TIMEOUT_MS) - save_button.click() + created_kb_id = None + + def trigger(): + save_button.click() + + create_response = capture_response( + page, + trigger, + lambda resp: resp.request.method == "POST" and "/v1/kb/create" in resp.url, + timeout_ms=RESULT_TIMEOUT_MS * 2, + ) + try: + create_payload = create_response.json() + except Exception: + create_payload = {} + if isinstance(create_payload, dict): + data = create_payload.get("data") or {} + if isinstance(data, dict): + created_kb_id = data.get("id") or data.get("kb_id") + expect(modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) - wait_for_dataset_detail(page, timeout_ms=RESULT_TIMEOUT_MS) - wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS) + try: + wait_for_dataset_detail(page, timeout_ms=RESULT_TIMEOUT_MS * 2) + except Exception: + if created_kb_id: + page.goto( + urljoin( + base_url.rstrip("/") + "/", f"/dataset/dataset/{created_kb_id}" + ), + wait_until="domcontentloaded", + ) + else: + raise + wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS * 2) + dataset_id = extract_dataset_id_from_url(page.url) flow_state["dataset_name"] = dataset_name + flow_state["dataset_id"] = dataset_id snap("dataset_created") snap("dataset_detail_ready") -def step_04_upload_files( +def step_04_set_dataset_settings( + flow_page, + flow_state, + base_url, + login_url, + active_auth_context, + step, + snap, + auth_click, + seeded_user_credentials, + tmp_path, + ensure_dataset_ready, +): + require(flow_state, "dataset_name", "dataset_id") + page = flow_page + dataset_id = flow_state["dataset_id"] + dataset_name = flow_state["dataset_name"] + metadata_field_key = "auto_meta_field" + + with step("open dataset settings page"): + page.goto( + urljoin( + base_url.rstrip("/") + "/", f"/dataset/dataset-setting/{dataset_id}" + ), + wait_until="domcontentloaded", + ) + expect(page.get_by_test_id("ds-settings-basic-name-input")).to_be_visible( + timeout=RESULT_TIMEOUT_MS + ) + expect(page.get_by_test_id("ds-settings-page-save-btn")).to_be_visible( + timeout=RESULT_TIMEOUT_MS + ) + snap("dataset_settings_open") + + with step("fill base settings"): + page.get_by_test_id("ds-settings-basic-name-input").fill( + f"{dataset_name}-cfg" + ) + select_combobox_option( + page, "ds-settings-basic-language-select", preferred_text="English" + ) + + avatar_path = make_test_png(tmp_path / "avatar-test.png") + page.get_by_test_id("ds-settings-basic-avatar-upload").set_input_files( + str(avatar_path) + ) + crop_modal = page.get_by_test_id("ds-settings-basic-avatar-crop-modal") + expect(crop_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS) + page.get_by_test_id("ds-settings-basic-avatar-crop-confirm-btn").click() + expect(crop_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) + + page.get_by_test_id("ds-settings-basic-description-input").fill( + "Dataset setting playwright description" + ) + try: + select_combobox_option(page, "ds-settings-basic-permissions-select") + except Exception: + page.keyboard.press("Escape") + + embedding_trigger = page.get_by_test_id( + "ds-settings-basic-embedding-model-select" + ).first + expect(embedding_trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS) + if not embedding_trigger.is_disabled(): + try: + select_combobox_option(page, "ds-settings-basic-embedding-model-select") + except Exception: + page.keyboard.press("Escape") + + with step("fill parser and metadata settings"): + set_number_input(page, "ds-settings-parser-page-rank-input", 12) + select_combobox_option( + page, "ds-settings-parser-pdf-parser-select", preferred_text="Plain Text" + ) + set_number_input(page, "ds-settings-parser-recommended-chunk-size-input", 640) + set_switch_state(page, "ds-settings-parser-child-chunk-switch", True) + expect( + page.get_by_test_id("ds-settings-parser-child-chunk-delimiter-input") + ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + set_switch_state(page, "ds-settings-parser-page-index-switch", True) + set_number_input(page, "ds-settings-parser-image-table-context-window-input", 16) + set_switch_state(page, "ds-settings-metadata-switch", True) + + page.get_by_test_id("ds-settings-metadata-open-modal-btn").click() + metadata_modal = page.get_by_test_id("ds-settings-metadata-modal") + expect(metadata_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS) + page.get_by_test_id("ds-settings-metadata-add-btn").click() + + nested_modal = page.get_by_test_id("ds-settings-metadata-add-modal") + expect(nested_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS) + field_input = nested_modal.locator("input[name='field']") + if field_input.count() == 0: + field_input = nested_modal.locator("input") + expect(field_input.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) + field_input.first.fill(metadata_field_key) + description_input = nested_modal.locator("textarea") + if description_input.count() > 0: + description_input.first.fill("auto metadata field from playwright") + confirm_btn = page.get_by_test_id("ds-settings-metadata-add-modal-confirm-btn") + confirm_btn.click() + try: + expect(nested_modal).not_to_be_visible(timeout=3000) + except AssertionError: + retry_field_input = nested_modal.locator("input[name='field']") + if retry_field_input.count() > 0: + retry_field_input.first.fill("auto_meta_field_retry") + confirm_btn.click() + expect(nested_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) + snap("dataset_settings_metadata_modal") + + page.get_by_test_id("ds-settings-metadata-modal-save-btn").click() + expect(metadata_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) + + overlap_slider = page.get_by_test_id( + "ds-settings-parser-overlapped-percent-slider" + ).first + expect(overlap_slider).to_be_visible(timeout=RESULT_TIMEOUT_MS) + overlap_slider.focus() + overlap_slider.press("ArrowRight") + set_number_input(page, "ds-settings-parser-auto-keyword-input", 3) + set_number_input(page, "ds-settings-parser-auto-question-input", 2) + set_switch_state(page, "ds-settings-parser-excel-to-html-switch", True) + + with step("fill graph and raptor settings"): + page.get_by_test_id("ds-settings-graph-entity-types-add-btn").click() + entity_input = page.get_by_test_id("ds-settings-graph-entity-types-input").first + expect(entity_input).to_be_visible(timeout=RESULT_TIMEOUT_MS) + entity_input.fill("playwright_entity") + entity_input.press("Enter") + select_ragflow_option( + page, "ds-settings-graph-method-select", preferred_text="General" + ) + set_switch_state(page, "ds-settings-graph-entity-resolution-switch", True) + set_switch_state(page, "ds-settings-graph-community-reports-switch", True) + + page.get_by_test_id("ds-settings-raptor-generation-scope-option-dataset").click() + page.get_by_test_id("ds-settings-raptor-prompt-textarea").fill( + "Playwright prompt for dataset settings" + ) + set_number_input(page, "ds-settings-raptor-max-token-input", 300) + set_number_input(page, "ds-settings-raptor-threshold-input", 0.3) + set_number_input(page, "ds-settings-raptor-max-cluster-input", 128) + set_number_input(page, "ds-settings-raptor-seed-input", 1234) + seed_input = page.get_by_test_id("ds-settings-raptor-seed-input").first + seed_before_randomize = seed_input.input_value() + page.get_by_test_id("ds-settings-raptor-seed-randomize-btn").click() + page.wait_for_function( + """([testId, previous]) => { + const node = document.querySelector(`[data-testid="${testId}"]`); + return !!node && String(node.value) !== String(previous); + }""", + arg=["ds-settings-raptor-seed-input", seed_before_randomize], + timeout=RESULT_TIMEOUT_MS, + ) + + with step("save dataset settings and assert update payload"): + try: + expect(page.locator("[data-sonner-toast]")).to_have_count(0, timeout=8000) + except AssertionError: + pass + save_btn = page.get_by_test_id("ds-settings-page-save-btn").first + expect(save_btn).to_be_visible(timeout=RESULT_TIMEOUT_MS) + + def trigger(): + save_btn.click() + + response = capture_response( + page, + trigger, + lambda resp: resp.request.method == "POST" and "/v1/kb/update" in resp.url, + timeout_ms=RESULT_TIMEOUT_MS * 2, + ) + assert 200 <= response.status < 400, f"Unexpected /v1/kb/update status={response.status}" + response_payload = response.json() + if isinstance(response_payload, dict): + assert response_payload.get("code") == 0, ( + f"/v1/kb/update response code={response_payload.get('code')} " + f"message={response_payload.get('message')}" + ) + + payload = get_request_json_payload(response) + assert payload.get("kb_id") == dataset_id, ( + f"Expected kb_id={dataset_id!r}, got {payload.get('kb_id')!r}" + ) + for key in ("name", "language", "parser_config"): + assert key in payload, f"Expected key {key!r} in /v1/kb/update payload" + parser_config = payload.get("parser_config") or {} + assert ( + parser_config.get("image_table_context_window") + == parser_config.get("image_context_size") + == parser_config.get("table_context_size") + ), "Expected image/table context window transform keys to be aligned" + expect(page.locator("[data-sonner-toast]").first).to_be_visible( + timeout=RESULT_TIMEOUT_MS + ) + + with step("return to dataset detail for upload"): + page.goto( + urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset/{dataset_id}"), + wait_until="domcontentloaded", + ) + wait_for_dataset_detail_ready(page, expect, timeout_ms=RESULT_TIMEOUT_MS) + + flow_state["dataset_settings_done"] = True + flow_state["settings_update_payload"] = payload + snap("dataset_settings_saved") + + +def step_05_upload_files( flow_page, flow_state, base_url, @@ -141,8 +576,10 @@ def step_04_upload_files( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): - require(flow_state, "dataset_name", "file_paths") + require(flow_state, "dataset_name", "dataset_settings_done", "file_paths") page = flow_page file_paths = [Path(path) for path in flow_state["file_paths"]] filenames = flow_state.get("filenames") or [path.name for path in file_paths] @@ -193,7 +630,7 @@ def trigger(): flow_state["uploads_done"] = True -def step_05_wait_parse_success( +def step_06_wait_parse_success( flow_page, flow_state, base_url, @@ -203,17 +640,20 @@ def step_05_wait_parse_success( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): require(flow_state, "uploads_done", "filenames") page = flow_page + parse_timeout_ms = RESULT_TIMEOUT_MS * 8 for filename in flow_state["filenames"]: with step(f"wait for parse success {filename}"): - wait_for_success_dot(page, expect, filename, timeout_ms=RESULT_TIMEOUT_MS) + wait_for_success_dot(page, expect, filename, timeout_ms=parse_timeout_ms) snap(f"parse_{filename}_success") flow_state["parse_complete"] = True -def step_06_delete_one_file( +def step_07_delete_one_file( flow_page, flow_state, base_url, @@ -223,6 +663,8 @@ def step_06_delete_one_file( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ): require(flow_state, "parse_complete", "filenames") page = flow_page @@ -247,9 +689,10 @@ def step_06_delete_one_file( ("01_login", step_01_login), ("02_open_datasets", step_02_open_datasets), ("03_create_dataset", step_03_create_dataset), - ("04_upload_files", step_04_upload_files), - ("05_wait_parse_success", step_05_wait_parse_success), - ("06_delete_one_file", step_06_delete_one_file), + ("04_set_dataset_settings", step_04_set_dataset_settings), + ("05_upload_files", step_05_upload_files), + ("06_wait_parse_success", step_06_wait_parse_success), + ("07_delete_one_file", step_07_delete_one_file), ] @@ -263,11 +706,13 @@ def test_dataset_upload_parse_and_delete_flow( base_url, login_url, ensure_model_provider_configured, + ensure_dataset_ready, active_auth_context, step, snap, auth_click, seeded_user_credentials, + tmp_path, ): step_fn( flow_page, @@ -279,4 +724,6 @@ def test_dataset_upload_parse_and_delete_flow( snap, auth_click, seeded_user_credentials, + tmp_path, + ensure_dataset_ready, ) diff --git a/test/playwright/helpers/datasets.py b/test/playwright/helpers/datasets.py index 124c1b4a254..89f832aa0a3 100644 --- a/test/playwright/helpers/datasets.py +++ b/test/playwright/helpers/datasets.py @@ -465,6 +465,31 @@ def _click_create_button_entrypoint() -> None: def delete_uploaded_file(page, expect, filename: str, timeout_ms: int) -> None: """Delete a document row by filename and confirm the modal.""" + + def visible_confirm_dialog(): + confirm = page.locator("[data-testid='confirm-delete-dialog']:visible") + if confirm.count() > 0: + return confirm.last + + confirm = page.locator("[role='alertdialog']:visible") + if confirm.count() > 0: + return confirm.last + + return page.locator("[role='alertdialog']").last + + def confirm_delete_button(confirm): + by_testid = confirm.get_by_test_id("confirm-delete-dialog-confirm-btn") + if by_testid.count() > 0: + return by_testid.first + + by_label = confirm.locator( + "button:visible", has_text=re.compile("^delete$", re.I) + ) + if by_label.count() > 0: + return by_label.first + + return confirm.locator("button:visible").last + row = page.locator( f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]" ) @@ -472,18 +497,31 @@ def delete_uploaded_file(page, expect, filename: str, timeout_ms: int) -> None: delete_button = row.locator("[data-testid='document-delete']") expect(delete_button).to_be_visible(timeout=timeout_ms) delete_button.click() - confirm = page.locator("[role='alertdialog']") - expect(confirm).to_be_visible() - confirm_delete = confirm.locator( - "button", has_text=re.compile("^delete$", re.I) - ).first + + confirm = visible_confirm_dialog() + expect(confirm).to_be_visible(timeout=timeout_ms) + confirm_delete = confirm_delete_button(confirm) expect(confirm_delete).to_be_visible(timeout=timeout_ms) try: - confirm_delete.click(timeout=timeout_ms) - except Exception: - # The confirm button can rerender during open/animation; reacquire and force. - confirm_delete = confirm.locator( - "button", has_text=re.compile("^delete$", re.I) - ).first confirm_delete.click(timeout=timeout_ms, force=True) + except Exception: + # The confirm action can rerender/detach during click. If delete already + # happened, avoid reopening flows and continue. + try: + expect(row).not_to_be_visible(timeout=2000) + return + except AssertionError: + pass + + confirm = visible_confirm_dialog() + if confirm.count() == 0: + # Re-open delete confirmation only when needed. + delete_button = row.locator("[data-testid='document-delete']") + if delete_button.count() > 0: + delete_button.first.click() + confirm = visible_confirm_dialog() + + if confirm.count() > 0: + confirm_delete = confirm_delete_button(confirm) + confirm_delete.click(timeout=timeout_ms, force=True) expect(row).not_to_be_visible(timeout=timeout_ms) diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py new file mode 100644 index 00000000000..a79d9358178 --- /dev/null +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -0,0 +1,221 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import sys +import types +import warnings + +import pytest + +# xgboost imports pkg_resources and emits a deprecation warning that is promoted +# to error in our pytest configuration; ignore it for this unit test module. +warnings.filterwarnings( + "ignore", + message="pkg_resources is deprecated as an API.*", + category=UserWarning, +) + + +def _install_cv2_stub_if_unavailable(): + try: + import cv2 # noqa: F401 + return + except Exception: + pass + + stub = types.ModuleType("cv2") + + # Constants referenced by deepdoc import-time defaults. + stub.INTER_LINEAR = 1 + stub.INTER_CUBIC = 2 + stub.BORDER_CONSTANT = 0 + stub.BORDER_REPLICATE = 1 + stub.COLOR_BGR2RGB = 0 + stub.COLOR_BGR2GRAY = 1 + stub.COLOR_GRAY2BGR = 2 + stub.IMREAD_IGNORE_ORIENTATION = 128 + stub.IMREAD_COLOR = 1 + stub.RETR_LIST = 1 + stub.CHAIN_APPROX_SIMPLE = 2 + + def _missing(*_args, **_kwargs): + raise RuntimeError("cv2 runtime call is unavailable in this test environment") + + def _module_getattr(name): + if name.isupper(): + return 0 + return _missing + + stub.__getattr__ = _module_getattr + sys.modules["cv2"] = stub + + +_install_cv2_stub_if_unavailable() + +from api.db.services import dialog_service + + +class _StubChatModel: + def __init__(self, outputs): + self._outputs = outputs + self.calls = [] + + async def async_chat(self, system_prompt, messages, llm_setting): + idx = len(self.calls) + if idx >= len(self._outputs): + raise AssertionError("async_chat called more times than expected") + self.calls.append( + { + "system_prompt": system_prompt, + "message": messages[0]["content"], + "llm_setting": llm_setting, + } + ) + return self._outputs[idx] + + +class _StubRetriever: + def __init__(self, results): + self._results = results + self.sql_calls = [] + + def sql_retrieval(self, sql, format="json"): + assert format == "json" + idx = len(self.sql_calls) + if idx >= len(self._results): + raise AssertionError("sql_retrieval called more times than expected") + self.sql_calls.append(sql) + return self._results[idx] + + +@pytest.fixture +def force_es_engine(monkeypatch): + monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_INFINITY", False) + monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_OCEANBASE", False) + + +@pytest.mark.p2 +def test_use_sql_repairs_missing_source_columns_for_non_aggregate(monkeypatch, force_es_engine): + retriever = _StubRetriever( + [ + { + "columns": [{"name": "product"}], + "rows": [["desk"], ["monitor"]], + }, + { + "columns": [{"name": "doc_id"}, {"name": "docnm_kwd"}, {"name": "product"}], + "rows": [["doc-1", "products.xlsx", "desk"], ["doc-2", "products.xlsx", "monitor"]], + }, + ] + ) + chat_model = _StubChatModel( + [ + "SELECT product FROM ragflow_tenant", + "SELECT doc_id, docnm_kwd, product FROM ragflow_tenant", + ] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + + result = asyncio.run( + dialog_service.use_sql( + question="show me column of product", + field_map={"product": "product"}, + tenant_id="tenant-id", + chat_mdl=chat_model, + quota=True, + kb_ids=None, + ) + ) + + assert result is not None + assert "|product|Source|" in result["answer"] + assert len(chat_model.calls) == 2 + assert len(retriever.sql_calls) == 2 + + +@pytest.mark.p2 +def test_use_sql_keeps_aggregate_flow_without_source_repair(monkeypatch, force_es_engine): + retriever = _StubRetriever( + [ + { + "columns": [{"name": "count(star)"}], + "rows": [[6]], + }, + ] + ) + chat_model = _StubChatModel( + [ + "SELECT COUNT(*) FROM ragflow_tenant", + ] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + + result = asyncio.run( + dialog_service.use_sql( + question="how many rows are there", + field_map={"product": "product"}, + tenant_id="tenant-id", + chat_mdl=chat_model, + quota=True, + kb_ids=None, + ) + ) + + assert result is not None + assert "|COUNT(*)|" in result["answer"] + assert "Source" not in result["answer"] + assert len(chat_model.calls) == 1 + assert len(retriever.sql_calls) == 1 + + +@pytest.mark.p2 +def test_use_sql_source_repair_is_bounded_to_single_retry(monkeypatch, force_es_engine): + retriever = _StubRetriever( + [ + { + "columns": [{"name": "product"}], + "rows": [["desk"]], + }, + { + "columns": [{"name": "product"}], + "rows": [["desk"]], + }, + ] + ) + chat_model = _StubChatModel( + [ + "SELECT product FROM ragflow_tenant", + "SELECT product FROM ragflow_tenant WHERE product IS NOT NULL", + ] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + + result = asyncio.run( + dialog_service.use_sql( + question="show me column of product", + field_map={"product": "product"}, + tenant_id="tenant-id", + chat_mdl=chat_model, + quota=True, + kb_ids=None, + ) + ) + + assert result is not None + assert "|product|" in result["answer"] + assert "Source" not in result["answer"] + assert len(chat_model.calls) == 2 + assert len(retriever.sql_calls) == 2 From 07a33c13459f1c7a2e73ed92e538e54b5bc33eba Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 4 Mar 2026 19:17:16 +0800 Subject: [PATCH 130/565] RAGFlow go API server (#13240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # RAGFlow Go Implementation Plan 🚀 This repository tracks the progress of porting RAGFlow to Go. We'll implement core features and provide performance comparisons between Python and Go versions. ## Implementation Checklist - [x] User Management APIs - [x] Dataset Management Operations - [x] Retrieval Test - [x] Chat Management Operations - [x] Infinity Go SDK --------- Signed-off-by: Jin Hai Co-authored-by: Yingfeng Zhang --- .github/workflows/tests.yml | 21 +- .gitignore | 2 + Dockerfile | 1 + admin/client/COMMAND.md | 779 ++ admin/client/parser.py | 7 + admin/client/ragflow_client.py | 38 + build.sh | 196 + cmd/ragflow_cli.go | 34 + cmd/server_main.go | 181 + docker/entrypoint.sh | 1 + go.mod | 69 + go.sum | 176 + internal/cache/redis.go | 996 +++ internal/cli/README.md | 87 + internal/cli/benchmark.go | 318 + internal/cli/cli.go | 140 + internal/cli/client.go | 496 ++ internal/cli/crypt.go | 106 + internal/cli/http_client.go | 248 + internal/cli/lexer.go | 301 + internal/cli/parser.go | 1568 ++++ internal/cli/table.go | 167 + internal/cli/types.go | 123 + internal/cpp/CMakeLists.txt | 138 + internal/cpp/Makefile | 81 + internal/cpp/analyzer.h | 88 + internal/cpp/dart_trie.h | 77 + internal/cpp/darts/darts.h | 1733 +++++ internal/cpp/darts_trie.cpp | 109 + internal/cpp/main.cpp | 442 ++ internal/cpp/opencc/config_reader.c | 289 + internal/cpp/opencc/config_reader.h | 46 + internal/cpp/opencc/converter.c | 590 ++ internal/cpp/opencc/converter.h | 48 + internal/cpp/opencc/dictionary/abstract.c | 94 + internal/cpp/opencc/dictionary/abstract.h | 45 + internal/cpp/opencc/dictionary/datrie.c | 250 + internal/cpp/opencc/dictionary/datrie.h | 45 + internal/cpp/opencc/dictionary/text.c | 232 + internal/cpp/opencc/dictionary/text.h | 36 + internal/cpp/opencc/dictionary_group.c | 177 + internal/cpp/opencc/dictionary_group.h | 57 + internal/cpp/opencc/dictionary_set.c | 73 + internal/cpp/opencc/dictionary_set.h | 37 + internal/cpp/opencc/encoding.c | 230 + internal/cpp/opencc/encoding.h | 36 + internal/cpp/opencc/opencc.c | 219 + internal/cpp/opencc/opencc.h | 116 + internal/cpp/opencc/opencc_types.h | 59 + internal/cpp/opencc/openccxx.cpp | 80 + internal/cpp/opencc/openccxx.h | 20 + internal/cpp/opencc/utils.c | 36 + internal/cpp/opencc/utils.h | 71 + internal/cpp/pcre2.h | 1079 +++ internal/cpp/pcre2posix.h | 184 + internal/cpp/rag_analyzer.cpp | 2431 ++++++ internal/cpp/rag_analyzer.h | 177 + internal/cpp/rag_analyzer_c_api.cpp | 225 + internal/cpp/rag_analyzer_c_api.h | 106 + internal/cpp/rag_analyzer_c_api_debug.cpp | 168 + internal/cpp/rag_analyzer_c_test.cpp | 120 + internal/cpp/re2/bitmap256.cc | 44 + internal/cpp/re2/bitmap256.h | 82 + internal/cpp/re2/bitstate.cc | 362 + internal/cpp/re2/compile.cc | 1221 ++++ internal/cpp/re2/dfa.cc | 1985 +++++ internal/cpp/re2/filtered_re2.cc | 118 + internal/cpp/re2/filtered_re2.h | 107 + internal/cpp/re2/mimics_pcre.cc | 192 + internal/cpp/re2/nfa.cc | 651 ++ internal/cpp/re2/onepass.cc | 577 ++ internal/cpp/re2/parse.cc | 2481 +++++++ internal/cpp/re2/perl_groups.cc | 118 + internal/cpp/re2/pod_array.h | 55 + internal/cpp/re2/prefilter.cc | 663 ++ internal/cpp/re2/prefilter.h | 130 + internal/cpp/re2/prefilter_tree.cc | 370 + internal/cpp/re2/prefilter_tree.h | 138 + internal/cpp/re2/prog.cc | 1158 +++ internal/cpp/re2/prog.h | 469 ++ internal/cpp/re2/re2.cc | 1326 ++++ internal/cpp/re2/re2.h | 991 +++ internal/cpp/re2/regexp.cc | 957 +++ internal/cpp/re2/regexp.h | 680 ++ internal/cpp/re2/set.cc | 159 + internal/cpp/re2/set.h | 84 + internal/cpp/re2/simplify.cc | 629 ++ internal/cpp/re2/sparse_array.h | 367 + internal/cpp/re2/sparse_set.h | 248 + internal/cpp/re2/stringpiece.cc | 69 + internal/cpp/re2/stringpiece.h | 189 + internal/cpp/re2/tostring.cc | 345 + internal/cpp/re2/unicode_casefold.cc | 591 ++ internal/cpp/re2/unicode_casefold.h | 78 + internal/cpp/re2/unicode_groups.cc | 6512 +++++++++++++++++ internal/cpp/re2/unicode_groups.h | 64 + internal/cpp/re2/walker-inl.h | 246 + internal/cpp/stemmer/api.cpp | 78 + internal/cpp/stemmer/api.h | 31 + internal/cpp/stemmer/header.h | 59 + internal/cpp/stemmer/stem_UTF_8_danish.cpp | 424 ++ internal/cpp/stemmer/stem_UTF_8_danish.h | 17 + internal/cpp/stemmer/stem_UTF_8_dutch.cpp | 792 ++ internal/cpp/stemmer/stem_UTF_8_dutch.h | 17 + internal/cpp/stemmer/stem_UTF_8_english.cpp | 1316 ++++ internal/cpp/stemmer/stem_UTF_8_english.h | 17 + internal/cpp/stemmer/stem_UTF_8_finnish.cpp | 958 +++ internal/cpp/stemmer/stem_UTF_8_finnish.h | 17 + internal/cpp/stemmer/stem_UTF_8_french.cpp | 1605 ++++ internal/cpp/stemmer/stem_UTF_8_french.h | 17 + internal/cpp/stemmer/stem_UTF_8_german.cpp | 626 ++ internal/cpp/stemmer/stem_UTF_8_german.h | 17 + internal/cpp/stemmer/stem_UTF_8_hungarian.cpp | 1353 ++++ internal/cpp/stemmer/stem_UTF_8_hungarian.h | 17 + internal/cpp/stemmer/stem_UTF_8_italian.cpp | 1288 ++++ internal/cpp/stemmer/stem_UTF_8_italian.h | 17 + internal/cpp/stemmer/stem_UTF_8_norwegian.cpp | 357 + internal/cpp/stemmer/stem_UTF_8_norwegian.h | 17 + internal/cpp/stemmer/stem_UTF_8_porter.cpp | 888 +++ internal/cpp/stemmer/stem_UTF_8_porter.h | 17 + .../cpp/stemmer/stem_UTF_8_portuguese.cpp | 1217 +++ internal/cpp/stemmer/stem_UTF_8_portuguese.h | 17 + internal/cpp/stemmer/stem_UTF_8_romanian.cpp | 1111 +++ internal/cpp/stemmer/stem_UTF_8_romanian.h | 17 + internal/cpp/stemmer/stem_UTF_8_russian.cpp | 774 ++ internal/cpp/stemmer/stem_UTF_8_russian.h | 17 + internal/cpp/stemmer/stem_UTF_8_spanish.cpp | 1319 ++++ internal/cpp/stemmer/stem_UTF_8_spanish.h | 17 + internal/cpp/stemmer/stem_UTF_8_swedish.cpp | 371 + internal/cpp/stemmer/stem_UTF_8_swedish.h | 17 + internal/cpp/stemmer/stem_UTF_8_turkish.cpp | 2978 ++++++++ internal/cpp/stemmer/stem_UTF_8_turkish.h | 17 + internal/cpp/stemmer/stemmer.cpp | 149 + internal/cpp/stemmer/stemmer.h | 58 + internal/cpp/stemmer/utilities.cpp | 509 ++ internal/cpp/string_utils.h | 476 ++ internal/cpp/term.cpp | 24 + internal/cpp/term.h | 72 + internal/cpp/tokenizer.cpp | 315 + internal/cpp/tokenizer.h | 113 + internal/cpp/util/logging.h | 111 + internal/cpp/util/mix.h | 41 + internal/cpp/util/mutex.h | 169 + internal/cpp/util/rune.cc | 246 + internal/cpp/util/strutil.cc | 166 + internal/cpp/util/strutil.h | 21 + internal/cpp/util/utf.h | 43 + internal/cpp/util/util.h | 44 + internal/cpp/wordnet_lemmatizer.cpp | 225 + internal/cpp/wordnet_lemmatizer.h | 52 + internal/dao/chat.go | 212 + internal/dao/chat_session.go | 85 + internal/dao/connector.go | 79 + internal/dao/database.go | 91 + internal/dao/document.go | 81 + internal/dao/file.go | 202 + internal/dao/file2document.go | 60 + internal/dao/kb.go | 149 + internal/dao/llm.go | 69 + internal/dao/model_provider.go | 123 + internal/dao/search.go | 127 + internal/dao/tenant.go | 90 + internal/dao/tenant_llm.go | 136 + internal/dao/user.go | 103 + internal/dao/user_canvas.go | 129 + internal/dao/user_tenant.go | 126 + internal/engine/README.md | 200 + internal/engine/elasticsearch/client.go | 103 + internal/engine/elasticsearch/document.go | 238 + internal/engine/elasticsearch/index.go | 144 + internal/engine/elasticsearch/search.go | 528 ++ internal/engine/engine.go | 67 + internal/engine/global.go | 70 + internal/engine/infinity/client.go | 59 + internal/engine/infinity/document.go | 47 + internal/engine/infinity/index.go | 37 + internal/engine/infinity/search.go | 205 + internal/engine/types/types.go | 54 + internal/go_binding/rag_analyzer.go | 265 + internal/handler/chat.go | 314 + internal/handler/chat_session.go | 377 + internal/handler/chunk.go | 180 + internal/handler/connector.go | 86 + internal/handler/document.go | 258 + internal/handler/error.go | 46 + internal/handler/file.go | 283 + internal/handler/kb.go | 158 + internal/handler/llm.go | 247 + internal/handler/search.go | 129 + internal/handler/system.go | 125 + internal/handler/tenant.go | 135 + internal/handler/user.go | 456 ++ internal/logger/README.md | 70 + internal/logger/logger.go | 138 + internal/model/api.go | 54 + internal/model/base.go | 79 + internal/model/canvas.go | 68 + internal/model/chat.go | 64 + internal/model/connector.go | 78 + internal/model/document.go | 51 + internal/model/evaluation.go | 87 + internal/model/file.go | 49 + internal/model/kb.go | 70 + internal/model/llm.go | 76 + internal/model/mcp.go | 35 + internal/model/memory.go | 42 + internal/model/pipeline.go | 49 + internal/model/search.go | 35 + internal/model/system.go | 30 + internal/model/task.go | 42 + internal/model/tenant.go | 39 + internal/model/tenant_llm.go | 36 + internal/model/types.go | 71 + internal/model/user.go | 45 + internal/model/user_tenant.go | 33 + internal/router/router.go | 194 + internal/server/config.go | 294 + internal/server/model_provider.go | 116 + internal/server/variable.go | 259 + internal/service/chat.go | 623 ++ internal/service/chat_session.go | 893 +++ internal/service/chunk.go | 465 ++ internal/service/connector.go | 69 + internal/service/document.go | 208 + internal/service/file.go | 220 + internal/service/kb.go | 82 + internal/service/llm.go | 248 + internal/service/model_bundle.go | 173 + internal/service/model_service.go | 117 + internal/service/models/deepseek_model.go | 33 + internal/service/models/factory.go | 58 + internal/service/models/gitee_model.go | 126 + internal/service/models/moonshot_model.go | 33 + .../models/openai_api_compatible_model.go | 33 + internal/service/models/openai_model.go | 123 + internal/service/models/siliconflow_model.go | 123 + internal/service/models/zhipu_model.go | 33 + internal/service/nlp/query_builder.go | 655 ++ internal/service/nlp/query_builder_test.go | 471 ++ internal/service/nlp/reranker.go | 471 ++ internal/service/nlp/synonym.go | 222 + internal/service/nlp/synonym_test.go | 444 ++ internal/service/nlp/term_weight.go | 496 ++ internal/service/nlp/term_weight_test.go | 832 +++ internal/service/nlp/wordnet.go | 572 ++ internal/service/nlp/wordnet_test.go | 285 + internal/service/search.go | 132 + internal/service/system.go | 56 + internal/service/tenant.go | 120 + internal/service/user.go | 621 ++ internal/tokenizer/tokenizer.go | 477 ++ .../tokenizer/tokenizer_concurrent_test.go | 493 ++ internal/utility/embedding_lru.go | 141 + internal/utility/token.go | 135 + internal/utility/version.go | 76 + internal/utility/version_test.go | 39 + web/vite.config.ts | 15 + 257 files changed, 80490 insertions(+), 6 deletions(-) create mode 100644 admin/client/COMMAND.md create mode 100755 build.sh create mode 100644 cmd/ragflow_cli.go create mode 100644 cmd/server_main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/cache/redis.go create mode 100644 internal/cli/README.md create mode 100644 internal/cli/benchmark.go create mode 100644 internal/cli/cli.go create mode 100644 internal/cli/client.go create mode 100644 internal/cli/crypt.go create mode 100644 internal/cli/http_client.go create mode 100644 internal/cli/lexer.go create mode 100644 internal/cli/parser.go create mode 100644 internal/cli/table.go create mode 100644 internal/cli/types.go create mode 100644 internal/cpp/CMakeLists.txt create mode 100644 internal/cpp/Makefile create mode 100644 internal/cpp/analyzer.h create mode 100644 internal/cpp/dart_trie.h create mode 100644 internal/cpp/darts/darts.h create mode 100644 internal/cpp/darts_trie.cpp create mode 100644 internal/cpp/main.cpp create mode 100644 internal/cpp/opencc/config_reader.c create mode 100644 internal/cpp/opencc/config_reader.h create mode 100644 internal/cpp/opencc/converter.c create mode 100644 internal/cpp/opencc/converter.h create mode 100644 internal/cpp/opencc/dictionary/abstract.c create mode 100644 internal/cpp/opencc/dictionary/abstract.h create mode 100644 internal/cpp/opencc/dictionary/datrie.c create mode 100644 internal/cpp/opencc/dictionary/datrie.h create mode 100644 internal/cpp/opencc/dictionary/text.c create mode 100644 internal/cpp/opencc/dictionary/text.h create mode 100644 internal/cpp/opencc/dictionary_group.c create mode 100644 internal/cpp/opencc/dictionary_group.h create mode 100644 internal/cpp/opencc/dictionary_set.c create mode 100644 internal/cpp/opencc/dictionary_set.h create mode 100644 internal/cpp/opencc/encoding.c create mode 100644 internal/cpp/opencc/encoding.h create mode 100644 internal/cpp/opencc/opencc.c create mode 100644 internal/cpp/opencc/opencc.h create mode 100644 internal/cpp/opencc/opencc_types.h create mode 100644 internal/cpp/opencc/openccxx.cpp create mode 100644 internal/cpp/opencc/openccxx.h create mode 100644 internal/cpp/opencc/utils.c create mode 100644 internal/cpp/opencc/utils.h create mode 100644 internal/cpp/pcre2.h create mode 100644 internal/cpp/pcre2posix.h create mode 100644 internal/cpp/rag_analyzer.cpp create mode 100644 internal/cpp/rag_analyzer.h create mode 100644 internal/cpp/rag_analyzer_c_api.cpp create mode 100644 internal/cpp/rag_analyzer_c_api.h create mode 100644 internal/cpp/rag_analyzer_c_api_debug.cpp create mode 100644 internal/cpp/rag_analyzer_c_test.cpp create mode 100644 internal/cpp/re2/bitmap256.cc create mode 100644 internal/cpp/re2/bitmap256.h create mode 100644 internal/cpp/re2/bitstate.cc create mode 100644 internal/cpp/re2/compile.cc create mode 100644 internal/cpp/re2/dfa.cc create mode 100644 internal/cpp/re2/filtered_re2.cc create mode 100644 internal/cpp/re2/filtered_re2.h create mode 100644 internal/cpp/re2/mimics_pcre.cc create mode 100644 internal/cpp/re2/nfa.cc create mode 100644 internal/cpp/re2/onepass.cc create mode 100644 internal/cpp/re2/parse.cc create mode 100644 internal/cpp/re2/perl_groups.cc create mode 100644 internal/cpp/re2/pod_array.h create mode 100644 internal/cpp/re2/prefilter.cc create mode 100644 internal/cpp/re2/prefilter.h create mode 100644 internal/cpp/re2/prefilter_tree.cc create mode 100644 internal/cpp/re2/prefilter_tree.h create mode 100644 internal/cpp/re2/prog.cc create mode 100644 internal/cpp/re2/prog.h create mode 100644 internal/cpp/re2/re2.cc create mode 100644 internal/cpp/re2/re2.h create mode 100644 internal/cpp/re2/regexp.cc create mode 100644 internal/cpp/re2/regexp.h create mode 100644 internal/cpp/re2/set.cc create mode 100644 internal/cpp/re2/set.h create mode 100644 internal/cpp/re2/simplify.cc create mode 100644 internal/cpp/re2/sparse_array.h create mode 100644 internal/cpp/re2/sparse_set.h create mode 100644 internal/cpp/re2/stringpiece.cc create mode 100644 internal/cpp/re2/stringpiece.h create mode 100644 internal/cpp/re2/tostring.cc create mode 100644 internal/cpp/re2/unicode_casefold.cc create mode 100644 internal/cpp/re2/unicode_casefold.h create mode 100644 internal/cpp/re2/unicode_groups.cc create mode 100644 internal/cpp/re2/unicode_groups.h create mode 100644 internal/cpp/re2/walker-inl.h create mode 100644 internal/cpp/stemmer/api.cpp create mode 100644 internal/cpp/stemmer/api.h create mode 100644 internal/cpp/stemmer/header.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_danish.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_danish.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_dutch.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_dutch.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_english.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_english.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_finnish.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_finnish.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_french.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_french.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_german.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_german.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_hungarian.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_hungarian.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_italian.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_italian.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_norwegian.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_norwegian.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_porter.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_porter.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_portuguese.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_portuguese.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_romanian.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_romanian.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_russian.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_russian.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_spanish.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_spanish.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_swedish.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_swedish.h create mode 100644 internal/cpp/stemmer/stem_UTF_8_turkish.cpp create mode 100644 internal/cpp/stemmer/stem_UTF_8_turkish.h create mode 100644 internal/cpp/stemmer/stemmer.cpp create mode 100644 internal/cpp/stemmer/stemmer.h create mode 100644 internal/cpp/stemmer/utilities.cpp create mode 100644 internal/cpp/string_utils.h create mode 100644 internal/cpp/term.cpp create mode 100644 internal/cpp/term.h create mode 100644 internal/cpp/tokenizer.cpp create mode 100644 internal/cpp/tokenizer.h create mode 100644 internal/cpp/util/logging.h create mode 100644 internal/cpp/util/mix.h create mode 100644 internal/cpp/util/mutex.h create mode 100644 internal/cpp/util/rune.cc create mode 100644 internal/cpp/util/strutil.cc create mode 100644 internal/cpp/util/strutil.h create mode 100644 internal/cpp/util/utf.h create mode 100644 internal/cpp/util/util.h create mode 100644 internal/cpp/wordnet_lemmatizer.cpp create mode 100644 internal/cpp/wordnet_lemmatizer.h create mode 100644 internal/dao/chat.go create mode 100644 internal/dao/chat_session.go create mode 100644 internal/dao/connector.go create mode 100644 internal/dao/database.go create mode 100644 internal/dao/document.go create mode 100644 internal/dao/file.go create mode 100644 internal/dao/file2document.go create mode 100644 internal/dao/kb.go create mode 100644 internal/dao/llm.go create mode 100644 internal/dao/model_provider.go create mode 100644 internal/dao/search.go create mode 100644 internal/dao/tenant.go create mode 100644 internal/dao/tenant_llm.go create mode 100644 internal/dao/user.go create mode 100644 internal/dao/user_canvas.go create mode 100644 internal/dao/user_tenant.go create mode 100644 internal/engine/README.md create mode 100644 internal/engine/elasticsearch/client.go create mode 100644 internal/engine/elasticsearch/document.go create mode 100644 internal/engine/elasticsearch/index.go create mode 100644 internal/engine/elasticsearch/search.go create mode 100644 internal/engine/engine.go create mode 100644 internal/engine/global.go create mode 100644 internal/engine/infinity/client.go create mode 100644 internal/engine/infinity/document.go create mode 100644 internal/engine/infinity/index.go create mode 100644 internal/engine/infinity/search.go create mode 100644 internal/engine/types/types.go create mode 100644 internal/go_binding/rag_analyzer.go create mode 100644 internal/handler/chat.go create mode 100644 internal/handler/chat_session.go create mode 100644 internal/handler/chunk.go create mode 100644 internal/handler/connector.go create mode 100644 internal/handler/document.go create mode 100644 internal/handler/error.go create mode 100644 internal/handler/file.go create mode 100644 internal/handler/kb.go create mode 100644 internal/handler/llm.go create mode 100644 internal/handler/search.go create mode 100644 internal/handler/system.go create mode 100644 internal/handler/tenant.go create mode 100644 internal/handler/user.go create mode 100644 internal/logger/README.md create mode 100644 internal/logger/logger.go create mode 100644 internal/model/api.go create mode 100644 internal/model/base.go create mode 100644 internal/model/canvas.go create mode 100644 internal/model/chat.go create mode 100644 internal/model/connector.go create mode 100644 internal/model/document.go create mode 100644 internal/model/evaluation.go create mode 100644 internal/model/file.go create mode 100644 internal/model/kb.go create mode 100644 internal/model/llm.go create mode 100644 internal/model/mcp.go create mode 100644 internal/model/memory.go create mode 100644 internal/model/pipeline.go create mode 100644 internal/model/search.go create mode 100644 internal/model/system.go create mode 100644 internal/model/task.go create mode 100644 internal/model/tenant.go create mode 100644 internal/model/tenant_llm.go create mode 100644 internal/model/types.go create mode 100644 internal/model/user.go create mode 100644 internal/model/user_tenant.go create mode 100644 internal/router/router.go create mode 100644 internal/server/config.go create mode 100644 internal/server/model_provider.go create mode 100644 internal/server/variable.go create mode 100644 internal/service/chat.go create mode 100644 internal/service/chat_session.go create mode 100644 internal/service/chunk.go create mode 100644 internal/service/connector.go create mode 100644 internal/service/document.go create mode 100644 internal/service/file.go create mode 100644 internal/service/kb.go create mode 100644 internal/service/llm.go create mode 100644 internal/service/model_bundle.go create mode 100644 internal/service/model_service.go create mode 100644 internal/service/models/deepseek_model.go create mode 100644 internal/service/models/factory.go create mode 100644 internal/service/models/gitee_model.go create mode 100644 internal/service/models/moonshot_model.go create mode 100644 internal/service/models/openai_api_compatible_model.go create mode 100644 internal/service/models/openai_model.go create mode 100644 internal/service/models/siliconflow_model.go create mode 100644 internal/service/models/zhipu_model.go create mode 100644 internal/service/nlp/query_builder.go create mode 100644 internal/service/nlp/query_builder_test.go create mode 100644 internal/service/nlp/reranker.go create mode 100644 internal/service/nlp/synonym.go create mode 100644 internal/service/nlp/synonym_test.go create mode 100644 internal/service/nlp/term_weight.go create mode 100644 internal/service/nlp/term_weight_test.go create mode 100644 internal/service/nlp/wordnet.go create mode 100644 internal/service/nlp/wordnet_test.go create mode 100644 internal/service/search.go create mode 100644 internal/service/system.go create mode 100644 internal/service/tenant.go create mode 100644 internal/service/user.go create mode 100644 internal/tokenizer/tokenizer.go create mode 100644 internal/tokenizer/tokenizer_concurrent_test.go create mode 100644 internal/utility/embedding_lru.go create mode 100644 internal/utility/token.go create mode 100644 internal/utility/version.go create mode 100644 internal/utility/version_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b1c4452c8ee..72dab7c6cdc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -129,13 +129,14 @@ jobs: fi fi - - name: Run unit test + - name: Build ragflow go server run: | - uv sync --python 3.12 --group test --frozen - source .venv/bin/activate - which pytest || echo "pytest not in PATH" - echo "Start to run unit test" - python3 run_tests.py + BUILDER_CONTAINER=ragflow_build_$(od -An -N4 -tx4 /dev/urandom | tr -d ' ') + echo "BUILDER_CONTAINER=${BUILDER_CONTAINER}" >> ${GITHUB_ENV} + TZ=${TZ:-$(readlink -f /etc/localtime | awk -F '/zoneinfo/' '{print $2}')} + sudo docker run --privileged -d --name ${BUILDER_CONTAINER} -e TZ=${TZ} -e UV_INDEX=https://pypi.tuna.tsinghua.edu.cn/simple -v ${PWD}:/ragflow -v ${PWD}/internal/cpp/resource:/usr/share/infinity/resource infiniflow/infinity_builder:ubuntu22_clang20 + sudo docker exec ${BUILDER_CONTAINER} bash -c "git config --global safe.directory \"*\" && cd /ragflow && ./build.sh --cpp" + ./build.sh --go - name: Build ragflow:nightly run: | @@ -152,6 +153,14 @@ jobs: echo "HTTP_API_TEST_LEVEL=${HTTP_API_TEST_LEVEL}" >> ${GITHUB_ENV} echo "RAGFLOW_CONTAINER=${GITHUB_RUN_ID}-ragflow-cpu-1" >> ${GITHUB_ENV} + - name: Run unit test + run: | + uv sync --python 3.12 --group test --frozen + source .venv/bin/activate + which pytest || echo "pytest not in PATH" + echo "Start to run unit test" + python3 run_tests.py + - name: Start ragflow:nightly run: | # Determine runner number (default to 1 if not found) diff --git a/.gitignore b/.gitignore index 0aa8576b993..0baacf87d7b 100644 --- a/.gitignore +++ b/.gitignore @@ -205,6 +205,8 @@ ragflow_cli.egg-info backup +*huqie.txt + .hypothesis diff --git a/Dockerfile b/Dockerfile index 4be231ba911..957bb74a703 100644 --- a/Dockerfile +++ b/Dockerfile @@ -202,6 +202,7 @@ COPY pyproject.toml uv.lock ./ COPY mcp mcp COPY common common COPY memory memory +COPY bin bin COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/admin/client/COMMAND.md b/admin/client/COMMAND.md new file mode 100644 index 00000000000..cd8e376c4db --- /dev/null +++ b/admin/client/COMMAND.md @@ -0,0 +1,779 @@ +# RAGFlow CLI User Command Reference + +This document describes the user commands available in RAGFlow CLI. All commands must end with a semicolon (`;`). + +## Command List + +### ping_server + +**Description** +Tests the connection status to the server. + +**Usage** +``` +PING; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> PING; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### show_current_user + +**Description** +Displays information about the currently logged-in user. + +**Usage** +``` +SHOW CURRENT USER; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> SHOW CURRENT USER; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### create_model_provider + +**Description** +Creates a new model provider. + +**Usage** +``` +CREATE MODEL PROVIDER ; +``` + +**Parameters** +- `provider_name`: Provider name, quoted string. +- `provider_key`: Provider key, quoted string. + +**Example** +``` +ragflow> CREATE MODEL PROVIDER 'openai' 'sk-...'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### drop_model_provider + +**Description** +Deletes a model provider. + +**Usage** +``` +DROP MODEL PROVIDER ; +``` + +**Parameters** +- `provider_name`: Name of the provider to delete, quoted string. + +**Example** +``` +ragflow> DROP MODEL PROVIDER 'openai'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_llm + +**Description** +Sets the default LLM (Large Language Model). + +**Usage** +``` +SET DEFAULT LLM ; +``` + +**Parameters** +- `llm_id`: LLM identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT LLM 'gpt-4'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_vlm + +**Description** +Sets the default VLM (Vision Language Model). + +**Usage** +``` +SET DEFAULT VLM ; +``` + +**Parameters** +- `vlm_id`: VLM identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT VLM 'clip-vit-large'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_embedding + +**Description** +Sets the default embedding model. + +**Usage** +``` +SET DEFAULT EMBEDDING ; +``` + +**Parameters** +- `embedding_id`: Embedding model identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT EMBEDDING 'text-embedding-ada-002'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_reranker + +**Description** +Sets the default reranker model. + +**Usage** +``` +SET DEFAULT RERANKER ; +``` + +**Parameters** +- `reranker_id`: Reranker model identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT RERANKER 'bge-reranker-large'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_asr + +**Description** +Sets the default ASR (Automatic Speech Recognition) model. + +**Usage** +``` +SET DEFAULT ASR ; +``` + +**Parameters** +- `asr_id`: ASR model identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT ASR 'whisper-large'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### set_default_tts + +**Description** +Sets the default TTS (Text-to-Speech) model. + +**Usage** +``` +SET DEFAULT TTS ; +``` + +**Parameters** +- `tts_id`: TTS model identifier, quoted string. + +**Example** +``` +ragflow> SET DEFAULT TTS 'tts-1'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_llm + +**Description** +Resets the default LLM to system default. + +**Usage** +``` +RESET DEFAULT LLM; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT LLM; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_vlm + +**Description** +Resets the default VLM to system default. + +**Usage** +``` +RESET DEFAULT VLM; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT VLM; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_embedding + +**Description** +Resets the default embedding model to system default. + +**Usage** +``` +RESET DEFAULT EMBEDDING; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT EMBEDDING; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_reranker + +**Description** +Resets the default reranker model to system default. + +**Usage** +``` +RESET DEFAULT RERANKER; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT RERANKER; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_asr + +**Description** +Resets the default ASR model to system default. + +**Usage** +``` +RESET DEFAULT ASR; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT ASR; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### reset_default_tts + +**Description** +Resets the default TTS model to system default. + +**Usage** +``` +RESET DEFAULT TTS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> RESET DEFAULT TTS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### create_user_dataset_with_parser + +**Description** +Creates a user dataset with the specified parser. + +**Usage** +``` +CREATE DATASET WITH EMBEDDING PARSER ; +``` + +**Parameters** +- `dataset_name`: Dataset name, quoted string. +- `embedding`: Embedding model name, quoted string. +- `parser_type`: Parser type, quoted string. + +**Example** +``` +ragflow> CREATE DATASET 'my_dataset' WITH EMBEDDING 'text-embedding-ada-002' PARSER 'pdf'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### create_user_dataset_with_pipeline + +**Description** +Creates a user dataset with the specified pipeline. + +**Usage** +``` +CREATE DATASET WITH EMBEDDING PIPELINE ; +``` + +**Parameters** +- `dataset_name`: Dataset name, quoted string. +- `embedding`: Embedding model name, quoted string. +- `pipeline`: Pipeline name, quoted string. + +**Example** +``` +ragflow> CREATE DATASET 'my_dataset' WITH EMBEDDING 'text-embedding-ada-002' PIPELINE 'standard'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### drop_user_dataset + +**Description** +Deletes a user dataset. + +**Usage** +``` +DROP DATASET ; +``` + +**Parameters** +- `dataset_name`: Name of the dataset to delete, quoted string. + +**Example** +``` +ragflow> DROP DATASET 'my_dataset'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_datasets + +**Description** +Lists all datasets for the current user. + +**Usage** +``` +LIST DATASETS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> LIST DATASETS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_dataset_files + +**Description** +Lists all files in the specified dataset. + +**Usage** +``` +LIST FILES OF DATASET ; +``` + +**Parameters** +- `dataset_name`: Dataset name, quoted string. + +**Example** +``` +ragflow> LIST FILES OF DATASET 'my_dataset'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_agents + +**Description** +Lists all agents for the current user. + +**Usage** +``` +LIST AGENTS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> LIST AGENTS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_chats + +**Description** +Lists all chat sessions for the current user. + +**Usage** +``` +LIST CHATS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> LIST CHATS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### create_user_chat + +**Description** +Creates a new chat session. + +**Usage** +``` +CREATE CHAT ; +``` + +**Parameters** +- `chat_name`: Chat session name, quoted string. + +**Example** +``` +ragflow> CREATE CHAT 'my_chat'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### drop_user_chat + +**Description** +Deletes a chat session. + +**Usage** +``` +DROP CHAT ; +``` + +**Parameters** +- `chat_name`: Name of the chat session to delete, quoted string. + +**Example** +``` +ragflow> DROP CHAT 'my_chat'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_model_providers + +**Description** +Lists all model providers for the current user. + +**Usage** +``` +LIST MODEL PROVIDERS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> LIST MODEL PROVIDERS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### list_user_default_models + +**Description** +Lists all default model settings for the current user. + +**Usage** +``` +LIST DEFAULT MODELS; +``` + +**Parameters** +No parameters. + +**Example** +``` +ragflow> LIST DEFAULT MODELS; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### import_docs_into_dataset + +**Description** +Imports documents into the specified dataset. + +**Usage** +``` +IMPORT INTO DATASET ; +``` + +**Parameters** +- `document_list`: List of document paths, multiple paths can be separated by commas, or as a space-separated quoted string. +- `dataset_name`: Target dataset name, quoted string. + +**Example** +``` +ragflow> IMPORT '/path/to/doc1.pdf,/path/to/doc2.pdf' INTO DATASET 'my_dataset'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### search_on_datasets + +**Description** +Searches in one or more specified datasets. + +**Usage** +``` +SEARCH ON DATASETS ; +``` + +**Parameters** +- `question`: Search question, quoted string. +- `dataset_list`: List of dataset names, multiple names can be separated by commas, or as a space-separated quoted string. + +**Example** +``` +ragflow> SEARCH 'What is RAG?' ON DATASETS 'dataset1,dataset2'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### parse_dataset_docs + +**Description** +Parses specified documents in a dataset. + +**Usage** +``` +PARSE OF DATASET ; +``` + +**Parameters** +- `document_names`: List of document names, multiple names can be separated by commas, or as a space-separated quoted string. +- `dataset_name`: Dataset name, quoted string. + +**Example** +``` +ragflow> PARSE 'doc1.pdf,doc2.pdf' OF DATASET 'my_dataset'; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### parse_dataset_sync + +**Description** +Synchronously parses the entire dataset. + +**Usage** +``` +PARSE DATASET SYNC; +``` + +**Parameters** +- `dataset_name`: Dataset name, quoted string. + +**Example** +``` +ragflow> PARSE DATASET 'my_dataset' SYNC; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### parse_dataset_async + +**Description** +Asynchronously parses the entire dataset. + +**Usage** +``` +PARSE DATASET ASYNC; +``` + +**Parameters** +- `dataset_name`: Dataset name, quoted string. + +**Example** +``` +ragflow> PARSE DATASET 'my_dataset' ASYNC; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +### benchmark + +**Description** +Performs performance benchmark testing on the specified user command. + +**Usage** +``` +BENCHMARK ; +``` + +**Parameters** +- `concurrency`: Concurrency number, positive integer. +- `iterations`: Number of iterations, positive integer. +- `user_command`: User command to test (must be a valid user command, such as `PING;`). + +**Example** +``` +ragflow> BENCHMARK 5 10 PING; +``` + +**Display Effect** +(Sample output will be provided by the user) + +--- + +**Notes** +- All string parameters (such as names, IDs, paths) must be enclosed in single quotes (`'`) or double quotes (`"`). +- Commands must end with a semicolon (`;`). +- The prompt is `ragflow>`. diff --git a/admin/client/parser.py b/admin/client/parser.py index e2912b9e16f..11adfa81574 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -92,6 +92,7 @@ | drop_chat_session | list_chat_sessions | chat_on_session + | list_server_configs | benchmark // meta command definition @@ -176,6 +177,7 @@ PING: "PING"i SESSION: "SESSION"i SESSIONS: "SESSIONS"i +SERVER: "SERVER"i login_user: LOGIN USER quoted_string ";" list_services: LIST SERVICES ";" @@ -221,6 +223,8 @@ list_configs: LIST CONFIGS ";" list_environments: LIST ENVS ";" +list_server_configs: LIST SERVER CONFIGS ";" + benchmark: BENCHMARK NUMBER NUMBER user_statement user_statement: ping_server @@ -473,6 +477,9 @@ def list_configs(self, items): def list_environments(self, items): return {"type": "list_environments"} + def list_server_configs(self, items): + return {"type": "list_server_configs"} + def create_model_provider(self, items): provider_name = items[3].children[0].strip("'\"") provider_key = items[4].children[0].strip("'\"") diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 6927aac9077..480d320f107 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -583,6 +583,42 @@ def list_environments(self, command): else: print(f"Fail to list variables, code: {res_json['code']}, message: {res_json['message']}") + def list_server_configs(self, command): + """List server configs by calling /system/configs API and flattening the JSON response.""" + response = self.http_client.request("GET", "/system/configs", use_api_base=False, auth_kind="web") + res_json = response.json() + if res_json.get("code") != 0: + print(f"Fail to list server configs, code: {res_json.get('code')}, message: {res_json.get('message')}") + return + + data = res_json.get("data", {}) + if not data: + print("No server configs found") + return + + # Flatten nested JSON with a.b.c notation + def flatten(obj, parent_key=""): + items = [] + if isinstance(obj, dict): + for k, v in obj.items(): + new_key = f"{parent_key}.{k}" if parent_key else k + if isinstance(v, (dict, list)) and v: + items.extend(flatten(v, new_key)) + else: + items.append({"name": new_key, "value": v}) + elif isinstance(obj, list): + for i, v in enumerate(obj): + new_key = f"{parent_key}[{i}]" + if isinstance(v, (dict, list)) and v: + items.extend(flatten(v, new_key)) + else: + items.append({"name": new_key, "value": v}) + return items + + # Reconstruct flattened data and print using _print_table_simple + flattened = flatten(data) + self._print_table_simple(flattened) + def handle_list_datasets(self, command): if self.server_type != "admin": print("This command is only allowed in ADMIN mode") @@ -1478,6 +1514,8 @@ def run_command(client: RAGFlowClient, command_dict: dict): client.list_configs(command_dict) case "list_environments": client.list_environments(command_dict) + case "list_server_configs": + client.list_server_configs(command_dict) case "create_model_provider": client.create_model_provider(command_dict) case "drop_model_provider": diff --git a/build.sh b/build.sh new file mode 100755 index 00000000000..70fe162437b --- /dev/null +++ b/build.sh @@ -0,0 +1,196 @@ +#!/bin/bash +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +# Build directories +CPP_DIR="$PROJECT_ROOT/internal/cpp" +BUILD_DIR="$CPP_DIR/cmake-build-release" +OUTPUT_BINARY="$PROJECT_ROOT/bin/server_main" + +echo -e "${GREEN}=== RAGFlow Go Server Build Script ===${NC}" + +# Function to print section headers +print_section() { + echo -e "\n${YELLOW}>>> $1${NC}" +} + +# Check dependencies +check_cpp_deps() { + print_section "Checking c++ dependencies" + + command -v cmake >/dev/null 2>&1 || { echo -e "${RED}Error: cmake is required but not installed.${NC}"; exit 1; } + command -v g++ >/dev/null 2>&1 || { echo -e "${RED}Error: g++ is required but not installed.${NC}"; exit 1; } + + # Check for pcre2 library + if [ -f "/usr/lib/x86_64-linux-gnu/libpcre2-8.a" ] || [ -f "/usr/local/lib/libpcre2-8.a" ]; then + echo "✓ pcre2 library found" + else + echo -e "${YELLOW}Warning: libpcre2-8.a not found. You may need to install libpcre2-dev:${NC}" + echo " sudo apt-get install libpcre2-dev" + fi + + echo "✓ Required tools are available" +} + +check_go_deps() { + print_section "Checking go dependencies" + + command -v go >/dev/null 2>&1 || { echo -e "${RED}Error: go is required but not installed.${NC}"; exit 1; } + + echo "✓ Required tools are available" +} + +# Build C++ static library +build_cpp() { + print_section "Building C++ static library" + + mkdir -p "$BUILD_DIR" + cd "$BUILD_DIR" + + echo "Running cmake..." + cmake .. -DCMAKE_BUILD_TYPE=Release + + echo "Building librag_tokenizer_c_api.a..." + make rag_tokenizer_c_api -j$(nproc) + + if [ ! -f "$BUILD_DIR/librag_tokenizer_c_api.a" ]; then + echo -e "${RED}Error: Failed to build C++ static library${NC}" + exit 1 + fi + + echo -e "${GREEN}✓ C++ static library built successfully${NC}" +} + +# Build Go server +build_go() { + print_section "Building Go server" + + cd "$PROJECT_ROOT" + + # Check if C++ library exists + if [ ! -f "$BUILD_DIR/librag_tokenizer_c_api.a" ]; then + echo -e "${RED}Error: C++ static library not found. Run with --cpp first.${NC}" + exit 1 + fi + + # Check for pcre2 library + if [ -f "/usr/lib/x86_64-linux-gnu/libpcre2-8.a" ] || [ -f "/usr/local/lib/libpcre2-8.a" ]; then + echo "✓ pcre2 library found" + else + echo -e "${YELLOW}Warning: libpcre2-8.a not found. You may need to install libpcre2-dev:${NC}" + sudo apt -y install libpcre2-dev + fi + + echo "Building Go binary: $OUTPUT_BINARY" + GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$OUTPUT_BINARY" ./cmd/server_main.go + + if [ ! -f "$OUTPUT_BINARY" ]; then + echo -e "${RED}Error: Failed to build Go binary${NC}" + exit 1 + fi + + echo -e "${GREEN}✓ Go server built successfully: $OUTPUT_BINARY${NC}" +} + +# Clean build artifacts +clean() { + print_section "Cleaning build artifacts" + + rm -rf "$BUILD_DIR" + rm -f "$OUTPUT_BINARY" + + echo -e "${GREEN}✓ Build artifacts cleaned${NC}" +} + +# Run the server +run() { + if [ ! -f "$OUTPUT_BINARY" ]; then + echo -e "${RED}Error: Binary not found. Build first with --all or --go${NC}" + exit 1 + fi + + print_section "Starting server" + cd "$PROJECT_ROOT" + ./server_main +} + +# Show help +show_help() { + cat << EOF +Usage: $0 [OPTIONS] + +Build script for RAGFlow Go server with C++ bindings. + +OPTIONS: + --all, -a Build everything (C++ library + Go server) [default] + --cpp, -c Build only C++ static library + --go, -g Build only Go server (requires C++ library to be built) + --clean, -C Clean all build artifacts + --run, -r Build and run the server + --help, -h Show this help message + +EXAMPLES: + $0 # Build everything + $0 --cpp # Build only C++ library + $0 --go # Build only Go server + $0 --run # Build and run + $0 --clean # Clean build artifacts + +DEPENDENCIES: + - cmake >= 4.0 + - go >= 1.24 + - g++ with C++17/23 support + - libpcre2-dev +EOF +} + +# Main function +main() { + case "${1:-}" in + --cpp|-c) + check_cpp_deps + build_cpp + ;; + --go|-g) + check_go_deps + build_go + ;; + --clean|-C) + clean + ;; + --run|-r) + check_cpp_deps + check_go_deps + build_cpp + build_go + run + ;; + --help|-h) + show_help + ;; + --all|-a|"") + check_cpp_deps + check_go_deps + build_cpp + build_go + echo -e "\n${GREEN}=== Build completed successfully! ===${NC}" + echo "Binary: $OUTPUT_BINARY" + ;; + *) + echo -e "${RED}Unknown option: $1${NC}" + show_help + exit 1 + ;; + esac +} + +main "$@" diff --git a/cmd/ragflow_cli.go b/cmd/ragflow_cli.go new file mode 100644 index 00000000000..7af88e3acd3 --- /dev/null +++ b/cmd/ragflow_cli.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "ragflow/internal/cli" +) + +func main() { + // Create CLI instance + cliApp, err := cli.NewCLI() + if err != nil { + fmt.Printf("Failed to create CLI: %v\n", err) + os.Exit(1) + } + + // Handle interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + cliApp.Cleanup() + os.Exit(0) + }() + + // Run CLI + if err := cliApp.Run(); err != nil { + fmt.Printf("CLI error: %v\n", err) + os.Exit(1) + } +} diff --git a/cmd/server_main.go b/cmd/server_main.go new file mode 100644 index 00000000000..e079371e331 --- /dev/null +++ b/cmd/server_main.go @@ -0,0 +1,181 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "ragflow/internal/server" + "syscall" + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "ragflow/internal/cache" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/handler" + "ragflow/internal/logger" + "ragflow/internal/router" + "ragflow/internal/service" + "ragflow/internal/service/nlp" + "ragflow/internal/tokenizer" +) + +func main() { + // Initialize logger with default level + // logger.Init("info"); // set debug log level + if err := logger.Init("debug"); err != nil { + panic(fmt.Sprintf("Failed to initialize logger: %v", err)) + } + + // Initialize configuration + if err := server.Init(""); err != nil { + logger.Fatal("Failed to initialize config", zap.Error(err)) + } + + // Load model providers configuration + if err := server.LoadModelProviders(""); err != nil { + logger.Fatal("Failed to load model providers", zap.Error(err)) + } + logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders()))) + + cfg := server.GetConfig() + + // Reinitialize logger with configured level if different + if cfg.Log.Level != "" && cfg.Log.Level != "info" { + if err := logger.Init(cfg.Log.Level); err != nil { + logger.Error("Failed to reinitialize logger with configured level", err) + } + } + server.SetLogger(logger.Logger) + + logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) + + // Print all configuration settings + server.PrintAll() + + // Set Gin mode + if cfg.Server.Mode == "release" { + gin.SetMode(gin.ReleaseMode) + } else { + gin.SetMode(gin.DebugMode) + } + + // Initialize database + if err := dao.InitDB(); err != nil { + logger.Fatal("Failed to initialize database", zap.Error(err)) + } + + // Initialize doc engine + if err := engine.Init(&cfg.DocEngine); err != nil { + logger.Fatal("Failed to initialize doc engine", zap.Error(err)) + } + defer engine.Close() + + // Initialize Redis cache + if err := cache.Init(&cfg.Redis); err != nil { + logger.Fatal("Failed to initialize Redis", zap.Error(err)) + } + defer cache.Close() + + // Initialize server variables (runtime variables that can change during operation) + // This must be done after Cache is initialized + if err := server.InitVariables(cache.Get()); err != nil { + logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) + } + + // Initialize tokenizer (rag_analyzer) + tokenizerCfg := &tokenizer.PoolConfig{ + DictPath: "/usr/share/infinity/resource", + } + if err := tokenizer.Init(tokenizerCfg); err != nil { + logger.Fatal("Failed to initialize tokenizer", zap.Error(err)) + } + defer tokenizer.Close() + + // Initialize global QueryBuilder using tokenizer's DictPath + // This ensures the Synonym uses the same wordnet directory as tokenizer + if err := nlp.InitQueryBuilderFromTokenizer(tokenizerCfg.DictPath); err != nil { + logger.Fatal("Failed to initialize query builder", zap.Error(err)) + } + + // Initialize service layer + userService := service.NewUserService() + documentService := service.NewDocumentService() + kbService := service.NewKnowledgebaseService() + chunkService := service.NewChunkService() + llmService := service.NewLLMService() + tenantService := service.NewTenantService() + chatService := service.NewChatService() + chatSessionService := service.NewChatSessionService() + systemService := service.NewSystemService() + connectorService := service.NewConnectorService() + searchService := service.NewSearchService() + fileService := service.NewFileService() + + // Initialize handler layer + userHandler := handler.NewUserHandler(userService) + tenantHandler := handler.NewTenantHandler(tenantService, userService) + documentHandler := handler.NewDocumentHandler(documentService) + systemHandler := handler.NewSystemHandler(systemService) + kbHandler := handler.NewKnowledgebaseHandler(kbService, userService) + chunkHandler := handler.NewChunkHandler(chunkService, userService) + llmHandler := handler.NewLLMHandler(llmService, userService) + chatHandler := handler.NewChatHandler(chatService, userService) + chatSessionHandler := handler.NewChatSessionHandler(chatSessionService, userService) + connectorHandler := handler.NewConnectorHandler(connectorService, userService) + searchHandler := handler.NewSearchHandler(searchService, userService) + fileHandler := handler.NewFileHandler(fileService, userService) + + // Initialize router + r := router.NewRouter(userHandler, tenantHandler, documentHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler) + + // Create Gin engine + ginEngine := gin.New() + + // Middleware + if cfg.Server.Mode == "debug" { + ginEngine.Use(gin.Logger()) + } + ginEngine.Use(gin.Recovery()) + + // Setup routes + r.Setup(ginEngine) + + // Create HTTP server + addr := fmt.Sprintf(":%d", cfg.Server.Port) + srv := &http.Server{ + Addr: addr, + Handler: ginEngine, + } + + // Start server in a goroutine + go func() { + logger.Info(fmt.Sprintf("Server starting on port: %d", cfg.Server.Port)) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Fatal("Failed to start server", zap.Error(err)) + } + }() + + // Wait for interrupt signal to gracefully shutdown + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) + sig := <-quit + + logger.Info("Received signal", zap.String("signal", sig.String())) + logger.Info("Shutting down server...") + + // Create context with timeout for graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Shutdown server + if err := srv.Shutdown(ctx); err != nil { + logger.Fatal("Server forced to shutdown", zap.Error(err)) + } + + logger.Info("Server exited") +} diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 4fb5cbde3dd..5dd300a78d1 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -224,6 +224,7 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then echo "Starting ragflow_server..." while true; do "$PY" api/ragflow_server.py ${INIT_SUPERUSER_ARGS} & + bin/server_main & wait; sleep 1; done & diff --git a/go.mod b/go.mod new file mode 100644 index 00000000000..256f066ac63 --- /dev/null +++ b/go.mod @@ -0,0 +1,69 @@ +module ragflow + +go 1.24.0 + +require ( + github.com/elastic/go-elasticsearch/v8 v8.19.1 + github.com/gin-gonic/gin v1.9.1 + github.com/google/uuid v1.4.0 + github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a + github.com/redis/go-redis/v9 v9.18.0 + github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de + github.com/spf13/viper v1.18.2 + go.uber.org/zap v1.27.1 + golang.org/x/crypto v0.47.0 + gorm.io/driver/mysql v1.5.2 + gorm.io/gorm v1.25.5 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/elastic/elastic-transport-go/v8 v8.8.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.16.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.1.1 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/trace v1.28.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/arch v0.6.0 // indirect + golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect + google.golang.org/protobuf v1.32.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000000..6b405659dbb --- /dev/null +++ b/go.sum @@ -0,0 +1,176 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo= +github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/go-elasticsearch/v8 v8.19.1 h1:0iEGt5/Ds9MNVxEp3hqLsXdbe6SjleaVHONg/FuR09Q= +github.com/elastic/go-elasticsearch/v8 v8.19.1/go.mod h1:tHJQdInFa6abmDbDCEH2LJja07l/SIpaGpJcm13nt7s= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= +github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a h1:Inib12UR9HAfBubrGNraPjKt/Cu8xPbTJbC50+0wP5U= +github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a/go.mod h1:8N0Hlye5Lzw+H/yHWpZMkT0QLA+iOHG7KLdvAm95DZg= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= +github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de h1:1/P9CcR8iENN9ybbSRWohRd3rsPp9tEWlTS/7ygvjHE= +github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de/go.mod h1:TRwEEJlrSIv+jc66k48huOZ2aKVBPL8V29ZcsjUIH70= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= +go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= +go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= +go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= +go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= +go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.6.0 h1:S0JTfE48HbRj80+4tbvZDYsJ3tGv6BUU3XxyZ7CirAc= +golang.org/x/arch v0.6.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= +gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/cache/redis.go b/internal/cache/redis.go new file mode 100644 index 00000000000..36270e8b646 --- /dev/null +++ b/internal/cache/redis.go @@ -0,0 +1,996 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cache + +import ( + "context" + "encoding/json" + "fmt" + "math" + "math/rand" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" + + "ragflow/internal/logger" + "ragflow/internal/server" +) + +var ( + globalClient *RedisClient + once sync.Once +) + +// RedisClient wraps go-redis client with additional utility methods +type RedisClient struct { + client *redis.Client + luaDeleteIfEqual *redis.Script + luaTokenBucket *redis.Script + luaAutoIncrement *redis.Script + config *server.RedisConfig +} + +// RedisMsg represents a message from Redis Stream +type RedisMsg struct { + consumer *redis.Client + queueName string + groupName string + msgID string + message map[string]interface{} +} + +// Lua scripts +const ( + luaDeleteIfEqualScript = ` + local current_value = redis.call('get', KEYS[1]) + if current_value and current_value == ARGV[1] then + redis.call('del', KEYS[1]) + return 1 + end + return 0 + ` + + luaTokenBucketScript = ` + local key = KEYS[1] + local capacity = tonumber(ARGV[1]) + local rate = tonumber(ARGV[2]) + local now = tonumber(ARGV[3]) + local cost = tonumber(ARGV[4]) + + local data = redis.call("HMGET", key, "tokens", "timestamp") + local tokens = tonumber(data[1]) + local last_ts = tonumber(data[2]) + + if tokens == nil then + tokens = capacity + last_ts = now + end + + local delta = math.max(0, now - last_ts) + tokens = math.min(capacity, tokens + delta * rate) + + if tokens < cost then + return {0, tokens} + end + + tokens = tokens - cost + + redis.call("HMSET", key, + "tokens", tokens, + "timestamp", now + ) + + redis.call("EXPIRE", key, math.ceil(capacity / rate * 2)) + + return {1, tokens} + ` +) + +// Init initializes Redis client +func Init(cfg *server.RedisConfig) error { + var initErr error + once.Do(func() { + if cfg.Host == "" { + logger.Info("Redis host not configured, skipping Redis initialization") + return + } + + client := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Password: cfg.Password, + DB: cfg.DB, + }) + + // Test connection + ctx, cancel := context.WithTimeout(context.Background(), server.DefaultConnectTimeout) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + initErr = fmt.Errorf("failed to connect to Redis: %w", err) + return + } + + globalClient = &RedisClient{ + client: client, + config: cfg, + luaDeleteIfEqual: redis.NewScript(luaDeleteIfEqualScript), + luaTokenBucket: redis.NewScript(luaTokenBucketScript), + } + + logger.Info("Redis client initialized", + zap.String("host", cfg.Host), + zap.Int("port", cfg.Port), + zap.Int("db", cfg.DB), + ) + }) + return initErr +} + +// Get gets global Redis client instance +func Get() *RedisClient { + return globalClient +} + +// Close closes Redis client +func Close() error { + if globalClient != nil && globalClient.client != nil { + return globalClient.client.Close() + } + return nil +} + +// IsEnabled checks if Redis is enabled (configured and initialized) +func IsEnabled() bool { + return globalClient != nil && globalClient.client != nil +} + +// Health checks if Redis is healthy +func (r *RedisClient) Health() bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.Ping(ctx).Err(); err != nil { + return false + } + + testKey := "health_check_" + uuid.New().String() + testValue := "yy" + if err := r.client.Set(ctx, testKey, testValue, 3*time.Second).Err(); err != nil { + return false + } + + val, err := r.client.Get(ctx, testKey).Result() + if err != nil || val != testValue { + return false + } + return true +} + +// Info returns Redis server information +func (r *RedisClient) Info() map[string]interface{} { + if r.client == nil { + return nil + } + ctx := context.Background() + infoStr, err := r.client.Info(ctx).Result() + if err != nil { + logger.Warn("Failed to get Redis info", zap.Error(err)) + return nil + } + + // Parse info string to map + info := make(map[string]string) + lines := splitLines(infoStr) + for _, line := range lines { + if line == "" || line[0] == '#' { + continue + } + parts := splitN(line, ":", 2) + if len(parts) == 2 { + info[parts[0]] = parts[1] + } + } + + result := map[string]interface{}{ + "redis_version": info["redis_version"], + "server_mode": getServerMode(info), + "used_memory": info["used_memory_human"], + "total_system_memory": info["total_system_memory_human"], + "mem_fragmentation_ratio": info["mem_fragmentation_ratio"], + "connected_clients": parseInt(info["connected_clients"]), + "blocked_clients": parseInt(info["blocked_clients"]), + "instantaneous_ops_per_sec": parseInt(info["instantaneous_ops_per_sec"]), + "total_commands_processed": parseInt(info["total_commands_processed"]), + } + return result +} + +func getServerMode(info map[string]string) string { + if mode, ok := info["server_mode"]; ok { + return mode + } + return info["redis_mode"] +} + +func splitLines(s string) []string { + var lines []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + lines = append(lines, s[start:i]) + start = i + 1 + } + } + if start < len(s) { + lines = append(lines, s[start:]) + } + return lines +} + +func splitN(s, sep string, n int) []string { + if n <= 0 { + return []string{s} + } + idx := -1 + for i := 0; i < len(s)-len(sep)+1; i++ { + if s[i:i+len(sep)] == sep { + idx = i + break + } + } + if idx == -1 { + return []string{s} + } + return []string{s[:idx], s[idx+len(sep):]} +} + +func parseInt(s string) int { + v, _ := strconv.Atoi(s) + return v +} + +// IsAlive checks if Redis client is alive +func (r *RedisClient) IsAlive() bool { + return r.client != nil +} + +// Exist checks if key exists +func (r *RedisClient) Exist(key string) (bool, error) { + if r.client == nil { + return false, nil + } + ctx := context.Background() + exists, err := r.client.Exists(ctx, key).Result() + if err != nil { + logger.Warn("Redis Exist error", zap.String("key", key), zap.Error(err)) + return false, err + } + return exists > 0, nil +} + +// Get gets value by key +func (r *RedisClient) Get(key string) (string, error) { + if r.client == nil { + return "", nil + } + ctx := context.Background() + val, err := r.client.Get(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + if err != nil { + logger.Warn("Redis Get error", zap.String("key", key), zap.Error(err)) + return "", err + } + return val, nil +} + +// SetObj sets object with JSON serialization +func (r *RedisClient) SetObj(key string, obj interface{}, exp time.Duration) bool { + if r.client == nil { + return false + } + ctx := context.Background() + data, err := json.Marshal(obj) + if err != nil { + logger.Warn("Redis SetObj marshal error", zap.String("key", key), zap.Error(err)) + return false + } + if err := r.client.Set(ctx, key, data, exp).Err(); err != nil { + logger.Warn("Redis SetObj error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// GetObj gets and unmarshals object from Redis +func (r *RedisClient) GetObj(key string, dest interface{}) bool { + if r.client == nil { + return false + } + ctx := context.Background() + data, err := r.client.Get(ctx, key).Result() + if err == redis.Nil { + return false + } + if err != nil { + logger.Warn("Redis GetObj error", zap.String("key", key), zap.Error(err)) + return false + } + if err := json.Unmarshal([]byte(data), dest); err != nil { + logger.Warn("Redis GetObj unmarshal error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// Set sets value with expiration +func (r *RedisClient) Set(key string, value string, exp time.Duration) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.Set(ctx, key, value, exp).Err(); err != nil { + logger.Warn("Redis Set error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// SetNX sets value only if key does not exist +func (r *RedisClient) SetNX(key string, value string, exp time.Duration) bool { + if r.client == nil { + return false + } + ctx := context.Background() + ok, err := r.client.SetNX(ctx, key, value, exp).Result() + if err != nil { + logger.Warn("Redis SetNX error", zap.String("key", key), zap.Error(err)) + return false + } + return ok +} + +// GetOrCreateSecretKey atomically retrieves an existing key or creates a new one +// Uses Redis SETNX command to ensure atomicity across multiple goroutines/processes +func (r *RedisClient) GetOrCreateKey(key string, value string) (string, error) { + if r.client == nil { + return "", nil + } + ctx := context.Background() + // First, try to get the existing key + existingKey, err := r.client.Get(ctx, key).Result() + if err == nil { + logger.Warn("Redis Get error", zap.String("key", key), zap.Error(err)) + // Successfully retrieved existing key + return existingKey, nil + } + + // Use SETNX to atomically set the key only if it doesn't exist + // SETNX returns true if the key was set, false if it already existed + success, err := r.client.SetNX(ctx, key, value, 0).Result() + if err != nil { + return "", fmt.Errorf("failed to set key in Redis: %v", err) + } + + if success { + // This goroutine successfully set the key + return value, nil + } + + // SETNX failed, meaning another goroutine set the key concurrently + // Retrieve and return that key + finalKey, err := r.client.Get(ctx, key).Result() + if err != nil { + return "", fmt.Errorf("failed to get key set by another process: %v", err) + } + + return finalKey, nil +} + +// SAdd adds member to set +func (r *RedisClient) SAdd(key string, member string) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.SAdd(ctx, key, member).Err(); err != nil { + logger.Warn("Redis SAdd error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// SRem removes member from set +func (r *RedisClient) SRem(key string, member string) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.SRem(ctx, key, member).Err(); err != nil { + logger.Warn("Redis SRem error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// SMembers returns all members of a set +func (r *RedisClient) SMembers(key string) ([]string, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + members, err := r.client.SMembers(ctx, key).Result() + if err != nil { + logger.Warn("Redis SMembers error", zap.String("key", key), zap.Error(err)) + return nil, err + } + return members, nil +} + +// SIsMember checks if member exists in set +func (r *RedisClient) SIsMember(key string, member string) bool { + if r.client == nil { + return false + } + ctx := context.Background() + ok, err := r.client.SIsMember(ctx, key, member).Result() + if err != nil { + logger.Warn("Redis SIsMember error", zap.String("key", key), zap.Error(err)) + return false + } + return ok +} + +// ZAdd adds member with score to sorted set +func (r *RedisClient) ZAdd(key string, member string, score float64) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err(); err != nil { + logger.Warn("Redis ZAdd error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// ZCount returns count of members with score in range +func (r *RedisClient) ZCount(key string, min, max float64) int64 { + if r.client == nil { + return 0 + } + ctx := context.Background() + count, err := r.client.ZCount(ctx, key, fmt.Sprintf("%f", min), fmt.Sprintf("%f", max)).Result() + if err != nil { + logger.Warn("Redis ZCount error", zap.String("key", key), zap.Error(err)) + return 0 + } + return count +} + +// ZPopMin pops minimum score members from sorted set +func (r *RedisClient) ZPopMin(key string, count int) ([]redis.Z, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + members, err := r.client.ZPopMin(ctx, key, int64(count)).Result() + if err != nil { + logger.Warn("Redis ZPopMin error", zap.String("key", key), zap.Error(err)) + return nil, err + } + return members, nil +} + +// ZRangeByScore returns members with score in range +func (r *RedisClient) ZRangeByScore(key string, min, max float64) ([]string, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + members, err := r.client.ZRangeByScore(ctx, key, &redis.ZRangeBy{ + Min: fmt.Sprintf("%f", min), + Max: fmt.Sprintf("%f", max), + }).Result() + if err != nil { + logger.Warn("Redis ZRangeByScore error", zap.String("key", key), zap.Error(err)) + return nil, err + } + return members, nil +} + +// ZRemRangeByScore removes members with score in range +func (r *RedisClient) ZRemRangeByScore(key string, min, max float64) int64 { + if r.client == nil { + return 0 + } + ctx := context.Background() + count, err := r.client.ZRemRangeByScore(ctx, key, fmt.Sprintf("%f", min), fmt.Sprintf("%f", max)).Result() + if err != nil { + logger.Warn("Redis ZRemRangeByScore error", zap.String("key", key), zap.Error(err)) + return 0 + } + return count +} + +// IncrBy increments key by increment +func (r *RedisClient) IncrBy(key string, increment int64) (int64, error) { + if r.client == nil { + return 0, nil + } + ctx := context.Background() + val, err := r.client.IncrBy(ctx, key, increment).Result() + if err != nil { + logger.Warn("Redis IncrBy error", zap.String("key", key), zap.Error(err)) + return 0, err + } + return val, nil +} + +// DecrBy decrements key by decrement +func (r *RedisClient) DecrBy(key string, decrement int64) (int64, error) { + if r.client == nil { + return 0, nil + } + ctx := context.Background() + val, err := r.client.DecrBy(ctx, key, decrement).Result() + if err != nil { + logger.Warn("Redis DecrBy error", zap.String("key", key), zap.Error(err)) + return 0, err + } + return val, nil +} + +// GenerateAutoIncrementID generates auto-increment ID +func (r *RedisClient) GenerateAutoIncrementID(keyPrefix string, namespace string, increment int64, ensureMinimum *int64) int64 { + if r.client == nil { + return -1 + } + if keyPrefix == "" { + keyPrefix = "id_generator" + } + if namespace == "" { + namespace = "default" + } + if increment == 0 { + increment = 1 + } + + redisKey := fmt.Sprintf("%s:%s", keyPrefix, namespace) + ctx := context.Background() + + // Check if key exists + exists, err := r.client.Exists(ctx, redisKey).Result() + if err != nil { + logger.Warn("Redis GenerateAutoIncrementID error", zap.Error(err)) + return -1 + } + + if exists == 0 && ensureMinimum != nil { + startID := int64(math.Max(1, float64(*ensureMinimum))) + r.client.Set(ctx, redisKey, startID, 0) + return startID + } + + // Get current value + if ensureMinimum != nil { + current, err := r.client.Get(ctx, redisKey).Int64() + if err == nil && current < *ensureMinimum { + r.client.Set(ctx, redisKey, *ensureMinimum, 0) + return *ensureMinimum + } + } + + // Increment + nextID, err := r.client.IncrBy(ctx, redisKey, increment).Result() + if err != nil { + logger.Warn("Redis GenerateAutoIncrementID increment error", zap.Error(err)) + return -1 + } + + return nextID +} + +// Transaction sets key with NX flag (transaction-like behavior) +func (r *RedisClient) Transaction(key string, value string, exp time.Duration) bool { + if r.client == nil { + return false + } + ctx := context.Background() + pipe := r.client.Pipeline() + pipe.SetNX(ctx, key, value, exp) + _, err := pipe.Exec(ctx) + if err != nil { + logger.Warn("Redis Transaction error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// QueueProduct produces a message to Redis Stream +func (r *RedisClient) QueueProduct(queue string, message interface{}) bool { + if r.client == nil { + return false + } + ctx := context.Background() + + for i := 0; i < 3; i++ { + data, err := json.Marshal(message) + if err != nil { + logger.Warn("Redis QueueProduct marshal error", zap.Error(err)) + return false + } + + _, err = r.client.XAdd(ctx, &redis.XAddArgs{ + Stream: queue, + Values: map[string]interface{}{"message": string(data)}, + }).Result() + if err == nil { + return true + } + logger.Warn("Redis QueueProduct error", zap.String("queue", queue), zap.Error(err)) + time.Sleep(100 * time.Millisecond) + } + return false +} + +// QueueConsumer consumes a message from Redis Stream +func (r *RedisClient) QueueConsumer(queueName, groupName, consumerName string, msgID string) (*RedisMsg, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + + for i := 0; i < 3; i++ { + // Create consumer group if not exists + groups, err := r.client.XInfoGroups(ctx, queueName).Result() + if err != nil && err.Error() != "no such key" { + logger.Warn("Redis QueueConsumer XInfoGroups error", zap.Error(err)) + } + + groupExists := false + for _, g := range groups { + if g.Name == groupName { + groupExists = true + break + } + } + + if !groupExists { + err = r.client.XGroupCreateMkStream(ctx, queueName, groupName, "0").Err() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + logger.Warn("Redis QueueConsumer XGroupCreate error", zap.Error(err)) + } + } + + if msgID == "" { + msgID = ">" + } + + messages, err := r.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{queueName, msgID}, + Count: 1, + Block: 5 * time.Second, + }).Result() + + if err == redis.Nil { + return nil, nil + } + if err != nil { + logger.Warn("Redis QueueConsumer XReadGroup error", zap.Error(err)) + time.Sleep(100 * time.Millisecond) + continue + } + + if len(messages) == 0 || len(messages[0].Messages) == 0 { + return nil, nil + } + + msg := messages[0].Messages[0] + var messageData map[string]interface{} + if msgStr, ok := msg.Values["message"].(string); ok { + json.Unmarshal([]byte(msgStr), &messageData) + } + + return &RedisMsg{ + consumer: r.client, + queueName: queueName, + groupName: groupName, + msgID: msg.ID, + message: messageData, + }, nil + } + return nil, nil +} + +// Ack acknowledges the message +func (m *RedisMsg) Ack() bool { + if m.consumer == nil { + return false + } + ctx := context.Background() + err := m.consumer.XAck(ctx, m.queueName, m.groupName, m.msgID).Err() + if err != nil { + logger.Warn("RedisMsg Ack error", zap.Error(err)) + return false + } + return true +} + +// GetMessage returns the message data +func (m *RedisMsg) GetMessage() map[string]interface{} { + return m.message +} + +// GetMsgID returns the message ID +func (m *RedisMsg) GetMsgID() string { + return m.msgID +} + +// GetPendingMsg gets pending messages +func (r *RedisClient) GetPendingMsg(queue, groupName string) ([]redis.XPendingExt, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + msgs, err := r.client.XPendingExt(ctx, &redis.XPendingExtArgs{ + Stream: queue, + Group: groupName, + Start: "-", + End: "+", + Count: 10, + }).Result() + if err != nil { + if err.Error() != "No such key" { + logger.Warn("Redis GetPendingMsg error", zap.Error(err)) + } + return nil, err + } + return msgs, nil +} + +// RequeueMsg requeues a message +func (r *RedisClient) RequeueMsg(queue, groupName, msgID string) { + if r.client == nil { + return + } + ctx := context.Background() + + for i := 0; i < 3; i++ { + msgs, err := r.client.XRange(ctx, queue, msgID, msgID).Result() + if err != nil { + logger.Warn("Redis RequeueMsg XRange error", zap.Error(err)) + time.Sleep(100 * time.Millisecond) + continue + } + if len(msgs) == 0 { + return + } + + r.client.XAdd(ctx, &redis.XAddArgs{ + Stream: queue, + Values: msgs[0].Values, + }) + r.client.XAck(ctx, queue, groupName, msgID) + return + } +} + +// QueueInfo returns queue group info +func (r *RedisClient) QueueInfo(queue, groupName string) (map[string]interface{}, error) { + if r.client == nil { + return nil, nil + } + ctx := context.Background() + + for i := 0; i < 3; i++ { + groups, err := r.client.XInfoGroups(ctx, queue).Result() + if err != nil { + logger.Warn("Redis QueueInfo error", zap.Error(err)) + time.Sleep(100 * time.Millisecond) + continue + } + + for _, g := range groups { + if g.Name == groupName { + return map[string]interface{}{ + "name": g.Name, + "consumers": g.Consumers, + "pending": g.Pending, + "last_delivered": g.LastDeliveredID, + }, nil + } + } + return nil, nil + } + return nil, nil +} + +// DeleteIfEqual deletes key if its value equals expected value (atomic) +func (r *RedisClient) DeleteIfEqual(key, expectedValue string) bool { + if r.client == nil { + return false + } + ctx := context.Background() + result, err := r.luaDeleteIfEqual.Run(ctx, r.client, []string{key}, expectedValue).Result() + if err != nil { + logger.Warn("Redis DeleteIfEqual error", zap.Error(err)) + return false + } + return result.(int64) == 1 +} + +// Delete deletes a key +func (r *RedisClient) Delete(key string) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.Del(ctx, key).Err(); err != nil { + logger.Warn("Redis Delete error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// Expire sets expiration on a key +func (r *RedisClient) Expire(key string, exp time.Duration) bool { + if r.client == nil { + return false + } + ctx := context.Background() + if err := r.client.Expire(ctx, key, exp).Err(); err != nil { + logger.Warn("Redis Expire error", zap.String("key", key), zap.Error(err)) + return false + } + return true +} + +// TTL gets remaining time to live of a key +func (r *RedisClient) TTL(key string) time.Duration { + if r.client == nil { + return -2 + } + ctx := context.Background() + ttl, err := r.client.TTL(ctx, key).Result() + if err != nil { + logger.Warn("Redis TTL error", zap.String("key", key), zap.Error(err)) + return -2 + } + return ttl +} + +// DistributedLock distributed lock implementation +type DistributedLock struct { + client *RedisClient + lockKey string + lockValue string + timeout time.Duration + blockingTimeout time.Duration +} + +// NewDistributedLock creates a new distributed lock +func NewDistributedLock(lockKey string, lockValue string, timeout time.Duration, blockingTimeout time.Duration) *DistributedLock { + if globalClient == nil { + return nil + } + if lockValue == "" { + lockValue = uuid.New().String() + } + return &DistributedLock{ + client: globalClient, + lockKey: lockKey, + lockValue: lockValue, + timeout: timeout, + blockingTimeout: blockingTimeout, + } +} + +// Acquire acquires the lock +func (l *DistributedLock) Acquire() bool { + if l.client == nil { + return false + } + // Delete if stale + l.client.DeleteIfEqual(l.lockKey, l.lockValue) + return l.client.SetNX(l.lockKey, l.lockValue, l.timeout) +} + +// SpinAcquire keeps trying to acquire the lock +func (l *DistributedLock) SpinAcquire(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + l.client.DeleteIfEqual(l.lockKey, l.lockValue) + if l.client.SetNX(l.lockKey, l.lockValue, l.timeout) { + return nil + } + time.Sleep(10 * time.Second) + } + } +} + +// Release releases the lock +func (l *DistributedLock) Release() bool { + if l.client == nil { + return false + } + return l.client.DeleteIfEqual(l.lockKey, l.lockValue) +} + +// TokenBucket token bucket rate limiter +type TokenBucket struct { + client *RedisClient + key string + capacity float64 + rate float64 +} + +// NewTokenBucket creates a new token bucket +func NewTokenBucket(key string, capacity, rate float64) *TokenBucket { + if globalClient == nil { + return nil + } + return &TokenBucket{ + client: globalClient, + key: key, + capacity: capacity, + rate: rate, + } +} + +// Allow checks if request is allowed +func (tb *TokenBucket) Allow(cost float64) (bool, float64) { + if tb.client == nil || tb.client.client == nil { + return true, 0 + } + ctx := context.Background() + now := float64(time.Now().Unix()) + + result, err := tb.client.luaTokenBucket.Run(ctx, tb.client.client, []string{tb.key}, + tb.capacity, tb.rate, now, cost).Result() + if err != nil { + logger.Warn("TokenBucket Allow error", zap.Error(err)) + return true, 0 + } + + values := result.([]interface{}) + allowed := values[0].(int64) == 1 + tokens := values[1].(int64) + return allowed, float64(tokens) +} + +// GetClient returns the underlying go-redis client for advanced usage +func (r *RedisClient) GetClient() *redis.Client { + return r.client +} + +// RandomSleep sleeps for random duration between min and max milliseconds +func RandomSleep(minMs, maxMs int) { + duration := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond + time.Sleep(duration) +} diff --git a/internal/cli/README.md b/internal/cli/README.md new file mode 100644 index 00000000000..4f71a37de5b --- /dev/null +++ b/internal/cli/README.md @@ -0,0 +1,87 @@ +# RAGFlow CLI (Go Version) + +This is the Go implementation of the RAGFlow command-line interface, compatible with the Python version's syntax. + +## Features + +- Interactive mode only +- Full compatibility with Python CLI syntax +- Recursive descent parser for SQL-like commands +- Support for all major commands: + - User management: LOGIN, REGISTER, CREATE USER, DROP USER, LIST USERS, etc. + - Service management: LIST SERVICES, SHOW SERVICE, STARTUP/SHUTDOWN/RESTART SERVICE + - Role management: CREATE ROLE, DROP ROLE, LIST ROLES, GRANT/REVOKE PERMISSION + - Dataset management: CREATE DATASET, DROP DATASET, LIST DATASETS + - Model management: SET/RESET DEFAULT LLM/VLM/EMBEDDING/etc. + - And more... + +## Usage + +Build and run: + +```bash +go build -o ragflow_cli ./cmd/ragflow_cli.go +./ragflow_cli +``` + +## Architecture + +``` +internal/cli/ +├── cli.go # Main CLI loop and interaction +├── parser/ # Command parser package +│ ├── types.go # Token and Command types +│ ├── lexer.go # Lexical analyzer +│ └── parser.go # Recursive descent parser +``` + +## Command Examples + +```sql +-- Authentication +LOGIN USER 'admin@example.com'; + +-- User management +REGISTER USER 'john' AS 'John Doe' PASSWORD 'secret'; +CREATE USER 'jane' 'password123'; +DROP USER 'jane'; +LIST USERS; +SHOW USER 'john'; + +-- Service management +LIST SERVICES; +SHOW SERVICE 1; +STARTUP SERVICE 1; +SHUTDOWN SERVICE 1; +RESTART SERVICE 1; +PING; + +-- Role management +CREATE ROLE admin DESCRIPTION 'Administrator role'; +LIST ROLES; +GRANT read,write ON datasets TO ROLE admin; + +-- Dataset management +CREATE DATASET 'my_dataset' WITH EMBEDDING 'text-embedding-ada-002' PARSER 'naive'; +LIST DATASETS; +DROP DATASET 'my_dataset'; + +-- Model configuration +SET DEFAULT LLM 'gpt-4'; +SET DEFAULT EMBEDDING 'text-embedding-ada-002'; +RESET DEFAULT LLM; + +-- Meta commands +\? -- Show help +\q -- Quit +\c -- Clear screen +``` + +## Parser Implementation + +The parser uses a hand-written recursive descent approach instead of go-yacc for: +- Better control over error messages +- Easier to extend and maintain +- No code generation step required + +The parser structure follows the grammar defined in the Python version, ensuring full syntax compatibility. diff --git a/internal/cli/benchmark.go b/internal/cli/benchmark.go new file mode 100644 index 00000000000..872c830e3e4 --- /dev/null +++ b/internal/cli/benchmark.go @@ -0,0 +1,318 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "fmt" + "strings" + "sync" + "time" +) + +// BenchmarkResult holds the result of a benchmark run +type BenchmarkResult struct { + Duration float64 + TotalCommands int + SuccessCount int + FailureCount int + QPS float64 + ResponseList []*Response +} + +// RunBenchmark runs a benchmark with the given concurrency and iterations +func (c *RAGFlowClient) RunBenchmark(cmd *Command) error { + concurrency, ok := cmd.Params["concurrency"].(int) + if !ok { + concurrency = 1 + } + + iterations, ok := cmd.Params["iterations"].(int) + if !ok { + iterations = 1 + } + + nestedCmd, ok := cmd.Params["command"].(*Command) + if !ok { + return fmt.Errorf("benchmark command not found") + } + + if concurrency < 1 { + return fmt.Errorf("concurrency must be greater than 0") + } + + // Add iterations to the nested command + nestedCmd.Params["iterations"] = iterations + + if concurrency == 1 { + return c.runBenchmarkSingle(concurrency, iterations, nestedCmd) + } + return c.runBenchmarkConcurrent(concurrency, iterations, nestedCmd) +} + +// runBenchmarkSingle runs benchmark with single concurrency (sequential execution) +func (c *RAGFlowClient) runBenchmarkSingle(concurrency, iterations int, nestedCmd *Command) error { + commandType := nestedCmd.Type + + startTime := time.Now() + responseList := make([]*Response, 0, iterations) + + // For search_on_datasets, convert dataset names to IDs first + if commandType == "search_on_datasets" && iterations > 1 { + datasets, _ := nestedCmd.Params["datasets"].(string) + datasetNames := strings.Split(datasets, ",") + datasetIDs := make([]string, 0, len(datasetNames)) + for _, name := range datasetNames { + name = strings.TrimSpace(name) + id, err := c.getDatasetID(name) + if err != nil { + return err + } + datasetIDs = append(datasetIDs, id) + } + nestedCmd.Params["dataset_ids"] = datasetIDs + } + + // Check if command supports native benchmark (iterations > 1) + supportsNative := false + if iterations > 1 { + result, err := c.ExecuteCommand(nestedCmd) + if err == nil && result != nil { + // Command supports benchmark natively + supportsNative = true + duration, _ := result["duration"].(float64) + respList, _ := result["response_list"].([]*Response) + responseList = respList + + // Calculate and print results + successCount := 0 + for _, resp := range responseList { + if isSuccess(resp, commandType) { + successCount++ + } + } + + qps := float64(0) + if duration > 0 { + qps = float64(iterations) / duration + } + + fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) + fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", + duration, qps, iterations, successCount, iterations-successCount) + return nil + } + } + + // Manual execution: run iterations times + if !supportsNative { + // Remove iterations param to avoid native benchmark + delete(nestedCmd.Params, "iterations") + + for i := 0; i < iterations; i++ { + singleResult, err := c.ExecuteCommand(nestedCmd) + if err != nil { + // Command failed, add a failed response + responseList = append(responseList, &Response{StatusCode: 0}) + continue + } + + // For commands that return a single response (like ping with iterations=1) + if singleResult != nil { + if respList, ok := singleResult["response_list"].([]*Response); ok { + responseList = append(responseList, respList...) + } + } else { + // Command executed successfully but returned no data + // Mark as success for now + responseList = append(responseList, &Response{StatusCode: 200, Body: []byte("pong")}) + } + } + } + + duration := time.Since(startTime).Seconds() + + successCount := 0 + for _, resp := range responseList { + if isSuccess(resp, commandType) { + successCount++ + } + } + + qps := float64(0) + if duration > 0 { + qps = float64(iterations) / duration + } + + // Print results + fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) + fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", + duration, qps, iterations, successCount, iterations-successCount) + + return nil +} + +// runBenchmarkConcurrent runs benchmark with multiple concurrent workers +func (c *RAGFlowClient) runBenchmarkConcurrent(concurrency, iterations int, nestedCmd *Command) error { + results := make([]map[string]interface{}, concurrency) + var wg sync.WaitGroup + + // For search_on_datasets, convert dataset names to IDs first + if nestedCmd.Type == "search_on_datasets" { + datasets, _ := nestedCmd.Params["datasets"].(string) + datasetNames := strings.Split(datasets, ",") + datasetIDs := make([]string, 0, len(datasetNames)) + for _, name := range datasetNames { + name = strings.TrimSpace(name) + id, err := c.getDatasetID(name) + if err != nil { + return err + } + datasetIDs = append(datasetIDs, id) + } + nestedCmd.Params["dataset_ids"] = datasetIDs + } + + startTime := time.Now() + + // Launch concurrent workers + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + // Create a new client for each goroutine to avoid race conditions + workerClient := NewRAGFlowClient(c.ServerType) + workerClient.HTTPClient = c.HTTPClient // Share the same HTTP client config + + // Execute benchmark silently (no output) + responseList := workerClient.executeBenchmarkSilent(nestedCmd, iterations) + + results[idx] = map[string]interface{}{ + "duration": 0.0, + "response_list": responseList, + } + }(i) + } + + wg.Wait() + endTime := time.Now() + + totalDuration := endTime.Sub(startTime).Seconds() + successCount := 0 + commandType := nestedCmd.Type + + for _, result := range results { + if result == nil { + continue + } + responseList, _ := result["response_list"].([]*Response) + for _, resp := range responseList { + if isSuccess(resp, commandType) { + successCount++ + } + } + } + + totalCommands := iterations * concurrency + qps := float64(0) + if totalDuration > 0 { + qps = float64(totalCommands) / totalDuration + } + + // Print results + fmt.Printf("command: %s, Concurrency: %d, iterations: %d\n", commandType, concurrency, iterations) + fmt.Printf("total duration: %.4fs, QPS: %.2f, COMMAND_COUNT: %d, SUCCESS: %d, FAILURE: %d\n", + totalDuration, qps, totalCommands, successCount, totalCommands-successCount) + + return nil +} + +// executeBenchmarkSilent executes a command for benchmark without printing output +func (c *RAGFlowClient) executeBenchmarkSilent(cmd *Command, iterations int) []*Response { + responseList := make([]*Response, 0, iterations) + + for i := 0; i < iterations; i++ { + var resp *Response + var err error + + switch cmd.Type { + case "ping_server": + resp, err = c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) + case "list_user_datasets": + resp, err = c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) + case "list_datasets": + userName, _ := cmd.Params["user_name"].(string) + resp, err = c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil) + case "search_on_datasets": + question, _ := cmd.Params["question"].(string) + datasetIDs, _ := cmd.Params["dataset_ids"].([]string) + payload := map[string]interface{}{ + "kb_id": datasetIDs, + "question": question, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + } + resp, err = c.HTTPClient.Request("POST", "/chunk/retrieval_test", false, "web", nil, payload) + default: + // For other commands, we would need to add specific handling + // For now, mark as failed + resp = &Response{StatusCode: 0} + } + + if err != nil { + resp = &Response{StatusCode: 0} + } + + responseList = append(responseList, resp) + } + + return responseList +} + +// isSuccess checks if a response is successful based on command type +func isSuccess(resp *Response, commandType string) bool { + if resp == nil { + return false + } + + switch commandType { + case "ping_server": + return resp.StatusCode == 200 && string(resp.Body) == "pong" + case "list_user_datasets", "list_datasets", "search_on_datasets": + // Check status code and JSON response code for dataset commands + if resp.StatusCode != 200 { + return false + } + resJSON, err := resp.JSON() + if err != nil { + return false + } + code, ok := resJSON["code"].(float64) + return ok && code == 0 + default: + // For other commands, check status code and response code + if resp.StatusCode != 200 { + return false + } + resJSON, err := resp.JSON() + if err != nil { + return false + } + code, ok := resJSON["code"].(float64) + return ok && code == 0 + } +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go new file mode 100644 index 00000000000..14edea1b2c5 --- /dev/null +++ b/internal/cli/cli.go @@ -0,0 +1,140 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +// CLI represents the command line interface +type CLI struct { + parser *Parser + client *RAGFlowClient + prompt string + running bool +} + +// NewCLI creates a new CLI instance +func NewCLI() (*CLI, error) { + return &CLI{ + prompt: "RAGFlow> ", + client: NewRAGFlowClient("user"), // Default to user mode + }, nil +} + +// Run starts the interactive CLI +func (c *CLI) Run() error { + c.running = true + scanner := bufio.NewScanner(os.Stdin) + + fmt.Println("Welcome to RAGFlow CLI") + fmt.Println("Type \\? for help, \\q to quit") + fmt.Println() + + for c.running { + fmt.Print(c.prompt) + + if !scanner.Scan() { + break + } + + input := scanner.Text() + input = strings.TrimSpace(input) + + if input == "" { + continue + } + + if err := c.execute(input); err != nil { + fmt.Printf("Error: %v\n", err) + } + } + + return scanner.Err() +} + +func (c *CLI) execute(input string) error { + p := NewParser(input) + cmd, err := p.Parse() + if err != nil { + return err + } + + if cmd == nil { + return nil + } + + // Handle meta commands + if cmd.Type == "meta" { + return c.handleMetaCommand(cmd) + } + + // Execute the command using the client + _, err = c.client.ExecuteCommand(cmd) + return err +} + +func (c *CLI) handleMetaCommand(cmd *Command) error { + command := cmd.Params["command"].(string) + + switch command { + case "q", "quit", "exit": + fmt.Println("Goodbye!") + c.running = false + case "?", "h", "help": + c.printHelp() + case "c", "clear": + // Clear screen (simple approach) + fmt.Print("\033[H\033[2J") + default: + return fmt.Errorf("unknown meta command: \\%s", command) + } + return nil +} + +func (c *CLI) printHelp() { + help := ` +RAGFlow CLI Help +================ + +SQL Commands: + LOGIN USER 'email'; - Login as user + REGISTER USER 'name' AS 'nickname' PASSWORD 'pwd'; - Register new user + SHOW VERSION; - Show version info + SHOW CURRENT USER; - Show current user + LIST USERS; - List all users + LIST SERVICES; - List services + PING; - Ping server + ... and many more + +Meta Commands: + \\? or \\h - Show this help + \\q or \\quit - Exit CLI + \\c or \\clear - Clear screen + +For more information, see documentation. +` + fmt.Println(help) +} + +// Cleanup performs cleanup before exit +func (c *CLI) Cleanup() { + fmt.Println("\nCleaning up...") +} diff --git a/internal/cli/client.go b/internal/cli/client.go new file mode 100644 index 00000000000..d4cd8dc2c37 --- /dev/null +++ b/internal/cli/client.go @@ -0,0 +1,496 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "unsafe" +) + +// RAGFlowClient handles API interactions with the RAGFlow server +type RAGFlowClient struct { + HTTPClient *HTTPClient + ServerType string // "admin" or "user" +} + +// NewRAGFlowClient creates a new RAGFlow client +func NewRAGFlowClient(serverType string) *RAGFlowClient { + return &RAGFlowClient{ + HTTPClient: NewHTTPClient(), + ServerType: serverType, + } +} + +// LoginUser performs user login +func (c *RAGFlowClient) LoginUser(cmd *Command) error { + // First, ping the server to check if it's available + resp, err := c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + if resp.StatusCode != 200 || string(resp.Body) != "pong" { + fmt.Println("Server is down") + return fmt.Errorf("server is down") + } + + email, ok := cmd.Params["email"].(string) + if !ok { + return fmt.Errorf("email not provided") + } + + // Get password from user input (hidden) + fmt.Printf("password for %s: ", email) + password, err := readPassword() + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + password = strings.TrimSpace(password) + + // Login + token, err := c.loginUser(email, password) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Can't access server for login (connection failed)") + return err + } + + c.HTTPClient.LoginToken = token + fmt.Printf("Login user %s successfully\n", email) + return nil +} + +// loginUser performs the actual login request +func (c *RAGFlowClient) loginUser(email, password string) (string, error) { + // Encrypt password using scrypt (same as Python implementation) + encryptedPassword, err := EncryptPassword(password) + if err != nil { + return "", fmt.Errorf("failed to encrypt password: %w", err) + } + + payload := map[string]interface{}{ + "email": email, + "password": encryptedPassword, + } + + var path string + if c.ServerType == "admin" { + path = "/admin/login" + } else { + path = "/user/login" + } + + resp, err := c.HTTPClient.Request("POST", path, c.ServerType == "admin", "", nil, payload) + if err != nil { + return "", err + } + + resJSON, err := resp.JSON() + if err != nil { + return "", fmt.Errorf("login failed: invalid JSON response (%w)", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return "", fmt.Errorf("login failed: %s", msg) + } + + token := resp.Headers.Get("Authorization") + if token == "" { + return "", fmt.Errorf("login failed: missing Authorization header") + } + + return token, nil +} + +// PingServer pings the server to check if it's alive +// Returns benchmark result map if iterations > 1, otherwise prints status +func (c *RAGFlowClient) PingServer(cmd *Command) (map[string]interface{}, error) { + // Get iterations from command params (for benchmark) + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode: multiple iterations + result, err := c.HTTPClient.RequestWithIterations("GET", "/system/ping", false, "web", nil, nil, iterations) + if err != nil { + return nil, err + } + return result, nil + } + + // Single ping mode + resp, err := c.HTTPClient.Request("GET", "/system/ping", false, "web", nil, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + fmt.Println("Server is down") + return nil, err + } + + if resp.StatusCode == 200 && string(resp.Body) == "pong" { + fmt.Println("Server is alive") + } else { + fmt.Printf("Error: %d\n", resp.StatusCode) + } + return nil, nil +} + +// ListUserDatasets lists datasets for current user (user mode) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListUserDatasets(cmd *Command) (map[string]interface{}, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("POST", "/kb/list", false, "web", nil, nil, iterations) + } + + // Normal mode + resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return nil, fmt.Errorf("failed to list datasets: %s", msg) + } + + data, ok := resJSON["data"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + kbs, ok := data["kbs"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format: kbs not found") + } + + // Convert to slice of maps + tableData := make([]map[string]interface{}, 0, len(kbs)) + for _, kb := range kbs { + if kbMap, ok := kb.(map[string]interface{}); ok { + // Remove avatar field + delete(kbMap, "avatar") + tableData = append(tableData, kbMap) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// ListDatasets lists datasets for a specific user (admin mode) +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) ListDatasets(cmd *Command) (map[string]interface{}, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + userName, ok := cmd.Params["user_name"].(string) + if !ok { + return nil, fmt.Errorf("user_name not provided") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil, iterations) + } + + fmt.Printf("Listing all datasets of user: %s\n", userName) + + resp, err := c.HTTPClient.Request("GET", fmt.Sprintf("/admin/users/%s/datasets", userName), true, "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + data, ok := resJSON["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + // Convert to slice of maps and remove avatar + tableData := make([]map[string]interface{}, 0, len(data)) + for _, item := range data { + if itemMap, ok := item.(map[string]interface{}); ok { + delete(itemMap, "avatar") + tableData = append(tableData, itemMap) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// readPassword reads password from terminal without echoing +func readPassword() (string, error) { + // Check if stdin is a terminal by trying to get terminal size + if isTerminal() { + // Use stty to disable echo + cmd := exec.Command("stty", "-echo") + cmd.Stdin = os.Stdin + if err := cmd.Run(); err != nil { + // Fallback: read normally + return readPasswordFallback() + } + defer func() { + // Re-enable echo + cmd := exec.Command("stty", "echo") + cmd.Stdin = os.Stdin + cmd.Run() + }() + + reader := bufio.NewReader(os.Stdin) + password, err := reader.ReadString('\n') + fmt.Println() // New line after password input + if err != nil { + return "", err + } + return strings.TrimSpace(password), nil + } + + // Fallback for non-terminal input (e.g., piped input) + return readPasswordFallback() +} + +// isTerminal checks if stdin is a terminal +func isTerminal() bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} + +// readPasswordFallback reads password as plain text (fallback mode) +func readPasswordFallback() (string, error) { + reader := bufio.NewReader(os.Stdin) + password, err := reader.ReadString('\n') + if err != nil { + return "", err + } + return strings.TrimSpace(password), nil +} + +// getDatasetID gets dataset ID by name +func (c *RAGFlowClient) getDatasetID(datasetName string) (string, error) { + resp, err := c.HTTPClient.Request("POST", "/kb/list", false, "web", nil, nil) + if err != nil { + return "", fmt.Errorf("failed to list datasets: %w", err) + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("failed to list datasets: HTTP %d", resp.StatusCode) + } + + resJSON, err := resp.JSON() + if err != nil { + return "", fmt.Errorf("invalid JSON response: %w", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return "", fmt.Errorf("failed to list datasets: %s", msg) + } + + data, ok := resJSON["data"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid response format") + } + + kbs, ok := data["kbs"].([]interface{}) + if !ok { + return "", fmt.Errorf("invalid response format: kbs not found") + } + + for _, kb := range kbs { + if kbMap, ok := kb.(map[string]interface{}); ok { + if name, _ := kbMap["name"].(string); name == datasetName { + if id, _ := kbMap["id"].(string); id != "" { + return id, nil + } + } + } + } + + return "", fmt.Errorf("dataset '%s' not found", datasetName) +} + +// SearchOnDatasets searches for chunks in specified datasets +// Returns (result_map, error) - result_map is non-nil for benchmark mode +func (c *RAGFlowClient) SearchOnDatasets(cmd *Command) (map[string]interface{}, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + question, ok := cmd.Params["question"].(string) + if !ok { + return nil, fmt.Errorf("question not provided") + } + + datasets, ok := cmd.Params["datasets"].(string) + if !ok { + return nil, fmt.Errorf("datasets not provided") + } + + // Parse dataset names (comma-separated) and convert to IDs + datasetNames := strings.Split(datasets, ",") + datasetIDs := make([]string, 0, len(datasetNames)) + for _, name := range datasetNames { + name = strings.TrimSpace(name) + id, err := c.getDatasetID(name) + if err != nil { + return nil, err + } + datasetIDs = append(datasetIDs, id) + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + payload := map[string]interface{}{ + "kb_id": datasetIDs, + "question": question, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + } + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("POST", "/chunk/retrieval_test", false, "web", nil, payload, iterations) + } + + // Normal mode + resp, err := c.HTTPClient.Request("POST", "/chunk/retrieval_test", false, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to search on datasets: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to search on datasets: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + resJSON, err := resp.JSON() + if err != nil { + return nil, fmt.Errorf("invalid JSON response: %w", err) + } + + code, ok := resJSON["code"].(float64) + if !ok || code != 0 { + msg, _ := resJSON["message"].(string) + return nil, fmt.Errorf("failed to search on datasets: %s", msg) + } + + data, ok := resJSON["data"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format") + } + + chunks, ok := data["chunks"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response format: chunks not found") + } + + // Convert to slice of maps for printing + tableData := make([]map[string]interface{}, 0, len(chunks)) + for _, chunk := range chunks { + if chunkMap, ok := chunk.(map[string]interface{}); ok { + row := map[string]interface{}{ + "id": chunkMap["chunk_id"], + "content": chunkMap["content_with_weight"], + "document_id": chunkMap["doc_id"], + "dataset_id": chunkMap["kb_id"], + "docnm_kwd": chunkMap["docnm_kwd"], + "image_id": chunkMap["image_id"], + "similarity": chunkMap["similarity"], + "term_similarity": chunkMap["term_similarity"], + "vector_similarity": chunkMap["vector_similarity"], + } + tableData = append(tableData, row) + } + } + + PrintTableSimple(tableData) + return nil, nil +} + +// ExecuteCommand executes a parsed command +// Returns benchmark result map for commands that support it (e.g., ping_server with iterations > 1) +func (c *RAGFlowClient) ExecuteCommand(cmd *Command) (map[string]interface{}, error) { + switch cmd.Type { + case "login_user": + return nil, c.LoginUser(cmd) + case "ping_server": + return c.PingServer(cmd) + case "benchmark": + return nil, c.RunBenchmark(cmd) + case "list_user_datasets": + return c.ListUserDatasets(cmd) + case "list_datasets": + return c.ListDatasets(cmd) + case "search_on_datasets": + return c.SearchOnDatasets(cmd) + // TODO: Implement other commands + default: + return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) + } +} diff --git a/internal/cli/crypt.go b/internal/cli/crypt.go new file mode 100644 index 00000000000..4da5f18484a --- /dev/null +++ b/internal/cli/crypt.go @@ -0,0 +1,106 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "os" + "path/filepath" +) + +// EncryptPassword encrypts a password using RSA public key +// This matches the Python implementation in api/utils/crypt.py +func EncryptPassword(password string) (string, error) { + // Read public key from conf/public.pem + publicKeyPath := filepath.Join(getProjectBaseDirectory(), "conf", "public.pem") + publicKeyPEM, err := os.ReadFile(publicKeyPath) + if err != nil { + return "", fmt.Errorf("failed to read public key: %w", err) + } + + // Parse public key + block, _ := pem.Decode(publicKeyPEM) + if block == nil { + return "", fmt.Errorf("failed to parse public key PEM") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + // Try parsing as PKCS1 + pub, err = x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return "", fmt.Errorf("failed to parse public key: %w", err) + } + } + + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return "", fmt.Errorf("not an RSA public key") + } + + // Step 1: Base64 encode the password + passwordBase64 := base64.StdEncoding.EncodeToString([]byte(password)) + + // Step 2: Encrypt using RSA PKCS1v15 + encrypted, err := rsa.EncryptPKCS1v15(rand.Reader, rsaPub, []byte(passwordBase64)) + if err != nil { + return "", fmt.Errorf("failed to encrypt password: %w", err) + } + + // Step 3: Base64 encode the encrypted data + return base64.StdEncoding.EncodeToString(encrypted), nil +} + +// getProjectBaseDirectory returns the project base directory +func getProjectBaseDirectory() string { + // Try to find the project root by looking for go.mod or conf directory + // Start from current working directory and go up + cwd, err := os.Getwd() + if err != nil { + return "." + } + + dir := cwd + for { + // Check if conf directory exists + confDir := filepath.Join(dir, "conf") + if info, err := os.Stat(confDir); err == nil && info.IsDir() { + return dir + } + + // Check for go.mod + goMod := filepath.Join(dir, "go.mod") + if _, err := os.Stat(goMod); err == nil { + return dir + } + + // Go up one directory + parent := filepath.Dir(dir) + if parent == dir { + // Reached root + break + } + dir = parent + } + + return cwd +} diff --git a/internal/cli/http_client.go b/internal/cli/http_client.go new file mode 100644 index 00000000000..eb08b4ff634 --- /dev/null +++ b/internal/cli/http_client.go @@ -0,0 +1,248 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// HTTPClient handles HTTP requests to the RAGFlow server +type HTTPClient struct { + Host string + Port int + APIVersion string + APIKey string + LoginToken string + ConnectTimeout time.Duration + ReadTimeout time.Duration + VerifySSL bool + client *http.Client +} + +// NewHTTPClient creates a new HTTP client +func NewHTTPClient() *HTTPClient { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + return &HTTPClient{ + Host: "127.0.0.1", + Port: 9382, + APIVersion: "v1", + ConnectTimeout: 5 * time.Second, + ReadTimeout: 60 * time.Second, + VerifySSL: false, + client: &http.Client{ + Transport: transport, + Timeout: 60 * time.Second, + }, + } +} + +// APIBase returns the API base URL +func (c *HTTPClient) APIBase() string { + return fmt.Sprintf("%s:%d/api/%s", c.Host, c.Port, c.APIVersion) +} + +// NonAPIBase returns the non-API base URL +func (c *HTTPClient) NonAPIBase() string { + return fmt.Sprintf("%s:%d/%s", c.Host, c.Port, c.APIVersion) +} + +// BuildURL builds the full URL for a given path +func (c *HTTPClient) BuildURL(path string, useAPIBase bool) string { + base := c.APIBase() + if !useAPIBase { + base = c.NonAPIBase() + } + if c.VerifySSL { + return fmt.Sprintf("https://%s%s", base, path) + } + return fmt.Sprintf("http://%s%s", base, path) +} + +// Headers builds the request headers +func (c *HTTPClient) Headers(authKind string, extra map[string]string) map[string]string { + headers := make(map[string]string) + switch authKind { + case "api": + if c.APIKey != "" { + headers["Authorization"] = fmt.Sprintf("Bearer %s", c.APIKey) + } + case "web", "admin": + if c.LoginToken != "" { + headers["Authorization"] = c.LoginToken + } + } + for k, v := range extra { + headers[k] = v + } + return headers +} + +// Response represents an HTTP response +type Response struct { + StatusCode int + Body []byte + Headers http.Header +} + +// JSON parses the response body as JSON +func (r *Response) JSON() (map[string]interface{}, error) { + var result map[string]interface{} + if err := json.Unmarshal(r.Body, &result); err != nil { + return nil, err + } + return result, nil +} + +// Request makes an HTTP request +func (c *HTTPClient) Request(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (*Response, error) { + url := c.BuildURL(path, useAPIBase) + mergedHeaders := c.Headers(authKind, headers) + + var body io.Reader + if jsonBody != nil { + jsonData, err := json.Marshal(jsonBody) + if err != nil { + return nil, err + } + body = bytes.NewReader(jsonData) + if mergedHeaders == nil { + mergedHeaders = make(map[string]string) + } + mergedHeaders["Content-Type"] = "application/json" + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + for k, v := range mergedHeaders { + req.Header.Set(k, v) + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return &Response{ + StatusCode: resp.StatusCode, + Body: respBody, + Headers: resp.Header.Clone(), + }, nil +} + +// RequestWithIterations makes multiple HTTP requests for benchmarking +// Returns a map with "duration" (total time in seconds) and "response_list" +func (c *HTTPClient) RequestWithIterations(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}, iterations int) (map[string]interface{}, error) { + if iterations <= 1 { + resp, err := c.Request(method, path, useAPIBase, authKind, headers, jsonBody) + if err != nil { + return nil, err + } + return map[string]interface{}{ + "duration": 0.0, + "response_list": []*Response{resp}, + }, nil + } + + url := c.BuildURL(path, useAPIBase) + mergedHeaders := c.Headers(authKind, headers) + + var body io.Reader + if jsonBody != nil { + jsonData, err := json.Marshal(jsonBody) + if err != nil { + return nil, err + } + body = bytes.NewReader(jsonData) + if mergedHeaders == nil { + mergedHeaders = make(map[string]string) + } + mergedHeaders["Content-Type"] = "application/json" + } + + responseList := make([]*Response, 0, iterations) + var totalDuration float64 + + for i := 0; i < iterations; i++ { + start := time.Now() + + var reqBody io.Reader + if body != nil { + // Need to create a new reader for each request + jsonData, _ := json.Marshal(jsonBody) + reqBody = bytes.NewReader(jsonData) + } + + req, err := http.NewRequest(method, url, reqBody) + if err != nil { + return nil, err + } + + for k, v := range mergedHeaders { + req.Header.Set(k, v) + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + + respBody, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + responseList = append(responseList, &Response{ + StatusCode: resp.StatusCode, + Body: respBody, + Headers: resp.Header.Clone(), + }) + + totalDuration += time.Since(start).Seconds() + } + + return map[string]interface{}{ + "duration": totalDuration, + "response_list": responseList, + }, nil +} + +// RequestJSON makes an HTTP request and returns JSON response +func (c *HTTPClient) RequestJSON(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (map[string]interface{}, error) { + resp, err := c.Request(method, path, useAPIBase, authKind, headers, jsonBody) + if err != nil { + return nil, err + } + return resp.JSON() +} diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go new file mode 100644 index 00000000000..214285b65fa --- /dev/null +++ b/internal/cli/lexer.go @@ -0,0 +1,301 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "strings" + "unicode" +) + +// Lexer performs lexical analysis of the input +type Lexer struct { + input string + pos int + readPos int + ch byte +} + +// NewLexer creates a new lexer for the given input +func NewLexer(input string) *Lexer { + l := &Lexer{input: input} + l.readChar() + return l +} + +func (l *Lexer) readChar() { + if l.readPos >= len(l.input) { + l.ch = 0 + } else { + l.ch = l.input[l.readPos] + } + l.pos = l.readPos + l.readPos++ +} + +func (l *Lexer) peekChar() byte { + if l.readPos >= len(l.input) { + return 0 + } + return l.input[l.readPos] +} + +func (l *Lexer) skipWhitespace() { + for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { + l.readChar() + } +} + +// NextToken returns the next token from the input +func (l *Lexer) NextToken() Token { + var tok Token + + l.skipWhitespace() + + switch l.ch { + case ';': + tok = newToken(TokenSemicolon, l.ch) + l.readChar() + case ',': + tok = newToken(TokenComma, l.ch) + l.readChar() + case '\'': + tok.Type = TokenQuotedString + tok.Value = l.readQuotedString('\'') + case '"': + tok.Type = TokenQuotedString + tok.Value = l.readQuotedString('"') + case '\\': + // Meta command: backslash followed by command name + tok.Type = TokenIdentifier + tok.Value = l.readMetaCommand() + case 0: + tok.Type = TokenEOF + tok.Value = "" + default: + if isLetter(l.ch) { + ident := l.readIdentifier() + return l.lookupIdent(ident) + } else if isDigit(l.ch) { + tok.Type = TokenNumber + tok.Value = l.readNumber() + return tok + } else { + tok = newToken(TokenIllegal, l.ch) + l.readChar() + } + } + + return tok +} + +func (l *Lexer) readMetaCommand() string { + start := l.pos + l.readChar() // consume backslash + for isLetter(l.ch) || l.ch == '?' { + l.readChar() + } + return l.input[start:l.pos] +} + +func newToken(tokenType int, ch byte) Token { + return Token{Type: tokenType, Value: string(ch)} +} + +func (l *Lexer) readIdentifier() string { + start := l.pos + for isLetter(l.ch) || isDigit(l.ch) || l.ch == '_' || l.ch == '-' || l.ch == '.' { + l.readChar() + } + return l.input[start:l.pos] +} + +func (l *Lexer) readNumber() string { + start := l.pos + for isDigit(l.ch) { + l.readChar() + } + return l.input[start:l.pos] +} + +func (l *Lexer) readQuotedString(quote byte) string { + l.readChar() // skip opening quote + start := l.pos + for l.ch != quote && l.ch != 0 { + l.readChar() + } + str := l.input[start:l.pos] + if l.ch == quote { + l.readChar() // skip closing quote + } + return str +} + +func (l *Lexer) lookupIdent(ident string) Token { + upper := strings.ToUpper(ident) + switch upper { + case "LOGIN": + return Token{Type: TokenLogin, Value: ident} + case "REGISTER": + return Token{Type: TokenRegister, Value: ident} + case "LIST": + return Token{Type: TokenList, Value: ident} + case "SERVICES": + return Token{Type: TokenServices, Value: ident} + case "SHOW": + return Token{Type: TokenShow, Value: ident} + case "CREATE": + return Token{Type: TokenCreate, Value: ident} + case "SERVICE": + return Token{Type: TokenService, Value: ident} + case "SHUTDOWN": + return Token{Type: TokenShutdown, Value: ident} + case "STARTUP": + return Token{Type: TokenStartup, Value: ident} + case "RESTART": + return Token{Type: TokenRestart, Value: ident} + case "USERS": + return Token{Type: TokenUsers, Value: ident} + case "DROP": + return Token{Type: TokenDrop, Value: ident} + case "USER": + return Token{Type: TokenUser, Value: ident} + case "ALTER": + return Token{Type: TokenAlter, Value: ident} + case "ACTIVE": + return Token{Type: TokenActive, Value: ident} + case "ADMIN": + return Token{Type: TokenAdmin, Value: ident} + case "PASSWORD": + return Token{Type: TokenPassword, Value: ident} + case "DATASET": + return Token{Type: TokenDataset, Value: ident} + case "DATASETS": + return Token{Type: TokenDatasets, Value: ident} + case "OF": + return Token{Type: TokenOf, Value: ident} + case "AGENTS": + return Token{Type: TokenAgents, Value: ident} + case "ROLE": + return Token{Type: TokenRole, Value: ident} + case "ROLES": + return Token{Type: TokenRoles, Value: ident} + case "DESCRIPTION": + return Token{Type: TokenDescription, Value: ident} + case "GRANT": + return Token{Type: TokenGrant, Value: ident} + case "REVOKE": + return Token{Type: TokenRevoke, Value: ident} + case "ALL": + return Token{Type: TokenAll, Value: ident} + case "PERMISSION": + return Token{Type: TokenPermission, Value: ident} + case "TO": + return Token{Type: TokenTo, Value: ident} + case "FROM": + return Token{Type: TokenFrom, Value: ident} + case "FOR": + return Token{Type: TokenFor, Value: ident} + case "RESOURCES": + return Token{Type: TokenResources, Value: ident} + case "ON": + return Token{Type: TokenOn, Value: ident} + case "SET": + return Token{Type: TokenSet, Value: ident} + case "RESET": + return Token{Type: TokenReset, Value: ident} + case "VERSION": + return Token{Type: TokenVersion, Value: ident} + case "VAR": + return Token{Type: TokenVar, Value: ident} + case "VARS": + return Token{Type: TokenVars, Value: ident} + case "CONFIGS": + return Token{Type: TokenConfigs, Value: ident} + case "ENVS": + return Token{Type: TokenEnvs, Value: ident} + case "KEY": + return Token{Type: TokenKey, Value: ident} + case "KEYS": + return Token{Type: TokenKeys, Value: ident} + case "GENERATE": + return Token{Type: TokenGenerate, Value: ident} + case "MODEL": + return Token{Type: TokenModel, Value: ident} + case "MODELS": + return Token{Type: TokenModels, Value: ident} + case "PROVIDER": + return Token{Type: TokenProvider, Value: ident} + case "PROVIDERS": + return Token{Type: TokenProviders, Value: ident} + case "DEFAULT": + return Token{Type: TokenDefault, Value: ident} + case "CHATS": + return Token{Type: TokenChats, Value: ident} + case "CHAT": + return Token{Type: TokenChat, Value: ident} + case "FILES": + return Token{Type: TokenFiles, Value: ident} + case "AS": + return Token{Type: TokenAs, Value: ident} + case "PARSE": + return Token{Type: TokenParse, Value: ident} + case "IMPORT": + return Token{Type: TokenImport, Value: ident} + case "INTO": + return Token{Type: TokenInto, Value: ident} + case "WITH": + return Token{Type: TokenWith, Value: ident} + case "PARSER": + return Token{Type: TokenParser, Value: ident} + case "PIPELINE": + return Token{Type: TokenPipeline, Value: ident} + case "SEARCH": + return Token{Type: TokenSearch, Value: ident} + case "CURRENT": + return Token{Type: TokenCurrent, Value: ident} + case "LLM": + return Token{Type: TokenLLM, Value: ident} + case "VLM": + return Token{Type: TokenVLM, Value: ident} + case "EMBEDDING": + return Token{Type: TokenEmbedding, Value: ident} + case "RERANKER": + return Token{Type: TokenReranker, Value: ident} + case "ASR": + return Token{Type: TokenASR, Value: ident} + case "TTS": + return Token{Type: TokenTTS, Value: ident} + case "ASYNC": + return Token{Type: TokenAsync, Value: ident} + case "SYNC": + return Token{Type: TokenSync, Value: ident} + case "BENCHMARK": + return Token{Type: TokenBenchmark, Value: ident} + case "PING": + return Token{Type: TokenPing, Value: ident} + default: + return Token{Type: TokenIdentifier, Value: ident} + } +} + +func isLetter(ch byte) bool { + return unicode.IsLetter(rune(ch)) +} + +func isDigit(ch byte) bool { + return unicode.IsDigit(rune(ch)) +} diff --git a/internal/cli/parser.go b/internal/cli/parser.go new file mode 100644 index 00000000000..bd336566039 --- /dev/null +++ b/internal/cli/parser.go @@ -0,0 +1,1568 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "fmt" + "strconv" + "strings" +) + +// Parser implements a recursive descent parser for RAGFlow CLI commands +type Parser struct { + lexer *Lexer + curToken Token + peekToken Token +} + +// NewParser creates a new parser +func NewParser(input string) *Parser { + l := NewLexer(input) + p := &Parser{lexer: l} + // Read two tokens to initialize curToken and peekToken + p.nextToken() + p.nextToken() + return p +} + +func (p *Parser) nextToken() { + p.curToken = p.peekToken + p.peekToken = p.lexer.NextToken() +} + +// Parse parses the input and returns a Command +func (p *Parser) Parse() (*Command, error) { + if p.curToken.Type == TokenEOF { + return nil, nil + } + + // Check for meta commands (backslash commands) + if p.curToken.Type == TokenIdentifier && strings.HasPrefix(p.curToken.Value, "\\") { + return p.parseMetaCommand() + } + + // Parse SQL-like command + return p.parseSQLCommand() +} + +func (p *Parser) parseMetaCommand() (*Command, error) { + cmd := NewCommand("meta") + cmdName := strings.TrimPrefix(p.curToken.Value, "\\") + cmd.Params["command"] = strings.ToLower(cmdName) + + // Parse arguments + var args []string + p.nextToken() + for p.curToken.Type != TokenEOF { + args = append(args, p.curToken.Value) + p.nextToken() + } + cmd.Params["args"] = args + + return cmd, nil +} + +func (p *Parser) parseSQLCommand() (*Command, error) { + if p.curToken.Type != TokenIdentifier && !isKeyword(p.curToken.Type) { + return nil, fmt.Errorf("expected command, got %s", p.curToken.Value) + } + + switch p.curToken.Type { + case TokenLogin: + return p.parseLoginUser() + case TokenPing: + return p.parsePingServer() + case TokenList: + return p.parseListCommand() + case TokenShow: + return p.parseShowCommand() + case TokenCreate: + return p.parseCreateCommand() + case TokenDrop: + return p.parseDropCommand() + case TokenAlter: + return p.parseAlterCommand() + case TokenGrant: + return p.parseGrantCommand() + case TokenRevoke: + return p.parseRevokeCommand() + case TokenSet: + return p.parseSetCommand() + case TokenReset: + return p.parseResetCommand() + case TokenGenerate: + return p.parseGenerateCommand() + case TokenImport: + return p.parseImportCommand() + case TokenSearch: + return p.parseSearchCommand() + case TokenParse: + return p.parseParseCommand() + case TokenBenchmark: + return p.parseBenchmarkCommand() + case TokenRegister: + return p.parseRegisterCommand() + case TokenStartup: + return p.parseStartupCommand() + case TokenShutdown: + return p.parseShutdownCommand() + case TokenRestart: + return p.parseRestartCommand() + default: + return nil, fmt.Errorf("unknown command: %s", p.curToken.Value) + } +} + +func (p *Parser) expectPeek(tokenType int) error { + if p.peekToken.Type != tokenType { + return fmt.Errorf("expected %s, got %s", tokenTypeToString(tokenType), p.peekToken.Value) + } + p.nextToken() + return nil +} + +func (p *Parser) expectSemicolon() error { + if p.curToken.Type == TokenSemicolon { + return nil + } + if p.peekToken.Type == TokenSemicolon { + p.nextToken() + return nil + } + return fmt.Errorf("expected semicolon") +} + +func isKeyword(tokenType int) bool { + return tokenType >= TokenLogin && tokenType <= TokenPing +} + +// Helper functions for parsing +func (p *Parser) parseQuotedString() (string, error) { + if p.curToken.Type != TokenQuotedString { + return "", fmt.Errorf("expected quoted string, got %s", p.curToken.Value) + } + return p.curToken.Value, nil +} + +func (p *Parser) parseIdentifier() (string, error) { + if p.curToken.Type != TokenIdentifier { + return "", fmt.Errorf("expected identifier, got %s", p.curToken.Value) + } + return p.curToken.Value, nil +} + +func (p *Parser) parseNumber() (int, error) { + if p.curToken.Type != TokenNumber { + return 0, fmt.Errorf("expected number, got %s", p.curToken.Value) + } + return strconv.Atoi(p.curToken.Value) +} + +// Command parsers +func (p *Parser) parseLoginUser() (*Command, error) { + cmd := NewCommand("login_user") + + p.nextToken() // consume LOGIN + if p.curToken.Type != TokenUser { + return nil, fmt.Errorf("expected USER after LOGIN") + } + + p.nextToken() + email, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["email"] = email + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + + return cmd, nil +} + +func (p *Parser) parsePingServer() (*Command, error) { + cmd := NewCommand("ping_server") + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseRegisterCommand() (*Command, error) { + cmd := NewCommand("register_user") + + p.nextToken() // consume REGISTER + if err := p.expectPeek(TokenUser); err != nil { + return nil, err + } + p.nextToken() + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["user_name"] = userName + + p.nextToken() + if p.curToken.Type != TokenAs { + return nil, fmt.Errorf("expected AS") + } + + p.nextToken() + nickname, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["nickname"] = nickname + + p.nextToken() + if p.curToken.Type != TokenPassword { + return nil, fmt.Errorf("expected PASSWORD") + } + + p.nextToken() + password, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["password"] = password + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + + return cmd, nil +} + +func (p *Parser) parseListCommand() (*Command, error) { + p.nextToken() // consume LIST + + switch p.curToken.Type { + case TokenServices: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_services"), nil + case TokenUsers: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_users"), nil + case TokenRoles: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_roles"), nil + case TokenVars: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_variables"), nil + case TokenConfigs: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_configs"), nil + case TokenEnvs: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_environments"), nil + case TokenDatasets: + return p.parseListDatasets() + case TokenAgents: + return p.parseListAgents() + case TokenKeys: + return p.parseListKeys() + case TokenModel: + return p.parseListModelProviders() + case TokenDefault: + return p.parseListDefaultModels() + case TokenChats: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_user_chats"), nil + case TokenFiles: + return p.parseListFiles() + default: + return nil, fmt.Errorf("unknown LIST target: %s", p.curToken.Value) + } +} + +func (p *Parser) parseListDatasets() (*Command, error) { + cmd := NewCommand("list_user_datasets") + p.nextToken() // consume DATASETS + + if p.curToken.Type == TokenSemicolon { + return cmd, nil + } + + if p.curToken.Type == TokenOf { + p.nextToken() + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd = NewCommand("list_datasets") + cmd.Params["user_name"] = userName + p.nextToken() + } + + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseListAgents() (*Command, error) { + p.nextToken() // consume AGENTS + + if p.curToken.Type == TokenSemicolon { + return NewCommand("list_user_agents"), nil + } + + if p.curToken.Type != TokenOf { + return nil, fmt.Errorf("expected OF") + } + p.nextToken() + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("list_agents") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseListKeys() (*Command, error) { + p.nextToken() // consume KEYS + if p.curToken.Type != TokenOf { + return nil, fmt.Errorf("expected OF") + } + p.nextToken() + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("list_keys") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseListModelProviders() (*Command, error) { + p.nextToken() // consume MODEL + if p.curToken.Type != TokenProviders { + return nil, fmt.Errorf("expected PROVIDERS") + } + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_user_model_providers"), nil +} + +func (p *Parser) parseListDefaultModels() (*Command, error) { + p.nextToken() // consume DEFAULT + if p.curToken.Type != TokenModels { + return nil, fmt.Errorf("expected MODELS") + } + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("list_user_default_models"), nil +} + +func (p *Parser) parseListFiles() (*Command, error) { + p.nextToken() // consume FILES + if p.curToken.Type != TokenOf { + return nil, fmt.Errorf("expected OF") + } + p.nextToken() + if p.curToken.Type != TokenDataset { + return nil, fmt.Errorf("expected DATASET") + } + p.nextToken() + + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("list_user_dataset_files") + cmd.Params["dataset_name"] = datasetName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseShowCommand() (*Command, error) { + p.nextToken() // consume SHOW + + switch p.curToken.Type { + case TokenVersion: + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("show_version"), nil + case TokenCurrent: + p.nextToken() + if p.curToken.Type != TokenUser { + return nil, fmt.Errorf("expected USER after CURRENT") + } + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return NewCommand("show_current_user"), nil + case TokenUser: + return p.parseShowUser() + case TokenRole: + return p.parseShowRole() + case TokenVar: + return p.parseShowVariable() + case TokenService: + return p.parseShowService() + default: + return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value) + } +} + +func (p *Parser) parseShowUser() (*Command, error) { + p.nextToken() // consume USER + + // Check for PERMISSION + if p.curToken.Type == TokenPermission { + p.nextToken() + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd := NewCommand("show_user_permission") + cmd.Params["user_name"] = userName + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil + } + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("show_user") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseShowRole() (*Command, error) { + p.nextToken() // consume ROLE + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("show_role") + cmd.Params["role_name"] = roleName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseShowVariable() (*Command, error) { + p.nextToken() // consume VAR + varName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("show_variable") + cmd.Params["var_name"] = varName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseShowService() (*Command, error) { + p.nextToken() // consume SERVICE + serviceNum, err := p.parseNumber() + if err != nil { + return nil, err + } + + cmd := NewCommand("show_service") + cmd.Params["number"] = serviceNum + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseCreateCommand() (*Command, error) { + p.nextToken() // consume CREATE + + switch p.curToken.Type { + case TokenUser: + return p.parseCreateUser() + case TokenRole: + return p.parseCreateRole() + case TokenModel: + return p.parseCreateModelProvider() + case TokenDataset: + return p.parseCreateDataset() + case TokenChat: + return p.parseCreateChat() + default: + return nil, fmt.Errorf("unknown CREATE target: %s", p.curToken.Value) + } +} + +func (p *Parser) parseCreateUser() (*Command, error) { + p.nextToken() // consume USER + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + password, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("create_user") + cmd.Params["user_name"] = userName + cmd.Params["password"] = password + cmd.Params["role"] = "user" + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseCreateRole() (*Command, error) { + p.nextToken() // consume ROLE + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("create_role") + cmd.Params["role_name"] = roleName + + p.nextToken() + if p.curToken.Type == TokenDescription { + p.nextToken() + description, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["description"] = description + p.nextToken() + } + + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseCreateModelProvider() (*Command, error) { + p.nextToken() // consume MODEL + if p.curToken.Type != TokenProvider { + return nil, fmt.Errorf("expected PROVIDER") + } + p.nextToken() + + providerName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + providerKey, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("create_model_provider") + cmd.Params["provider_name"] = providerName + cmd.Params["provider_key"] = providerKey + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseCreateDataset() (*Command, error) { + p.nextToken() // consume DATASET + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH") + } + p.nextToken() + if p.curToken.Type != TokenEmbedding { + return nil, fmt.Errorf("expected EMBEDDING") + } + p.nextToken() + + embedding, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + cmd := NewCommand("create_user_dataset") + cmd.Params["dataset_name"] = datasetName + cmd.Params["embedding"] = embedding + + if p.curToken.Type == TokenParser { + p.nextToken() + parserType, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["parser_type"] = parserType + p.nextToken() + } else if p.curToken.Type == TokenPipeline { + p.nextToken() + pipeline, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["pipeline"] = pipeline + p.nextToken() + } else { + return nil, fmt.Errorf("expected PARSER or PIPELINE") + } + + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseCreateChat() (*Command, error) { + p.nextToken() // consume CHAT + chatName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("create_user_chat") + cmd.Params["chat_name"] = chatName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropCommand() (*Command, error) { + p.nextToken() // consume DROP + + switch p.curToken.Type { + case TokenUser: + return p.parseDropUser() + case TokenRole: + return p.parseDropRole() + case TokenModel: + return p.parseDropModelProvider() + case TokenDataset: + return p.parseDropDataset() + case TokenChat: + return p.parseDropChat() + case TokenKey: + return p.parseDropKey() + default: + return nil, fmt.Errorf("unknown DROP target: %s", p.curToken.Value) + } +} + +func (p *Parser) parseDropUser() (*Command, error) { + p.nextToken() // consume USER + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_user") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropRole() (*Command, error) { + p.nextToken() // consume ROLE + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_role") + cmd.Params["role_name"] = roleName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropModelProvider() (*Command, error) { + p.nextToken() // consume MODEL + if p.curToken.Type != TokenProvider { + return nil, fmt.Errorf("expected PROVIDER") + } + p.nextToken() + + providerName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_model_provider") + cmd.Params["provider_name"] = providerName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropDataset() (*Command, error) { + p.nextToken() // consume DATASET + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_user_dataset") + cmd.Params["dataset_name"] = datasetName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropChat() (*Command, error) { + p.nextToken() // consume CHAT + chatName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_user_chat") + cmd.Params["chat_name"] = chatName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseDropKey() (*Command, error) { + p.nextToken() // consume KEY + key, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenOf { + return nil, fmt.Errorf("expected OF") + } + p.nextToken() + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("drop_key") + cmd.Params["key"] = key + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseAlterCommand() (*Command, error) { + p.nextToken() // consume ALTER + + switch p.curToken.Type { + case TokenUser: + return p.parseAlterUser() + case TokenRole: + return p.parseAlterRole() + default: + return nil, fmt.Errorf("unknown ALTER target: %s", p.curToken.Value) + } +} + +func (p *Parser) parseAlterUser() (*Command, error) { + p.nextToken() // consume USER + + if p.curToken.Type == TokenActive { + return p.parseActivateUser() + } + + if p.curToken.Type == TokenPassword { + p.nextToken() + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + password, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("alter_user") + cmd.Params["user_name"] = userName + cmd.Params["password"] = password + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil + } + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenSet { + return nil, fmt.Errorf("expected SET") + } + p.nextToken() + if p.curToken.Type != TokenRole { + return nil, fmt.Errorf("expected ROLE") + } + p.nextToken() + + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("alter_user_role") + cmd.Params["user_name"] = userName + cmd.Params["role_name"] = roleName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseActivateUser() (*Command, error) { + p.nextToken() // consume ACTIVE + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + status, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("activate_user") + cmd.Params["user_name"] = userName + cmd.Params["activate_status"] = status + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseAlterRole() (*Command, error) { + p.nextToken() // consume ROLE + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenSet { + return nil, fmt.Errorf("expected SET") + } + p.nextToken() + if p.curToken.Type != TokenDescription { + return nil, fmt.Errorf("expected DESCRIPTION") + } + p.nextToken() + + description, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("alter_role") + cmd.Params["role_name"] = roleName + cmd.Params["description"] = description + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseGrantCommand() (*Command, error) { + p.nextToken() // consume GRANT + + if p.curToken.Type == TokenAdmin { + return p.parseGrantAdmin() + } + + return p.parseGrantPermission() +} + +func (p *Parser) parseGrantAdmin() (*Command, error) { + p.nextToken() // consume ADMIN + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("grant_admin") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseGrantPermission() (*Command, error) { + actions, err := p.parseIdentifierList() + if err != nil { + return nil, err + } + + if p.curToken.Type != TokenOn { + return nil, fmt.Errorf("expected ON") + } + p.nextToken() + + resource, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenTo { + return nil, fmt.Errorf("expected TO") + } + p.nextToken() + if p.curToken.Type != TokenRole { + return nil, fmt.Errorf("expected ROLE") + } + p.nextToken() + + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("grant_permission") + cmd.Params["actions"] = actions + cmd.Params["resource"] = resource + cmd.Params["role_name"] = roleName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseRevokeCommand() (*Command, error) { + p.nextToken() // consume REVOKE + + if p.curToken.Type == TokenAdmin { + return p.parseRevokeAdmin() + } + + return p.parseRevokePermission() +} + +func (p *Parser) parseRevokeAdmin() (*Command, error) { + p.nextToken() // consume ADMIN + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("revoke_admin") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseRevokePermission() (*Command, error) { + actions, err := p.parseIdentifierList() + if err != nil { + return nil, err + } + + if p.curToken.Type != TokenOn { + return nil, fmt.Errorf("expected ON") + } + p.nextToken() + + resource, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() + if p.curToken.Type != TokenRole { + return nil, fmt.Errorf("expected ROLE") + } + p.nextToken() + + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("revoke_permission") + cmd.Params["actions"] = actions + cmd.Params["resource"] = resource + cmd.Params["role_name"] = roleName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseIdentifierList() ([]string, error) { + var list []string + + ident, err := p.parseIdentifier() + if err != nil { + return nil, err + } + list = append(list, ident) + p.nextToken() + + for p.curToken.Type == TokenComma { + p.nextToken() + ident, err := p.parseIdentifier() + if err != nil { + return nil, err + } + list = append(list, ident) + p.nextToken() + } + + return list, nil +} + +func (p *Parser) parseSetCommand() (*Command, error) { + p.nextToken() // consume SET + + if p.curToken.Type == TokenVar { + return p.parseSetVariable() + } + if p.curToken.Type == TokenDefault { + return p.parseSetDefault() + } + + return nil, fmt.Errorf("unknown SET target: %s", p.curToken.Value) +} + +func (p *Parser) parseSetVariable() (*Command, error) { + p.nextToken() // consume VAR + varName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + p.nextToken() + varValue, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + cmd := NewCommand("set_variable") + cmd.Params["var_name"] = varName + cmd.Params["var_value"] = varValue + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseSetDefault() (*Command, error) { + p.nextToken() // consume DEFAULT + + var modelType, modelID string + + switch p.curToken.Type { + case TokenLLM: + modelType = "llm_id" + case TokenVLM: + modelType = "img2txt_id" + case TokenEmbedding: + modelType = "embd_id" + case TokenReranker: + modelType = "reranker_id" + case TokenASR: + modelType = "asr_id" + case TokenTTS: + modelType = "tts_id" + default: + return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) + } + + p.nextToken() + id, err := p.parseQuotedString() + if err != nil { + return nil, err + } + modelID = id + + cmd := NewCommand("set_default_model") + cmd.Params["model_type"] = modelType + cmd.Params["model_id"] = modelID + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseResetCommand() (*Command, error) { + p.nextToken() // consume RESET + + if p.curToken.Type != TokenDefault { + return nil, fmt.Errorf("expected DEFAULT") + } + p.nextToken() + + var modelType string + switch p.curToken.Type { + case TokenLLM: + modelType = "llm_id" + case TokenVLM: + modelType = "img2txt_id" + case TokenEmbedding: + modelType = "embd_id" + case TokenReranker: + modelType = "reranker_id" + case TokenASR: + modelType = "asr_id" + case TokenTTS: + modelType = "tts_id" + default: + return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) + } + + cmd := NewCommand("reset_default_model") + cmd.Params["model_type"] = modelType + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseGenerateCommand() (*Command, error) { + p.nextToken() // consume GENERATE + if p.curToken.Type != TokenKey { + return nil, fmt.Errorf("expected KEY") + } + p.nextToken() + if p.curToken.Type != TokenFor { + return nil, fmt.Errorf("expected FOR") + } + p.nextToken() + if p.curToken.Type != TokenUser { + return nil, fmt.Errorf("expected USER") + } + p.nextToken() + + userName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("generate_key") + cmd.Params["user_name"] = userName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseImportCommand() (*Command, error) { + p.nextToken() // consume IMPORT + documentPaths, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenInto { + return nil, fmt.Errorf("expected INTO") + } + p.nextToken() + if p.curToken.Type != TokenDataset { + return nil, fmt.Errorf("expected DATASET") + } + p.nextToken() + + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("import_docs_into_dataset") + cmd.Params["document_paths"] = documentPaths + cmd.Params["dataset_name"] = datasetName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseSearchCommand() (*Command, error) { + p.nextToken() // consume SEARCH + question, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenOn { + return nil, fmt.Errorf("expected ON") + } + p.nextToken() + if p.curToken.Type != TokenDatasets { + return nil, fmt.Errorf("expected DATASETS") + } + p.nextToken() + + datasets, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("search_on_datasets") + cmd.Params["question"] = question + cmd.Params["datasets"] = datasets + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseParseCommand() (*Command, error) { + p.nextToken() // consume PARSE + + if p.curToken.Type == TokenDataset { + return p.parseParseDataset() + } + + return p.parseParseDocs() +} + +func (p *Parser) parseParseDataset() (*Command, error) { + p.nextToken() // consume DATASET + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + var method string + if p.curToken.Type == TokenSync { + method = "sync" + } else if p.curToken.Type == TokenAsync { + method = "async" + } else { + return nil, fmt.Errorf("expected SYNC or ASYNC") + } + + cmd := NewCommand("parse_dataset") + cmd.Params["dataset_name"] = datasetName + cmd.Params["method"] = method + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseParseDocs() (*Command, error) { + documentNames, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + if p.curToken.Type != TokenOf { + return nil, fmt.Errorf("expected OF") + } + p.nextToken() + if p.curToken.Type != TokenDataset { + return nil, fmt.Errorf("expected DATASET") + } + p.nextToken() + + datasetName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("parse_dataset_docs") + cmd.Params["document_names"] = documentNames + cmd.Params["dataset_name"] = datasetName + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseBenchmarkCommand() (*Command, error) { + cmd := NewCommand("benchmark") + + p.nextToken() // consume BENCHMARK + concurrency, err := p.parseNumber() + if err != nil { + return nil, err + } + cmd.Params["concurrency"] = concurrency + + p.nextToken() + iterations, err := p.parseNumber() + if err != nil { + return nil, err + } + cmd.Params["iterations"] = iterations + + p.nextToken() + // Parse user_statement + nestedCmd, err := p.parseUserStatement() + if err != nil { + return nil, err + } + cmd.Params["command"] = nestedCmd + + return cmd, nil +} + +func (p *Parser) parseUserStatement() (*Command, error) { + switch p.curToken.Type { + case TokenPing: + return p.parsePingServer() + case TokenShow: + return p.parseShowCommand() + case TokenCreate: + return p.parseCreateCommand() + case TokenDrop: + return p.parseDropCommand() + case TokenSet: + return p.parseSetCommand() + case TokenReset: + return p.parseResetCommand() + case TokenList: + return p.parseListCommand() + case TokenParse: + return p.parseParseCommand() + case TokenImport: + return p.parseImportCommand() + case TokenSearch: + return p.parseSearchCommand() + default: + return nil, fmt.Errorf("invalid user statement: %s", p.curToken.Value) + } +} + +func (p *Parser) parseStartupCommand() (*Command, error) { + p.nextToken() // consume STARTUP + if p.curToken.Type != TokenService { + return nil, fmt.Errorf("expected SERVICE") + } + p.nextToken() + + serviceNum, err := p.parseNumber() + if err != nil { + return nil, err + } + + cmd := NewCommand("startup_service") + cmd.Params["number"] = serviceNum + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseShutdownCommand() (*Command, error) { + p.nextToken() // consume SHUTDOWN + if p.curToken.Type != TokenService { + return nil, fmt.Errorf("expected SERVICE") + } + p.nextToken() + + serviceNum, err := p.parseNumber() + if err != nil { + return nil, err + } + + cmd := NewCommand("shutdown_service") + cmd.Params["number"] = serviceNum + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func (p *Parser) parseRestartCommand() (*Command, error) { + p.nextToken() // consume RESTART + if p.curToken.Type != TokenService { + return nil, fmt.Errorf("expected SERVICE") + } + p.nextToken() + + serviceNum, err := p.parseNumber() + if err != nil { + return nil, err + } + + cmd := NewCommand("restart_service") + cmd.Params["number"] = serviceNum + + p.nextToken() + if err := p.expectSemicolon(); err != nil { + return nil, err + } + return cmd, nil +} + +func tokenTypeToString(t int) string { + // Simplified for error messages + return fmt.Sprintf("token(%d)", t) +} diff --git a/internal/cli/table.go b/internal/cli/table.go new file mode 100644 index 00000000000..7baef5d5aef --- /dev/null +++ b/internal/cli/table.go @@ -0,0 +1,167 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "fmt" + "strings" + "unicode" +) + +// PrintTableSimple prints data in a simple table format +// Similar to Python's _print_table_simple +func PrintTableSimple(data []map[string]interface{}) { + if len(data) == 0 { + fmt.Println("No data to print") + return + } + + // Collect all column names + columnSet := make(map[string]bool) + for _, item := range data { + for key := range item { + columnSet[key] = true + } + } + + // Sort columns + columns := make([]string, 0, len(columnSet)) + for col := range columnSet { + columns = append(columns, col) + } + // Simple sort - in production you might want specific column ordering + for i := 0; i < len(columns); i++ { + for j := i + 1; j < len(columns); j++ { + if columns[i] > columns[j] { + columns[i], columns[j] = columns[j], columns[i] + } + } + } + + // Calculate column widths + colWidths := make(map[string]int) + for _, col := range columns { + maxWidth := getStringWidth(col) + for _, item := range data { + value := fmt.Sprintf("%v", item[col]) + valueWidth := getStringWidth(value) + if valueWidth > maxWidth { + maxWidth = valueWidth + } + } + if maxWidth < 2 { + maxWidth = 2 + } + colWidths[col] = maxWidth + } + + // Generate separator + separatorParts := make([]string, 0, len(columns)) + for _, col := range columns { + separatorParts = append(separatorParts, strings.Repeat("-", colWidths[col]+2)) + } + separator := "+" + strings.Join(separatorParts, "+") + "+" + + // Print header + fmt.Println(separator) + headerParts := make([]string, 0, len(columns)) + for _, col := range columns { + headerParts = append(headerParts, fmt.Sprintf(" %-*s ", colWidths[col], col)) + } + fmt.Println("|" + strings.Join(headerParts, "|") + "|") + fmt.Println(separator) + + // Print data rows + for _, item := range data { + rowParts := make([]string, 0, len(columns)) + for _, col := range columns { + value := fmt.Sprintf("%v", item[col]) + valueWidth := getStringWidth(value) + // Truncate if too long + if valueWidth > colWidths[col] { + runes := []rune(value) + truncated := truncateString(runes, colWidths[col]) + value = truncated + valueWidth = getStringWidth(value) + } + // Pad to column width + padding := colWidths[col] - valueWidth + len(value) + rowParts = append(rowParts, fmt.Sprintf(" %-*s ", padding, value)) + } + fmt.Println("|" + strings.Join(rowParts, "|") + "|") + } + + fmt.Println(separator) +} + +// getStringWidth calculates the display width of a string +// Treats CJK characters as width 2 +func getStringWidth(text string) int { + width := 0 + for _, r := range text { + if isHalfWidth(r) { + width++ + } else { + width += 2 + } + } + return width +} + +// isHalfWidth checks if a rune is half-width +func isHalfWidth(r rune) bool { + // ASCII printable characters and common whitespace + if r >= 0x20 && r <= 0x7E { + return true + } + if r == '\t' || r == '\n' || r == '\r' { + return true + } + return false +} + +// truncateString truncates a string to fit within maxWidth display width +func truncateString(runes []rune, maxWidth int) string { + width := 0 + for i, r := range runes { + if isHalfWidth(r) { + width++ + } else { + width += 2 + } + if width > maxWidth-3 { + return string(runes[:i]) + "..." + } + } + return string(runes) +} + +// getMax returns the maximum of two integers +func getMax(a, b int) int { + if a > b { + return a + } + return b +} + +// isWideChar checks if a character is wide (CJK, etc.) +func isWideChar(r rune) bool { + return unicode.Is(unicode.Han, r) || + unicode.Is(unicode.Hiragana, r) || + unicode.Is(unicode.Katakana, r) || + unicode.Is(unicode.Hangul, r) +} diff --git a/internal/cli/types.go b/internal/cli/types.go new file mode 100644 index 00000000000..b9d11b8b369 --- /dev/null +++ b/internal/cli/types.go @@ -0,0 +1,123 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +// Command represents a parsed command from the CLI +type Command struct { + Type string + Params map[string]interface{} +} + +// Token types for the lexer +const ( + // Keywords + TokenLogin = iota + TokenRegister + TokenList + TokenServices + TokenShow + TokenCreate + TokenService + TokenShutdown + TokenStartup + TokenRestart + TokenUsers + TokenDrop + TokenUser + TokenAlter + TokenActive + TokenAdmin + TokenPassword + TokenDataset + TokenDatasets + TokenOf + TokenAgents + TokenRole + TokenRoles + TokenDescription + TokenGrant + TokenRevoke + TokenAll + TokenPermission + TokenTo + TokenFrom + TokenFor + TokenResources + TokenOn + TokenSet + TokenReset + TokenVersion + TokenVar + TokenVars + TokenConfigs + TokenEnvs + TokenKey + TokenKeys + TokenGenerate + TokenModel + TokenModels + TokenProvider + TokenProviders + TokenDefault + TokenChats + TokenChat + TokenFiles + TokenAs + TokenParse + TokenImport + TokenInto + TokenWith + TokenParser + TokenPipeline + TokenSearch + TokenCurrent + TokenLLM + TokenVLM + TokenEmbedding + TokenReranker + TokenASR + TokenTTS + TokenAsync + TokenSync + TokenBenchmark + TokenPing + + // Literals + TokenIdentifier + TokenQuotedString + TokenNumber + + // Special + TokenSemicolon + TokenComma + TokenEOF + TokenIllegal +) + +// Token represents a lexical token +type Token struct { + Type int + Value string +} + +// NewCommand creates a new command with the given type +func NewCommand(cmdType string) *Command { + return &Command{ + Type: cmdType, + Params: make(map[string]interface{}), + } +} diff --git a/internal/cpp/CMakeLists.txt b/internal/cpp/CMakeLists.txt new file mode 100644 index 00000000000..9c4b4f5e299 --- /dev/null +++ b/internal/cpp/CMakeLists.txt @@ -0,0 +1,138 @@ +cmake_minimum_required(VERSION 4.0) +project(rag_tokenizer) + +set(CMAKE_CXX_STANDARD 23) + +# Option to enable AddressSanitizer +option(ENABLE_ASAN "Enable AddressSanitizer" OFF) + +if(ENABLE_ASAN) + message(STATUS "AddressSanitizer enabled") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -g") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fno-omit-frame-pointer -g") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fsanitize=address") +endif() + +file(GLOB_RECURSE + stemmer_src + CONFIGURE_DEPENDS + stemmer/*.cpp + stemmer/*.cc + stemmer/*.c + stemmer/*.h +) + +file(GLOB_RECURSE + opencc_src + CONFIGURE_DEPENDS + opencc/*.cpp + opencc/*.cc + opencc/*.c + opencc/*.h +) + +file(GLOB_RECURSE + util_src + CONFIGURE_DEPENDS + util/*.cpp + util/*.cc + util/*.c + util/*.h +) + +file(GLOB_RECURSE + re2_src + CONFIGURE_DEPENDS + re2/*.cpp + re2/*.cc + re2/*.c + re2/*.h +) + +file(GLOB_RECURSE + darts_src + CONFIGURE_DEPENDS + darts/*.h +) + +file(GLOB + main_src + CONFIGURE_DEPENDS + *.cpp + *.cc + *.c + *.h +) + +# Filter out C API files from main_src +list(FILTER main_src EXCLUDE REGEX "rag_analyzer_c_api") + +add_executable(rag_tokenizer + main.cpp + rag_analyzer.cpp + rag_analyzer.h + dart_trie.h + darts_trie.cpp + wordnet_lemmatizer.cpp + wordnet_lemmatizer.h + string_utils.h + term.h + term.cpp + tokenizer.cpp + tokenizer.h + analyzer.h + ${stemmer_src} + ${opencc_src} + ${util_src} + ${darts_src} + ${re2_src}) + +target_link_libraries(rag_tokenizer stdc++ m libpcre2-8.a) +target_include_directories(rag_tokenizer PUBLIC "${CMAKE_SOURCE_DIR}") +set_target_properties(rag_tokenizer PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON +) + +# Build C API static library for CGO +add_library(rag_tokenizer_c_api STATIC + rag_analyzer_c_api.cpp + rag_analyzer_c_api.h + rag_analyzer.cpp + rag_analyzer.h + dart_trie.h + darts_trie.cpp + wordnet_lemmatizer.cpp + wordnet_lemmatizer.h + string_utils.h + term.h + term.cpp + tokenizer.cpp + tokenizer.h + analyzer.h + ${stemmer_src} + ${opencc_src} + ${util_src} + ${darts_src} + ${re2_src} +) + +target_link_libraries(rag_tokenizer_c_api stdc++ libm.a libpcre2-8.a) +target_include_directories(rag_tokenizer_c_api PUBLIC "${CMAKE_SOURCE_DIR}") +set_target_properties(rag_tokenizer_c_api PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON +) + +# Test executable for C API +add_executable(rag_analyzer_c_test + rag_analyzer_c_test.cpp +) + +target_link_libraries(rag_analyzer_c_test rag_tokenizer_c_api stdc++ libm.a libpcre2-8.a) +target_include_directories(rag_analyzer_c_test PUBLIC "${CMAKE_SOURCE_DIR}") +set_target_properties(rag_analyzer_c_test PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON +) diff --git a/internal/cpp/Makefile b/internal/cpp/Makefile new file mode 100644 index 00000000000..cbf66ac70ff --- /dev/null +++ b/internal/cpp/Makefile @@ -0,0 +1,81 @@ +# Makefile for RAG Tokenizer with CGO bindings + +.PHONY: all clean build c_api c_api_debug c_api_asan test_go test_memory valgrind asan + +BUILD_DIR := build +ASAN_BUILD_DIR := build-asan +C_API_LIB := $(BUILD_DIR)/librag_tokenizer_c_api.a +C_API_ASAN_LIB := $(ASAN_BUILD_DIR)/librag_tokenizer_c_api.a +C_API_DEBUG_LIB := $(BUILD_DIR)/librag_tokenizer_c_api_debug.a + +all: build c_api + +# Create build directory +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +$(ASAN_BUILD_DIR): + mkdir -p $(ASAN_BUILD_DIR) + +# Build the main executable and C API library +build: $(BUILD_DIR) + cd $(BUILD_DIR) && cmake .. && make -j$$(nproc) + +# Build only the C API library +c_api: $(BUILD_DIR) + cd $(BUILD_DIR) && cmake .. && make rag_tokenizer_c_api -j$$(nproc) + +# Build C API library with AddressSanitizer +c_api_asan: $(ASAN_BUILD_DIR) + cd $(ASAN_BUILD_DIR) && cmake .. -DENABLE_ASAN=ON && make rag_tokenizer_c_api -j$$(nproc) + @echo "ASan library built: $(C_API_ASAN_LIB)" + +# Build debug version of C API library with memory tracking +c_api_debug: $(BUILD_DIR) + cd $(BUILD_DIR) && \ + g++ -std=c++17 -static-libgcc -static-libstdc++ -DMEMORY_DEBUG \ + -I.. \ + ../rag_analyzer_c_api_debug.cpp \ + ../rag_analyzer.cpp \ + ../darts_trie.cpp \ + ../wordnet_lemmatizer.cpp \ + ../term.cpp \ + ../tokenizer.cpp \ + ../stemmer/*.cpp \ + ../opencc/*.c ../opencc/*.cpp \ + ../util/*.cc \ + ../re2/*.cc \ + -o librag_tokenizer_c_api_debug.a \ + -lstdc++ -lm -lpthread -lpcre2-8 + @echo "Debug library built: $(C_API_DEBUG_LIB)" + +# Test the Go bindings +test_go: c_api + cd go_bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test." + +# Run memory test +test_memory: c_api + cd go_bindings/example && go run memory_leak_check.go + +# Run with valgrind +valgrind: c_api + cd go_bindings/example && bash run_valgrind.sh + +# Run with AddressSanitizer +asan: c_api_asan + @echo "Running with AddressSanitizer..." + cd go_bindings/example && \ + ASAN_OPTIONS=detect_leaks=1:print_stats=1:verbosity=0 \ + go run memory_leak_check.go + +# Install the C API library (optional) +install: c_api + sudo cp $(C_API_LIB) /usr/local/lib/ + sudo ldconfig + +# Clean build artifacts +clean: + rm -rf $(BUILD_DIR) + rm -rf $(ASAN_BUILD_DIR) + rm -f go_bindings/example/valgrind.log + rm -f go_bindings/example/memory_test_bin diff --git a/internal/cpp/analyzer.h b/internal/cpp/analyzer.h new file mode 100644 index 00000000000..73c2fd638bd --- /dev/null +++ b/internal/cpp/analyzer.h @@ -0,0 +1,88 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "tokenizer.h" +#include "term.h" + +enum class CutGrain { + kCoarse, + kFine, +}; + +class Analyzer { +public: + Analyzer() = default; + + virtual ~Analyzer() = default; + + void SetExtractSpecialChar(bool extract_special_char, bool convert_to_placeholder = true) { + extract_special_char_ = extract_special_char; + convert_to_placeholder_ = convert_to_placeholder; + } + + void SetCharOffset(bool set) { get_char_offset_ = set; } + + void SetTokenizerConfig(const TokenizeConfig &conf) { tokenizer_.SetConfig(conf); } + + int Analyze(const Term &input, TermList &output, bool fine_grained = false, bool enable_position = false) { + void *array[2] = {&output, this}; + return AnalyzeImpl(input, &array, fine_grained, enable_position, Analyzer::AppendTermList); + } + +protected: + typedef void (*HookType)(void *data, + const char *text, + const uint32_t len, + const uint32_t offset, + const uint32_t end_offset, + const bool is_special_char, + const uint16_t payload); + + virtual int AnalyzeImpl(const Term &input, void *data, bool fine_grained, bool enable_position,HookType func) const { return -1; } + + static void AppendTermList(void *data, + const char *text, + const uint32_t len, + const uint32_t offset, + const uint32_t end_offset, + const bool is_special_char, + const uint16_t payload) { + void **parameters = (void **)data; + TermList *output = (TermList *)parameters[0]; + Analyzer *analyzer = (Analyzer *)parameters[1]; + + if (is_special_char && !analyzer->extract_special_char_) + return; + if (is_special_char && analyzer->convert_to_placeholder_) { + if (output->empty() == true || output->back().text_.compare(PLACE_HOLDER) != 0) + output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset, payload); + } else { + output->Add(text, len, offset, end_offset, payload); + } + } + + Tokenizer tokenizer_; + + /// Whether including speical characters (e.g. puncutations) in the result. + bool extract_special_char_; + + /// Whether converting speical characters (e.g. puncutations) into a particular place holder + /// symbol in the result. + /// Be effect only when extract_special_char_ is set. + bool convert_to_placeholder_; + + bool get_char_offset_{false}; +}; diff --git a/internal/cpp/dart_trie.h b/internal/cpp/dart_trie.h new file mode 100644 index 00000000000..f4919592056 --- /dev/null +++ b/internal/cpp/dart_trie.h @@ -0,0 +1,77 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "darts/darts.h" +#include +#include +#include +#include +#include + +class POSTable +{ +public: + POSTable(const std::string& path); + + ~POSTable() = default; + + int32_t Load(); + + const char* GetPOS(int32_t index) const; + + int32_t GetPOSIndex(const std::string& tag) const; + +private: + std::string file_; + int32_t table_size_{0}; + std::vector pos_vec_; + std::map pos_map_; +}; + +using DartsCore = Darts::DoubleArrayImpl; + +struct DartsTuple +{ + DartsTuple(const std::string& k, const int& v) : key_(k), value_(v) + { + } + + std::string key_; + int value_; +}; + +class DartsTrie +{ + std::unique_ptr darts_; + std::vector buffer_; + +public: + DartsTrie(); + + void Add(const std::string& key, const int& value); + + void Build(); + + void Load(const std::string& file_name); + + void Save(const std::string& file_name); + + bool HasKeysWithPrefix(std::string_view key) const; + + int Traverse(const char* key, std::size_t& node_pos, std::size_t& key_pos, std::size_t length) const; + + int Get(std::string_view key) const; +}; diff --git a/internal/cpp/darts/darts.h b/internal/cpp/darts/darts.h new file mode 100644 index 00000000000..107af203413 --- /dev/null +++ b/internal/cpp/darts/darts.h @@ -0,0 +1,1733 @@ +#ifndef DARTS_H_ +#define DARTS_H_ + +#include +#include +#include + +#define DARTS_VERSION "0.32" + +// DARTS_THROW() throws a whose message starts with the +// file name and the line number. For example, DARTS_THROW("error message") at +// line 123 of "darts.h" throws a which has a pointer to +// "darts.h:123: exception: error message". The message is available by using +// what() as well as that of . +#define DARTS_INT_TO_STR(value) #value +#define DARTS_LINE_TO_STR(line) DARTS_INT_TO_STR(line) +#define DARTS_LINE_STR DARTS_LINE_TO_STR(__LINE__) +#define DARTS_THROW(msg) throw Darts::Details::Exception(__FILE__ ":" DARTS_LINE_STR ": exception: " msg) + +namespace Darts { + +// The following namespace hides the internal types and classes. +namespace Details { + +// This header assumes that and are 32-bit integer types. +// +// Darts-clone keeps values associated with keys. The type of the values is +// . Note that the values must be positive integers because the +// most significant bit (MSB) of each value is used to represent whether the +// corresponding unit is a leaf or not. Also, the keys are represented by +// sequences of s. is the unsigned type of . +typedef char char_type; +typedef unsigned char uchar_type; +typedef int value_type; + +// The main structure of Darts-clone is an array of s, and the +// unit type is actually a wrapper of . +typedef unsigned int id_type; + +// is the type of callback functions for reporting the +// progress of building a dictionary. See also build() of . +// The 1st argument receives the progress value and the 2nd argument receives +// the maximum progress value. A usage example is to show the progress +// percentage, 100.0 * (the 1st argument) / (the 2nd argument). +typedef int (*progress_func_type)(std::size_t, std::size_t); + +// is the type of double-array units and it is a wrapper of +// in practice. +class DoubleArrayUnit { +public: + DoubleArrayUnit() : unit_() {} + + // has_leaf() returns whether a leaf unit is immediately derived from the + // unit (true) or not (false). + bool has_leaf() const { return ((unit_ >> 8) & 1) == 1; } + // value() returns the value stored in the unit, and thus value() is + // available when and only when the unit is a leaf unit. + value_type value() const { return static_cast(unit_ & ((1U << 31) - 1)); } + + // label() returns the label associted with the unit. Note that a leaf unit + // always returns an invalid label. For this feature, leaf unit's label() + // returns an that has the MSB of 1. + id_type label() const { return unit_ & ((1U << 31) | 0xFF); } + // offset() returns the offset from the unit to its derived units. + id_type offset() const { return (unit_ >> 10) << ((unit_ & (1U << 9)) >> 6); } + +private: + id_type unit_; + + // Copyable. +}; + +// Darts-clone throws an for memory allocation failure, invalid +// arguments or a too large offset. The last case means that there are too many +// keys in the given set of keys. Note that the `msg' of must be a +// constant or static string because an keeps only a pointer to +// that string. +class Exception : public std::exception { +public: + explicit Exception(const char *msg = NULL) throw() : msg_(msg) {} + Exception(const Exception &rhs) throw() : msg_(rhs.msg_) {} + virtual ~Exception() throw() {} + + // overrides what() of . + virtual const char *what() const throw() { return (msg_ != NULL) ? msg_ : ""; } + +private: + const char *msg_; + + // Disallows operator=. + Exception &operator=(const Exception &); +}; + +} // namespace Details + +// is the interface of Darts-clone. Note that other +// classes should not be accessed from outside. +// +// has 4 template arguments but only the 3rd one is used as +// the type of values. Note that the given is used only from outside, and +// the internal value type is not changed from . +// In build(), given values are casted from to +// by using static_cast. On the other hand, values are casted from +// to in searching dictionaries. +template +class DoubleArrayImpl { +public: + // Even if this is changed, the internal value type is still + // . Other types, such as 64-bit integer types + // and floating-point number types, should not be used. + typedef T value_type; + // A key is reprenseted by a sequence of s. For example, + // exactMatchSearch() takes a . + typedef Details::char_type key_type; + // In searching dictionaries, the values associated with the matched keys are + // stored into or returned as s. + typedef value_type result_type; + + // enables applications to get the lengths of the matched + // keys in addition to the values. + struct result_pair_type { + value_type value; + std::size_t length; + }; + + // The constructor initializes member variables with 0 and NULLs. + DoubleArrayImpl() : size_(0), array_(NULL), buf_(NULL) {} + // The destructor frees memory allocated for units and then initializes + // member variables with 0 and NULLs. + virtual ~DoubleArrayImpl() { clear(); } + + // has 2 kinds of set_result()s. The 1st set_result() is to + // set a value to a . The 2nd set_result() is to set a value and + // a length to a . By using set_result()s, search methods + // can return the 2 kinds of results in the same way. + // Why the set_result()s are non-static? It is for compatibility. + // + // The 1st set_result() takes a length as the 3rd argument but it is not + // used. If a compiler does a good job, codes for getting the length may be + // removed. + void set_result(value_type *result, value_type value, std::size_t) const { *result = value; } + // The 2nd set_result() uses both `value' and `length'. + void set_result(result_pair_type *result, value_type value, std::size_t length) const { + result->value = value; + result->length = length; + } + + // set_array() calls clear() in order to free memory allocated to the old + // array and then sets a new array. This function is useful to set a memory- + // mapped array. Note that the array set by set_array() is not freed in + // clear() and the destructor of . + // set_array() can also set the size of the new array but the size is not + // used in search methods. So it works well even if the 2nd argument is 0 or + // omitted. Remember that size() and total_size() returns 0 in such a case. + void set_array(const void *ptr, std::size_t size = 0) { + clear(); + array_ = static_cast(ptr); + size_ = size; + } + // array() returns a pointer to the array of units. + const void *array() const { return array_; } + + // clear() frees memory allocated to units and then initializes member + // variables with 0 and NULLs. Note that clear() does not free memory if the + // array of units was set by set_array(). In such a case, `array_' is not + // NULL and `buf_' is NULL. + void clear() { + size_ = 0; + array_ = NULL; + if (buf_ != NULL) { + delete[] buf_; + buf_ = NULL; + } + } + + // unit_size() returns the size of each unit. The size must be 4 bytes. + std::size_t unit_size() const { return sizeof(unit_type); } + // size() returns the number of units. It can be 0 if set_array() is used. + std::size_t size() const { return size_; } + // total_size() returns the number of bytes allocated to the array of units. + // It can be 0 if set_array() is used. + std::size_t total_size() const { return unit_size() * size(); } + // nonzero_size() exists for compatibility. It always returns the number of + // units because it takes long time to count the number of non-zero units. + std::size_t nonzero_size() const { return size(); } + + // build() constructs a dictionary from given key-value pairs. If `lengths' + // is NULL, `keys' is handled as an array of zero-terminated strings. If + // `values' is NULL, the index in `keys' is associated with each key, i.e. + // the ith key has (i - 1) as its value. + // Note that the key-value pairs must be arranged in key order and the values + // must not be negative. Also, if there are duplicate keys, only the first + // pair will be stored in the resultant dictionary. + // `progress_func' is a pointer to a callback function. If it is not NULL, + // it will be called in build() so that the caller can check the progress of + // dictionary construction. For details, please see the definition of + // . + // The return value of build() is 0, and it indicates the success of the + // operation. Otherwise, build() throws a , which is a + // derived class of . + // build() uses another construction algorithm if `values' is not NULL. In + // this case, Darts-clone uses a Directed Acyclic Word Graph (DAWG) instead + // of a trie because a DAWG is likely to be more compact than a trie. + int build(std::size_t num_keys, + const key_type *const *keys, + const std::size_t *lengths = NULL, + const value_type *values = NULL, + Details::progress_func_type progress_func = NULL); + + // open() reads an array of units from the specified file. And if it goes + // well, the old array will be freed and replaced with the new array read + // from the file. `offset' specifies the number of bytes to be skipped before + // reading an array. `size' specifies the number of bytes to be read from the + // file. If the `size' is 0, the whole file will be read. + // open() returns 0 iff the operation succeeds. Otherwise, it returns a + // non-zero value or throws a . The exception is thrown + // when and only when a memory allocation fails. + int open(const char *file_name, const char *mode = "rb", std::size_t offset = 0, std::size_t size = 0); + // save() writes the array of units into the specified file. `offset' + // specifies the number of bytes to be skipped before writing the array. + // open() returns 0 iff the operation succeeds. Otherwise, it returns a + // non-zero value. + int save(const char *file_name, const char *mode = "wb", std::size_t offset = 0) const; + + // The 1st exactMatchSearch() tests whether the given key exists or not, and + // if it exists, its value and length are set to `result'. Otherwise, the + // value and the length of `result' are set to -1 and 0 respectively. + // Note that if `length' is 0, `key' is handled as a zero-terminated string. + // `node_pos' specifies the start position of matching. This argument enables + // the combination of exactMatchSearch() and traverse(). For example, if you + // want to test "xyzA", "xyzBC", and "xyzDE", you can use traverse() to get + // the node position corresponding to "xyz" and then you can use + // exactMatchSearch() to test "A", "BC", and "DE" from that position. + // Note that the length of `result' indicates the length from the `node_pos'. + // In the above example, the lengths are { 1, 2, 2 }, not { 4, 5, 5 }. + template + void exactMatchSearch(const key_type *key, U &result, std::size_t length = 0, std::size_t node_pos = 0) const { + result = exactMatchSearch(key, length, node_pos); + } + // The 2nd exactMatchSearch() returns a result instead of updating the 2nd + // argument. So, the following exactMatchSearch() has only 3 arguments. + template + inline U exactMatchSearch(const key_type *key, std::size_t length = 0, std::size_t node_pos = 0) const; + + // commonPrefixSearch() searches for keys which match a prefix of the given + // string. If `length' is 0, `key' is handled as a zero-terminated string. + // The values and the lengths of at most `max_num_results' matched keys are + // stored in `results'. commonPrefixSearch() returns the number of matched + // keys. Note that the return value can be larger than `max_num_results' if + // there are more than `max_num_results' matches. If you want to get all the + // results, allocate more spaces and call commonPrefixSearch() again. + // `node_pos' works as well as in exactMatchSearch(). + template + inline std::size_t + commonPrefixSearch(const key_type *key, U *results, std::size_t max_num_results, std::size_t length = 0, std::size_t node_pos = 0) const; + + // In Darts-clone, a dictionary is a deterministic finite-state automaton + // (DFA) and traverse() tests transitions on the DFA. The initial state is + // `node_pos' and traverse() chooses transitions labeled key[key_pos], + // key[key_pos + 1], ... in order. If there is not a transition labeled + // key[key_pos + i], traverse() terminates the transitions at that state and + // returns -2. Otherwise, traverse() ends without a termination and returns + // -1 or a nonnegative value, -1 indicates that the final state was not an + // accept state. When a nonnegative value is returned, it is the value + // associated with the final accept state. That is, traverse() returns the + // value associated with the given key if it exists. Note that traverse() + // updates `node_pos' and `key_pos' after each transition. + inline value_type traverse(const key_type *key, std::size_t &node_pos, std::size_t &key_pos, std::size_t length = 0) const; + +private: + typedef Details::uchar_type uchar_type; + typedef Details::id_type id_type; + typedef Details::DoubleArrayUnit unit_type; + + std::size_t size_; + const unit_type *array_; + unit_type *buf_; + + // Disallows copy and assignment. + DoubleArrayImpl(const DoubleArrayImpl &); + DoubleArrayImpl &operator=(const DoubleArrayImpl &); +}; + +// is the typical instance of . It uses +// as the type of values and it is suitable for most cases. +typedef DoubleArrayImpl DoubleArray; + +// The interface section ends here. For using Darts-clone, there is no need +// to read the remaining section, which gives the implementation of +// Darts-clone. + +// +// Member functions of DoubleArrayImpl (except build()). +// + +template +int DoubleArrayImpl::open(const char *file_name, const char *mode, std::size_t offset, std::size_t size) { +#ifdef _MSC_VER + std::FILE *file; + if (::fopen_s(&file, file_name, mode) != 0) { + return -1; + } +#else + std::FILE *file = std::fopen(file_name, mode); + if (file == NULL) { + return -1; + } +#endif + + if (size == 0) { + if (std::fseek(file, 0, SEEK_END) != 0) { + std::fclose(file); + return -1; + } + size = std::ftell(file) - offset; + } + + size /= unit_size(); + if (size < 256 || (size & 0xFF) != 0) { + std::fclose(file); + return -1; + } + + if (std::fseek(file, offset, SEEK_SET) != 0) { + std::fclose(file); + return -1; + } + + unit_type units[256]; + if (std::fread(units, unit_size(), 256, file) != 256) { + std::fclose(file); + return -1; + } + + if (units[0].label() != '\0' || units[0].has_leaf() || units[0].offset() == 0 || units[0].offset() >= 512) { + std::fclose(file); + return -1; + } + for (id_type i = 1; i < 256; ++i) { + if (units[i].label() <= 0xFF && units[i].offset() >= size) { + std::fclose(file); + return -1; + } + } + + unit_type *buf; + try { + buf = new unit_type[size]; + for (id_type i = 0; i < 256; ++i) { + buf[i] = units[i]; + } + } catch (const std::bad_alloc &) { + std::fclose(file); + DARTS_THROW("failed to open double-array: std::bad_alloc"); + } + + if (size > 256) { + if (std::fread(buf + 256, unit_size(), size - 256, file) != size - 256) { + std::fclose(file); + delete[] buf; + return -1; + } + } + std::fclose(file); + + clear(); + + size_ = size; + array_ = buf; + buf_ = buf; + return 0; +} + +template +int DoubleArrayImpl::save(const char *file_name, const char *mode, std::size_t offset) const { + if (size() == 0) { + return -1; + } + +#ifdef _MSC_VER + std::FILE *file; + if (::fopen_s(&file, file_name, mode) != 0) { + return -1; + } +#else + std::FILE *file = std::fopen(file_name, mode); + if (file == NULL) { + return -1; + } +#endif + + if (std::fseek(file, offset, SEEK_SET) != 0) { + std::fclose(file); + return -1; + } + + if (std::fwrite(array_, unit_size(), size(), file) != size()) { + std::fclose(file); + return -1; + } + std::fclose(file); + return 0; +} + +template +template +inline U DoubleArrayImpl::exactMatchSearch(const key_type *key, std::size_t length, std::size_t node_pos) const { + U result; + set_result(&result, static_cast(-1), 0); + + unit_type unit = array_[node_pos]; + if (length != 0) { + for (std::size_t i = 0; i < length; ++i) { + node_pos ^= unit.offset() ^ static_cast(key[i]); + unit = array_[node_pos]; + if (unit.label() != static_cast(key[i])) { + return result; + } + } + } else { + for (; key[length] != '\0'; ++length) { + node_pos ^= unit.offset() ^ static_cast(key[length]); + unit = array_[node_pos]; + if (unit.label() != static_cast(key[length])) { + return result; + } + } + } + + if (!unit.has_leaf()) { + return result; + } + unit = array_[node_pos ^ unit.offset()]; + set_result(&result, static_cast(unit.value()), length); + return result; +} + +template +template +inline std::size_t DoubleArrayImpl::commonPrefixSearch(const key_type *key, + U *results, + std::size_t max_num_results, + std::size_t length, + std::size_t node_pos) const { + std::size_t num_results = 0; + + unit_type unit = array_[node_pos]; + node_pos ^= unit.offset(); + if (length != 0) { + for (std::size_t i = 0; i < length; ++i) { + node_pos ^= static_cast(key[i]); + unit = array_[node_pos]; + if (unit.label() != static_cast(key[i])) { + return num_results; + } + + node_pos ^= unit.offset(); + if (unit.has_leaf()) { + if (num_results < max_num_results) { + set_result(&results[num_results], static_cast(array_[node_pos].value()), i + 1); + } + ++num_results; + } + } + } else { + for (; key[length] != '\0'; ++length) { + node_pos ^= static_cast(key[length]); + unit = array_[node_pos]; + if (unit.label() != static_cast(key[length])) { + return num_results; + } + + node_pos ^= unit.offset(); + if (unit.has_leaf()) { + if (num_results < max_num_results) { + set_result(&results[num_results], static_cast(array_[node_pos].value()), length + 1); + } + ++num_results; + } + } + } + + return num_results; +} + +template +inline typename DoubleArrayImpl::value_type +DoubleArrayImpl::traverse(const key_type *key, std::size_t &node_pos, std::size_t &key_pos, std::size_t length) const { + id_type id = static_cast(node_pos); + unit_type unit = array_[id]; + + if (length != 0) { + for (; key_pos < length; ++key_pos) { + id ^= unit.offset() ^ static_cast(key[key_pos]); + unit = array_[id]; + if (unit.label() != static_cast(key[key_pos])) { + return static_cast(-2); + } + node_pos = id; + } + } else { + for (; key[key_pos] != '\0'; ++key_pos) { + id ^= unit.offset() ^ static_cast(key[key_pos]); + unit = array_[id]; + if (unit.label() != static_cast(key[key_pos])) { + return static_cast(-2); + } + node_pos = id; + } + } + + if (!unit.has_leaf()) { + return static_cast(-1); + } + unit = array_[id ^ unit.offset()]; + return static_cast(unit.value()); +} + +namespace Details { + +// +// Memory management of array. +// + +template +class AutoArray { +public: + explicit AutoArray(T *array = NULL) : array_(array) {} + ~AutoArray() { clear(); } + + const T &operator[](std::size_t id) const { return array_[id]; } + T &operator[](std::size_t id) { return array_[id]; } + + bool empty() const { return array_ == NULL; } + + void clear() { + if (array_ != NULL) { + delete[] array_; + array_ = NULL; + } + } + void swap(AutoArray *array) { + T *temp = array_; + array_ = array->array_; + array->array_ = temp; + } + void reset(T *array = NULL) { AutoArray(array).swap(this); } + +private: + T *array_; + + // Disallows copy and assignment. + AutoArray(const AutoArray &); + AutoArray &operator=(const AutoArray &); +}; + +// +// Memory management of resizable array. +// + +template +class AutoPool { +public: + AutoPool() : buf_(), size_(0), capacity_(0) {} + ~AutoPool() { clear(); } + + const T &operator[](std::size_t id) const { return *(reinterpret_cast(&buf_[0]) + id); } + T &operator[](std::size_t id) { return *(reinterpret_cast(&buf_[0]) + id); } + + bool empty() const { return size_ == 0; } + std::size_t size() const { return size_; } + + void clear() { + resize(0); + buf_.clear(); + size_ = 0; + capacity_ = 0; + } + + void push_back(const T &value) { append(value); } + void pop_back() { (*this)[--size_].~T(); } + + void append() { + if (size_ == capacity_) + resize_buf(size_ + 1); + new (&(*this)[size_++]) T; + } + void append(const T &value) { + if (size_ == capacity_) + resize_buf(size_ + 1); + new (&(*this)[size_++]) T(value); + } + + void resize(std::size_t size) { + while (size_ > size) { + (*this)[--size_].~T(); + } + if (size > capacity_) { + resize_buf(size); + } + while (size_ < size) { + new (&(*this)[size_++]) T; + } + } + void resize(std::size_t size, const T &value) { + while (size_ > size) { + (*this)[--size_].~T(); + } + if (size > capacity_) { + resize_buf(size); + } + while (size_ < size) { + new (&(*this)[size_++]) T(value); + } + } + + void reserve(std::size_t size) { + if (size > capacity_) { + resize_buf(size); + } + } + +private: + AutoArray buf_; + std::size_t size_; + std::size_t capacity_; + + // Disallows copy and assignment. + AutoPool(const AutoPool &); + AutoPool &operator=(const AutoPool &); + + void resize_buf(std::size_t size); +}; + +template +void AutoPool::resize_buf(std::size_t size) { + std::size_t capacity; + if (size >= capacity_ * 2) { + capacity = size; + } else { + capacity = 1; + while (capacity < size) { + capacity <<= 1; + } + } + + AutoArray buf; + try { + buf.reset(new char[sizeof(T) * capacity]); + } catch (const std::bad_alloc &) { + DARTS_THROW("failed to resize pool: std::bad_alloc"); + } + + if (size_ > 0) { + T *src = reinterpret_cast(&buf_[0]); + T *dest = reinterpret_cast(&buf[0]); + for (std::size_t i = 0; i < size_; ++i) { + new (&dest[i]) T(src[i]); + src[i].~T(); + } + } + + buf_.swap(&buf); + capacity_ = capacity; +} + +// +// Memory management of stack. +// + +template +class AutoStack { +public: + AutoStack() : pool_() {} + ~AutoStack() { clear(); } + + const T &top() const { return pool_[size() - 1]; } + T &top() { return pool_[size() - 1]; } + + bool empty() const { return pool_.empty(); } + std::size_t size() const { return pool_.size(); } + + void push(const T &value) { pool_.push_back(value); } + void pop() { pool_.pop_back(); } + + void clear() { pool_.clear(); } + +private: + AutoPool pool_; + + // Disallows copy and assignment. + AutoStack(const AutoStack &); + AutoStack &operator=(const AutoStack &); +}; + +// +// Succinct bit vector. +// + +class BitVector { +public: + BitVector() : units_(), ranks_(), num_ones_(0), size_(0) {} + ~BitVector() { clear(); } + + bool operator[](std::size_t id) const { return (units_[id / UNIT_SIZE] >> (id % UNIT_SIZE) & 1) == 1; } + + id_type rank(std::size_t id) const { + std::size_t unit_id = id / UNIT_SIZE; + return ranks_[unit_id] + pop_count(units_[unit_id] & (~0U >> (UNIT_SIZE - (id % UNIT_SIZE) - 1))); + } + + void set(std::size_t id, bool bit) { + if (bit) { + units_[id / UNIT_SIZE] |= 1U << (id % UNIT_SIZE); + } else { + units_[id / UNIT_SIZE] &= ~(1U << (id % UNIT_SIZE)); + } + } + + bool empty() const { return units_.empty(); } + std::size_t num_ones() const { return num_ones_; } + std::size_t size() const { return size_; } + + void append() { + if ((size_ % UNIT_SIZE) == 0) { + units_.append(0); + } + ++size_; + } + void build(); + + void clear() { + units_.clear(); + ranks_.clear(); + } + +private: + enum { UNIT_SIZE = sizeof(id_type) * 8 }; + + AutoPool units_; + AutoArray ranks_; + std::size_t num_ones_; + std::size_t size_; + + // Disallows copy and assignment. + BitVector(const BitVector &); + BitVector &operator=(const BitVector &); + + static id_type pop_count(id_type unit) { + unit = ((unit & 0xAAAAAAAA) >> 1) + (unit & 0x55555555); + unit = ((unit & 0xCCCCCCCC) >> 2) + (unit & 0x33333333); + unit = ((unit >> 4) + unit) & 0x0F0F0F0F; + unit += unit >> 8; + unit += unit >> 16; + return unit & 0xFF; + } +}; + +inline void BitVector::build() { + try { + ranks_.reset(new id_type[units_.size()]); + } catch (const std::bad_alloc &) { + DARTS_THROW("failed to build rank index: std::bad_alloc"); + } + + num_ones_ = 0; + for (std::size_t i = 0; i < units_.size(); ++i) { + ranks_[i] = num_ones_; + num_ones_ += pop_count(units_[i]); + } +} + +// +// Keyset. +// + +template +class Keyset { +public: + Keyset(std::size_t num_keys, const char_type *const *keys, const std::size_t *lengths, const T *values) + : num_keys_(num_keys), keys_(keys), lengths_(lengths), values_(values) {} + + std::size_t num_keys() const { return num_keys_; } + const char_type *keys(std::size_t id) const { return keys_[id]; } + uchar_type keys(std::size_t key_id, std::size_t char_id) const { + if (has_lengths() && char_id >= lengths_[key_id]) + return '\0'; + return keys_[key_id][char_id]; + } + + bool has_lengths() const { return lengths_ != NULL; } + std::size_t lengths(std::size_t id) const { + if (has_lengths()) { + return lengths_[id]; + } + std::size_t length = 0; + while (keys_[id][length] != '\0') { + ++length; + } + return length; + } + + bool has_values() const { return values_ != NULL; } + value_type values(std::size_t id) const { + if (has_values()) { + return static_cast(values_[id]); + } + return static_cast(id); + } + +private: + std::size_t num_keys_; + const char_type *const *keys_; + const std::size_t *lengths_; + const T *values_; + + // Disallows copy and assignment. + Keyset(const Keyset &); + Keyset &operator=(const Keyset &); +}; + +// +// Node of Directed Acyclic Word Graph (DAWG). +// + +class DawgNode { +public: + DawgNode() : child_(0), sibling_(0), label_('\0'), is_state_(false), has_sibling_(false) {} + + void set_child(id_type child) { child_ = child; } + void set_sibling(id_type sibling) { sibling_ = sibling; } + void set_value(value_type value) { child_ = value; } + void set_label(uchar_type label) { label_ = label; } + void set_is_state(bool is_state) { is_state_ = is_state; } + void set_has_sibling(bool has_sibling) { has_sibling_ = has_sibling; } + + id_type child() const { return child_; } + id_type sibling() const { return sibling_; } + value_type value() const { return static_cast(child_); } + uchar_type label() const { return label_; } + bool is_state() const { return is_state_; } + bool has_sibling() const { return has_sibling_; } + + id_type unit() const { + if (label_ == '\0') { + return (child_ << 1) | (has_sibling_ ? 1 : 0); + } + return (child_ << 2) | (is_state_ ? 2 : 0) | (has_sibling_ ? 1 : 0); + } + +private: + id_type child_; + id_type sibling_; + uchar_type label_; + bool is_state_; + bool has_sibling_; + + // Copyable. +}; + +// +// Fixed unit of Directed Acyclic Word Graph (DAWG). +// + +class DawgUnit { +public: + explicit DawgUnit(id_type unit = 0) : unit_(unit) {} + DawgUnit(const DawgUnit &unit) : unit_(unit.unit_) {} + + DawgUnit &operator=(id_type unit) { + unit_ = unit; + return *this; + } + + id_type unit() const { return unit_; } + + id_type child() const { return unit_ >> 2; } + bool has_sibling() const { return (unit_ & 1) == 1; } + value_type value() const { return static_cast(unit_ >> 1); } + bool is_state() const { return (unit_ & 2) == 2; } + +private: + id_type unit_; + + // Copyable. +}; + +// +// Directed Acyclic Word Graph (DAWG) builder. +// + +class DawgBuilder { +public: + DawgBuilder() : nodes_(), units_(), labels_(), is_intersections_(), table_(), node_stack_(), recycle_bin_(), num_states_(0) {} + ~DawgBuilder() { clear(); } + + id_type root() const { return 0; } + + id_type child(id_type id) const { return units_[id].child(); } + id_type sibling(id_type id) const { return units_[id].has_sibling() ? (id + 1) : 0; } + int value(id_type id) const { return units_[id].value(); } + + bool is_leaf(id_type id) const { return label(id) == '\0'; } + uchar_type label(id_type id) const { return labels_[id]; } + + bool is_intersection(id_type id) const { return is_intersections_[id]; } + id_type intersection_id(id_type id) const { return is_intersections_.rank(id) - 1; } + + std::size_t num_intersections() const { return is_intersections_.num_ones(); } + + std::size_t size() const { return units_.size(); } + + void init(); + void finish(); + + void insert(const char *key, std::size_t length, value_type value); + + void clear(); + +private: + enum { INITIAL_TABLE_SIZE = 1 << 10 }; + + AutoPool nodes_; + AutoPool units_; + AutoPool labels_; + BitVector is_intersections_; + AutoPool table_; + AutoStack node_stack_; + AutoStack recycle_bin_; + std::size_t num_states_; + + // Disallows copy and assignment. + DawgBuilder(const DawgBuilder &); + DawgBuilder &operator=(const DawgBuilder &); + + void flush(id_type id); + + void expand_table(); + + id_type find_unit(id_type id, id_type *hash_id) const; + id_type find_node(id_type node_id, id_type *hash_id) const; + + bool are_equal(id_type node_id, id_type unit_id) const; + + id_type hash_unit(id_type id) const; + id_type hash_node(id_type id) const; + + id_type append_node(); + id_type append_unit(); + + void free_node(id_type id) { recycle_bin_.push(id); } + + static id_type hash(id_type key) { + key = ~key + (key << 15); // key = (key << 15) - key - 1; + key = key ^ (key >> 12); + key = key + (key << 2); + key = key ^ (key >> 4); + key = key * 2057; // key = (key + (key << 3)) + (key << 11); + key = key ^ (key >> 16); + return key; + } +}; + +inline void DawgBuilder::init() { + table_.resize(INITIAL_TABLE_SIZE, 0); + + append_node(); + append_unit(); + + num_states_ = 1; + + nodes_[0].set_label(0xFF); + node_stack_.push(0); +} + +inline void DawgBuilder::finish() { + flush(0); + + units_[0] = nodes_[0].unit(); + labels_[0] = nodes_[0].label(); + + nodes_.clear(); + table_.clear(); + node_stack_.clear(); + recycle_bin_.clear(); + + is_intersections_.build(); +} + +inline void DawgBuilder::insert(const char *key, std::size_t length, value_type value) { + if (value < 0) { + DARTS_THROW("failed to insert key: negative value"); + } else if (length == 0) { + DARTS_THROW("failed to insert key: zero-length key"); + } + + id_type id = 0; + std::size_t key_pos = 0; + + for (; key_pos <= length; ++key_pos) { + id_type child_id = nodes_[id].child(); + if (child_id == 0) { + break; + } + + uchar_type key_label = static_cast(key[key_pos]); + if (key_pos < length && key_label == '\0') { + DARTS_THROW("failed to insert key: invalid null character"); + } + + uchar_type unit_label = nodes_[child_id].label(); + if (key_label < unit_label) { + DARTS_THROW("failed to insert key: wrong key order"); + } else if (key_label > unit_label) { + nodes_[child_id].set_has_sibling(true); + flush(child_id); + break; + } + id = child_id; + } + + if (key_pos > length) { + return; + } + + for (; key_pos <= length; ++key_pos) { + uchar_type key_label = static_cast((key_pos < length) ? key[key_pos] : '\0'); + id_type child_id = append_node(); + + if (nodes_[id].child() == 0) { + nodes_[child_id].set_is_state(true); + } + nodes_[child_id].set_sibling(nodes_[id].child()); + nodes_[child_id].set_label(key_label); + nodes_[id].set_child(child_id); + node_stack_.push(child_id); + + id = child_id; + } + nodes_[id].set_value(value); +} + +inline void DawgBuilder::clear() { + nodes_.clear(); + units_.clear(); + labels_.clear(); + is_intersections_.clear(); + table_.clear(); + node_stack_.clear(); + recycle_bin_.clear(); + num_states_ = 0; +} + +inline void DawgBuilder::flush(id_type id) { + while (node_stack_.top() != id) { + id_type node_id = node_stack_.top(); + node_stack_.pop(); + + if (num_states_ >= table_.size() - (table_.size() >> 2)) { + expand_table(); + } + + id_type num_siblings = 0; + for (id_type i = node_id; i != 0; i = nodes_[i].sibling()) { + ++num_siblings; + } + + id_type hash_id; + id_type match_id = find_node(node_id, &hash_id); + if (match_id != 0) { + is_intersections_.set(match_id, true); + } else { + id_type unit_id = 0; + for (id_type i = 0; i < num_siblings; ++i) { + unit_id = append_unit(); + } + for (id_type i = node_id; i != 0; i = nodes_[i].sibling()) { + units_[unit_id] = nodes_[i].unit(); + labels_[unit_id] = nodes_[i].label(); + --unit_id; + } + match_id = unit_id + 1; + table_[hash_id] = match_id; + ++num_states_; + } + + for (id_type i = node_id, next; i != 0; i = next) { + next = nodes_[i].sibling(); + free_node(i); + } + + nodes_[node_stack_.top()].set_child(match_id); + } + node_stack_.pop(); +} + +inline void DawgBuilder::expand_table() { + std::size_t table_size = table_.size() << 1; + table_.clear(); + table_.resize(table_size, 0); + + for (std::size_t i = 1; i < units_.size(); ++i) { + id_type id = static_cast(i); + if (labels_[id] == '\0' || units_[id].is_state()) { + id_type hash_id; + find_unit(id, &hash_id); + table_[hash_id] = id; + } + } +} + +inline id_type DawgBuilder::find_unit(id_type id, id_type *hash_id) const { + *hash_id = hash_unit(id) % table_.size(); + for (;; *hash_id = (*hash_id + 1) % table_.size()) { + id_type unit_id = table_[*hash_id]; + if (unit_id == 0) { + break; + } + + // There must not be the same unit. + } + return 0; +} + +inline id_type DawgBuilder::find_node(id_type node_id, id_type *hash_id) const { + *hash_id = hash_node(node_id) % table_.size(); + for (;; *hash_id = (*hash_id + 1) % table_.size()) { + id_type unit_id = table_[*hash_id]; + if (unit_id == 0) { + break; + } + + if (are_equal(node_id, unit_id)) { + return unit_id; + } + } + return 0; +} + +inline bool DawgBuilder::are_equal(id_type node_id, id_type unit_id) const { + for (id_type i = nodes_[node_id].sibling(); i != 0; i = nodes_[i].sibling()) { + if (units_[unit_id].has_sibling() == false) { + return false; + } + ++unit_id; + } + if (units_[unit_id].has_sibling() == true) { + return false; + } + + for (id_type i = node_id; i != 0; i = nodes_[i].sibling(), --unit_id) { + if (nodes_[i].unit() != units_[unit_id].unit() || nodes_[i].label() != labels_[unit_id]) { + return false; + } + } + return true; +} + +inline id_type DawgBuilder::hash_unit(id_type id) const { + id_type hash_value = 0; + for (; id != 0; ++id) { + id_type unit = units_[id].unit(); + uchar_type label = labels_[id]; + hash_value ^= hash((label << 24) ^ unit); + + if (units_[id].has_sibling() == false) { + break; + } + } + return hash_value; +} + +inline id_type DawgBuilder::hash_node(id_type id) const { + id_type hash_value = 0; + for (; id != 0; id = nodes_[id].sibling()) { + id_type unit = nodes_[id].unit(); + uchar_type label = nodes_[id].label(); + hash_value ^= hash((label << 24) ^ unit); + } + return hash_value; +} + +inline id_type DawgBuilder::append_unit() { + is_intersections_.append(); + units_.append(); + labels_.append(); + + return static_cast(is_intersections_.size() - 1); +} + +inline id_type DawgBuilder::append_node() { + id_type id; + if (recycle_bin_.empty()) { + id = static_cast(nodes_.size()); + nodes_.append(); + } else { + id = recycle_bin_.top(); + nodes_[id] = DawgNode(); + recycle_bin_.pop(); + } + return id; +} + +// +// Unit of double-array builder. +// + +class DoubleArrayBuilderUnit { +public: + DoubleArrayBuilderUnit() : unit_(0) {} + + void set_has_leaf(bool has_leaf) { + if (has_leaf) { + unit_ |= 1U << 8; + } else { + unit_ &= ~(1U << 8); + } + } + void set_value(value_type value) { unit_ = value | (1U << 31); } + void set_label(uchar_type label) { unit_ = (unit_ & ~0xFFU) | label; } + void set_offset(id_type offset) { + if (offset >= 1U << 29) { + DARTS_THROW("failed to modify unit: too large offset"); + } + unit_ &= (1U << 31) | (1U << 8) | 0xFF; + if (offset < 1U << 21) { + unit_ |= (offset << 10); + } else { + unit_ |= (offset << 2) | (1U << 9); + } + } + +private: + id_type unit_; + + // Copyable. +}; + +// +// Extra unit of double-array builder. +// + +class DoubleArrayBuilderExtraUnit { +public: + DoubleArrayBuilderExtraUnit() : prev_(0), next_(0), is_fixed_(false), is_used_(false) {} + + void set_prev(id_type prev) { prev_ = prev; } + void set_next(id_type next) { next_ = next; } + void set_is_fixed(bool is_fixed) { is_fixed_ = is_fixed; } + void set_is_used(bool is_used) { is_used_ = is_used; } + + id_type prev() const { return prev_; } + id_type next() const { return next_; } + bool is_fixed() const { return is_fixed_; } + bool is_used() const { return is_used_; } + +private: + id_type prev_; + id_type next_; + bool is_fixed_; + bool is_used_; + + // Copyable. +}; + +// +// DAWG -> double-array converter. +// + +class DoubleArrayBuilder { +public: + explicit DoubleArrayBuilder(progress_func_type progress_func) + : progress_func_(progress_func), units_(), extras_(), labels_(), table_(), extras_head_(0) {} + ~DoubleArrayBuilder() { clear(); } + + template + void build(const Keyset &keyset); + void copy(std::size_t *size_ptr, DoubleArrayUnit **buf_ptr) const; + + void clear(); + +private: + static constexpr auto BLOCK_SIZE = 256; + static constexpr auto NUM_EXTRA_BLOCKS = 16; + static constexpr auto NUM_EXTRAS = BLOCK_SIZE * NUM_EXTRA_BLOCKS; + + enum { UPPER_MASK = 0xFF << 21 }; + enum { LOWER_MASK = 0xFF }; + + typedef DoubleArrayBuilderUnit unit_type; + typedef DoubleArrayBuilderExtraUnit extra_type; + + progress_func_type progress_func_; + AutoPool units_; + AutoArray extras_; + AutoPool labels_; + AutoArray table_; + id_type extras_head_; + + // Disallows copy and assignment. + DoubleArrayBuilder(const DoubleArrayBuilder &); + DoubleArrayBuilder &operator=(const DoubleArrayBuilder &); + + std::size_t num_blocks() const { return units_.size() / BLOCK_SIZE; } + + const extra_type &extras(id_type id) const { return extras_[id % NUM_EXTRAS]; } + extra_type &extras(id_type id) { return extras_[id % NUM_EXTRAS]; } + + template + void build_dawg(const Keyset &keyset, DawgBuilder *dawg_builder); + void build_from_dawg(const DawgBuilder &dawg); + void build_from_dawg(const DawgBuilder &dawg, id_type dawg_id, id_type dic_id); + id_type arrange_from_dawg(const DawgBuilder &dawg, id_type dawg_id, id_type dic_id); + + template + void build_from_keyset(const Keyset &keyset); + template + void build_from_keyset(const Keyset &keyset, std::size_t begin, std::size_t end, std::size_t depth, id_type dic_id); + template + id_type arrange_from_keyset(const Keyset &keyset, std::size_t begin, std::size_t end, std::size_t depth, id_type dic_id); + + id_type find_valid_offset(id_type id) const; + bool is_valid_offset(id_type id, id_type offset) const; + + void reserve_id(id_type id); + void expand_units(); + + void fix_all_blocks(); + void fix_block(id_type block_id); +}; + +template +void DoubleArrayBuilder::build(const Keyset &keyset) { + if (keyset.has_values()) { + Details::DawgBuilder dawg_builder; + build_dawg(keyset, &dawg_builder); + build_from_dawg(dawg_builder); + dawg_builder.clear(); + } else { + build_from_keyset(keyset); + } +} + +inline void DoubleArrayBuilder::copy(std::size_t *size_ptr, DoubleArrayUnit **buf_ptr) const { + if (size_ptr != NULL) { + *size_ptr = units_.size(); + } + if (buf_ptr != NULL) { + *buf_ptr = new DoubleArrayUnit[units_.size()]; + unit_type *units = reinterpret_cast(*buf_ptr); + for (std::size_t i = 0; i < units_.size(); ++i) { + units[i] = units_[i]; + } + } +} + +inline void DoubleArrayBuilder::clear() { + units_.clear(); + extras_.clear(); + labels_.clear(); + table_.clear(); + extras_head_ = 0; +} + +template +void DoubleArrayBuilder::build_dawg(const Keyset &keyset, DawgBuilder *dawg_builder) { + dawg_builder->init(); + for (std::size_t i = 0; i < keyset.num_keys(); ++i) { + dawg_builder->insert(keyset.keys(i), keyset.lengths(i), keyset.values(i)); + if (progress_func_ != NULL) { + progress_func_(i + 1, keyset.num_keys() + 1); + } + } + dawg_builder->finish(); +} + +inline void DoubleArrayBuilder::build_from_dawg(const DawgBuilder &dawg) { + std::size_t num_units = 1; + while (num_units < dawg.size()) { + num_units <<= 1; + } + units_.reserve(num_units); + + table_.reset(new id_type[dawg.num_intersections()]); + for (std::size_t i = 0; i < dawg.num_intersections(); ++i) { + table_[i] = 0; + } + + extras_.reset(new extra_type[NUM_EXTRAS]); + + reserve_id(0); + extras(0).set_is_used(true); + units_[0].set_offset(1); + units_[0].set_label('\0'); + + if (dawg.child(dawg.root()) != 0) { + build_from_dawg(dawg, dawg.root(), 0); + } + + fix_all_blocks(); + + extras_.clear(); + labels_.clear(); + table_.clear(); +} + +inline void DoubleArrayBuilder::build_from_dawg(const DawgBuilder &dawg, id_type dawg_id, id_type dic_id) { + id_type dawg_child_id = dawg.child(dawg_id); + if (dawg.is_intersection(dawg_child_id)) { + id_type intersection_id = dawg.intersection_id(dawg_child_id); + id_type offset = table_[intersection_id]; + if (offset != 0) { + offset ^= dic_id; + if (!(offset & UPPER_MASK) || !(offset & LOWER_MASK)) { + if (dawg.is_leaf(dawg_child_id)) { + units_[dic_id].set_has_leaf(true); + } + units_[dic_id].set_offset(offset); + return; + } + } + } + + id_type offset = arrange_from_dawg(dawg, dawg_id, dic_id); + if (dawg.is_intersection(dawg_child_id)) { + table_[dawg.intersection_id(dawg_child_id)] = offset; + } + + do { + uchar_type child_label = dawg.label(dawg_child_id); + id_type dic_child_id = offset ^ child_label; + if (child_label != '\0') { + build_from_dawg(dawg, dawg_child_id, dic_child_id); + } + dawg_child_id = dawg.sibling(dawg_child_id); + } while (dawg_child_id != 0); +} + +inline id_type DoubleArrayBuilder::arrange_from_dawg(const DawgBuilder &dawg, id_type dawg_id, id_type dic_id) { + labels_.resize(0); + + id_type dawg_child_id = dawg.child(dawg_id); + while (dawg_child_id != 0) { + labels_.append(dawg.label(dawg_child_id)); + dawg_child_id = dawg.sibling(dawg_child_id); + } + + id_type offset = find_valid_offset(dic_id); + units_[dic_id].set_offset(dic_id ^ offset); + + dawg_child_id = dawg.child(dawg_id); + for (std::size_t i = 0; i < labels_.size(); ++i) { + id_type dic_child_id = offset ^ labels_[i]; + reserve_id(dic_child_id); + + if (dawg.is_leaf(dawg_child_id)) { + units_[dic_id].set_has_leaf(true); + units_[dic_child_id].set_value(dawg.value(dawg_child_id)); + } else { + units_[dic_child_id].set_label(labels_[i]); + } + + dawg_child_id = dawg.sibling(dawg_child_id); + } + extras(offset).set_is_used(true); + + return offset; +} + +template +void DoubleArrayBuilder::build_from_keyset(const Keyset &keyset) { + std::size_t num_units = 1; + while (num_units < keyset.num_keys()) { + num_units <<= 1; + } + units_.reserve(num_units); + + extras_.reset(new extra_type[NUM_EXTRAS]); + + reserve_id(0); + extras(0).set_is_used(true); + units_[0].set_offset(1); + units_[0].set_label('\0'); + + if (keyset.num_keys() > 0) { + build_from_keyset(keyset, 0, keyset.num_keys(), 0, 0); + } + + fix_all_blocks(); + + extras_.clear(); + labels_.clear(); +} + +template +void DoubleArrayBuilder::build_from_keyset(const Keyset &keyset, std::size_t begin, std::size_t end, std::size_t depth, id_type dic_id) { + id_type offset = arrange_from_keyset(keyset, begin, end, depth, dic_id); + + while (begin < end) { + if (keyset.keys(begin, depth) != '\0') { + break; + } + ++begin; + } + if (begin == end) { + return; + } + + std::size_t last_begin = begin; + uchar_type last_label = keyset.keys(begin, depth); + while (++begin < end) { + uchar_type label = keyset.keys(begin, depth); + if (label != last_label) { + build_from_keyset(keyset, last_begin, begin, depth + 1, offset ^ last_label); + last_begin = begin; + last_label = keyset.keys(begin, depth); + } + } + build_from_keyset(keyset, last_begin, end, depth + 1, offset ^ last_label); +} + +template +id_type DoubleArrayBuilder::arrange_from_keyset(const Keyset &keyset, std::size_t begin, std::size_t end, std::size_t depth, id_type dic_id) { + labels_.resize(0); + + value_type value = -1; + for (std::size_t i = begin; i < end; ++i) { + uchar_type label = keyset.keys(i, depth); + if (label == '\0') { + if (keyset.has_lengths() && depth < keyset.lengths(i)) { + DARTS_THROW("failed to build double-array: " + "invalid null character"); + } else if (keyset.values(i) < 0) { + DARTS_THROW("failed to build double-array: negative value"); + } + + if (value == -1) { + value = keyset.values(i); + } + if (progress_func_ != NULL) { + progress_func_(i + 1, keyset.num_keys() + 1); + } + } + + if (labels_.empty()) { + labels_.append(label); + } else if (label != labels_[labels_.size() - 1]) { + if (label < labels_[labels_.size() - 1]) { + DARTS_THROW("failed to build double-array: wrong key order"); + } + labels_.append(label); + } + } + + id_type offset = find_valid_offset(dic_id); + units_[dic_id].set_offset(dic_id ^ offset); + + for (std::size_t i = 0; i < labels_.size(); ++i) { + id_type dic_child_id = offset ^ labels_[i]; + reserve_id(dic_child_id); + if (labels_[i] == '\0') { + units_[dic_id].set_has_leaf(true); + units_[dic_child_id].set_value(value); + } else { + units_[dic_child_id].set_label(labels_[i]); + } + } + extras(offset).set_is_used(true); + + return offset; +} + +inline id_type DoubleArrayBuilder::find_valid_offset(id_type id) const { + if (extras_head_ >= units_.size()) { + return units_.size() | (id & LOWER_MASK); + } + + id_type unfixed_id = extras_head_; + do { + id_type offset = unfixed_id ^ labels_[0]; + if (is_valid_offset(id, offset)) { + return offset; + } + unfixed_id = extras(unfixed_id).next(); + } while (unfixed_id != extras_head_); + + return units_.size() | (id & LOWER_MASK); +} + +inline bool DoubleArrayBuilder::is_valid_offset(id_type id, id_type offset) const { + if (extras(offset).is_used()) { + return false; + } + + id_type rel_offset = id ^ offset; + if ((rel_offset & LOWER_MASK) && (rel_offset & UPPER_MASK)) { + return false; + } + + for (std::size_t i = 1; i < labels_.size(); ++i) { + if (extras(offset ^ labels_[i]).is_fixed()) { + return false; + } + } + + return true; +} + +inline void DoubleArrayBuilder::reserve_id(id_type id) { + if (id >= units_.size()) { + expand_units(); + } + + if (id == extras_head_) { + extras_head_ = extras(id).next(); + if (extras_head_ == id) { + extras_head_ = units_.size(); + } + } + extras(extras(id).prev()).set_next(extras(id).next()); + extras(extras(id).next()).set_prev(extras(id).prev()); + extras(id).set_is_fixed(true); +} + +inline void DoubleArrayBuilder::expand_units() { + id_type src_num_units = units_.size(); + id_type src_num_blocks = num_blocks(); + + id_type dest_num_units = src_num_units + BLOCK_SIZE; + id_type dest_num_blocks = src_num_blocks + 1; + + if (dest_num_blocks > NUM_EXTRA_BLOCKS) { + fix_block(src_num_blocks - NUM_EXTRA_BLOCKS); + } + + units_.resize(dest_num_units); + + if (dest_num_blocks > NUM_EXTRA_BLOCKS) { + for (std::size_t id = src_num_units; id < dest_num_units; ++id) { + extras(id).set_is_used(false); + extras(id).set_is_fixed(false); + } + } + + for (id_type i = src_num_units + 1; i < dest_num_units; ++i) { + extras(i - 1).set_next(i); + extras(i).set_prev(i - 1); + } + + extras(src_num_units).set_prev(dest_num_units - 1); + extras(dest_num_units - 1).set_next(src_num_units); + + extras(src_num_units).set_prev(extras(extras_head_).prev()); + extras(dest_num_units - 1).set_next(extras_head_); + + extras(extras(extras_head_).prev()).set_next(src_num_units); + extras(extras_head_).set_prev(dest_num_units - 1); +} + +inline void DoubleArrayBuilder::fix_all_blocks() { + id_type begin = 0; + if (num_blocks() > NUM_EXTRA_BLOCKS) { + begin = num_blocks() - NUM_EXTRA_BLOCKS; + } + id_type end = num_blocks(); + + for (id_type block_id = begin; block_id != end; ++block_id) { + fix_block(block_id); + } +} + +inline void DoubleArrayBuilder::fix_block(id_type block_id) { + id_type begin = block_id * BLOCK_SIZE; + id_type end = begin + BLOCK_SIZE; + + id_type unused_offset = 0; + for (id_type offset = begin; offset != end; ++offset) { + if (!extras(offset).is_used()) { + unused_offset = offset; + break; + } + } + + for (id_type id = begin; id != end; ++id) { + if (!extras(id).is_fixed()) { + reserve_id(id); + units_[id].set_label(static_cast(id ^ unused_offset)); + } + } +} + +} // namespace Details + +// +// Member function build() of DoubleArrayImpl. +// + +template +int DoubleArrayImpl::build(std::size_t num_keys, + const key_type *const *keys, + const std::size_t *lengths, + const value_type *values, + Details::progress_func_type progress_func) { + Details::Keyset keyset(num_keys, keys, lengths, values); + + Details::DoubleArrayBuilder builder(progress_func); + builder.build(keyset); + + std::size_t size = 0; + unit_type *buf = NULL; + builder.copy(&size, &buf); + + clear(); + + size_ = size; + array_ = buf; + buf_ = buf; + + if (progress_func != NULL) { + progress_func(num_keys + 1, num_keys + 1); + } + + return 0; +} + +} // namespace Darts + +#undef DARTS_INT_TO_STR +#undef DARTS_LINE_TO_STR +#undef DARTS_LINE_STR +#undef DARTS_THROW + +#endif // DARTS_H_ diff --git a/internal/cpp/darts_trie.cpp b/internal/cpp/darts_trie.cpp new file mode 100644 index 00000000000..15b103b33ea --- /dev/null +++ b/internal/cpp/darts_trie.cpp @@ -0,0 +1,109 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "dart_trie.h" + +#include +#include +#include + +POSTable::POSTable(const std::string &file_name) : file_(file_name) { +} + +int32_t POSTable::Load() { + std::ifstream from(file_); + if (!from.good()) { + return -1; + // return Status::InvalidAnalyzerFile(file_); + } + + std::string line; + int32_t index = 0; + + while (getline(from, line)) { + line = line.substr(0, line.find('\r')); + if (line.empty()) + continue; + pos_map_[line] = index; + } + + for (auto &x : pos_map_) { + x.second = index++; + pos_vec_.push_back(x.first); + } + return 0; + // return Status::OK(); +} + +const char *POSTable::GetPOS(int32_t index) const { + if (index < 0 || index >= table_size_) + return ""; + + return pos_vec_[index].c_str(); +} + +int32_t POSTable::GetPOSIndex(const std::string &tag) const { + std::map::const_iterator it = pos_map_.find(tag); + if (it != pos_map_.end()) + return it->second; + return -1; +} + +DartsTrie::DartsTrie() : darts_{std::make_unique()} { +} + +void DartsTrie::Add(const std::string &key, const int &value) { buffer_.push_back(DartsTuple(key, value)); } + +void DartsTrie::Build() { + std::sort(buffer_.begin(), buffer_.end(), [](const DartsTuple &l, const DartsTuple &r) { return l.key_ < r.key_; }); + std::vector keys; + std::vector lengths; + std::vector values; + for (auto &o : buffer_) { + keys.push_back(o.key_.c_str()); + lengths.push_back(o.key_.size()); + values.push_back(o.value_); + } + darts_->build(keys.size(), keys.data(), lengths.data(), values.data(), nullptr); + buffer_.clear(); +} + +void DartsTrie::Load(const std::string &file_name) { darts_->open(file_name.c_str()); } + +void DartsTrie::Save(const std::string &file_name) { darts_->save(file_name.c_str()); } + +// string literal "" is null-terminated +constexpr std::string_view empty_null_terminated_sv = ""; + +bool DartsTrie::HasKeysWithPrefix(std::string_view key) const { + if (key.empty()) [[unlikely]] { + key = empty_null_terminated_sv; + } + std::size_t id = 0; + std::size_t key_pos = 0; + const auto result = darts_->traverse(key.data(), id, key_pos, key.size()); + return result != -2; +} + +int DartsTrie::Traverse(const char *key, std::size_t &node_pos, std::size_t &key_pos, const std::size_t length) const { + return darts_->traverse(key, node_pos, key_pos, length); +} + +int DartsTrie::Get(std::string_view key) const { + if (key.empty()) [[unlikely]] { + key = empty_null_terminated_sv; + } + return darts_->exactMatchSearch(key.data(), key.size()); +} \ No newline at end of file diff --git a/internal/cpp/main.cpp b/internal/cpp/main.cpp new file mode 100644 index 00000000000..fb8c38d6f0b --- /dev/null +++ b/internal/cpp/main.cpp @@ -0,0 +1,442 @@ +// +// Created by infiniflow on 2/2/26. +// + +#include +#include +#include +#include +#include +#include +#include "rag_analyzer.h" + +namespace fs = std::filesystem; + +void test_analyze_enable_position() { + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + + std::string rag_tokenizer_path_ = "test"; + std::string input_file_ = rag_tokenizer_path_ + "/tokenizer_input.txt"; + + std::cout << "Looking for input file: " << input_file_ << std::endl; + std::cout << "Current directory: " << fs::current_path() << std::endl; + + if (!fs::exists(input_file_)) { + std::cerr << "ERROR: Input file doesn't exist: " << input_file_ << std::endl; + std::cerr << "Full path: " << fs::absolute(input_file_) << std::endl; + return; + } + + std::ifstream infile(input_file_); + if (!infile.is_open()) { + std::cerr << "ERROR: Cannot open file: " << input_file_ << std::endl; + std::cerr << "Error code: " << strerror(errno) << std::endl; + return; + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::cout << "File size: " << file_size << " bytes" << std::endl; + + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + analyzer_->SetEnablePosition(false); + analyzer_->SetFineGrained(false); + + analyzer_->SetEnablePosition(true); + analyzer_->SetFineGrained(false); + + std::string line; + while (std::getline(infile, line)) { + if (line.empty()) + continue; + + TermList term_list; + analyzer_->Analyze(line, term_list); + std::cout << "Input text: " << std::endl << line << std::endl; + + std::cout << "Analyze result: " << std::endl; + for (unsigned i = 0; i < term_list.size(); ++i) { + std::cout << "[" << term_list[i].text_ << "@" << term_list[i].word_offset_ << "," << term_list[i]. + end_offset_ << "] "; + } + std::cout << std::endl; + } + infile.close(); + + delete analyzer_; + analyzer_ = nullptr; +} + +void test_analyze_enable_position_fine_grained() { + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + + std::string rag_tokenizer_path_ = "test"; + std::string input_file_ = rag_tokenizer_path_ + "/tokenizer_input.txt"; + + std::cout << "Looking for input file: " << input_file_ << std::endl; + std::cout << "Current directory: " << fs::current_path() << std::endl; + + if (!fs::exists(input_file_)) { + std::cerr << "ERROR: Input file doesn't exist: " << input_file_ << std::endl; + std::cerr << "Full path: " << fs::absolute(input_file_) << std::endl; + return; + } + + std::ifstream infile(input_file_); + if (!infile.is_open()) { + std::cerr << "ERROR: Cannot open file: " << input_file_ << std::endl; + std::cerr << "Error code: " << strerror(errno) << std::endl; + return; + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::cout << "File size: " << file_size << " bytes" << std::endl; + + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + analyzer_->SetEnablePosition(true); + analyzer_->SetFineGrained(true); + + std::string line; + + while (std::getline(infile, line)) { + if (line.empty()) + continue; + + TermList term_list; + analyzer_->Analyze(line, term_list); + std::cout << "Input text: " << std::endl << line << std::endl; + + std::cout << "Analyze result: " << std::endl; + for (unsigned i = 0; i < term_list.size(); ++i) { + std::cout << "[" << term_list[i].text_ << "@" << term_list[i].word_offset_ << "," << term_list[i]. + end_offset_ << "] "; + } + std::cout << std::endl; + } + infile.close(); + + delete analyzer_; + analyzer_ = nullptr; +} + +void test_tokenize_consistency_with_position() { + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + + std::string rag_tokenizer_path_ = "test"; + std::string input_file_ = rag_tokenizer_path_ + "/tokenizer_input.txt"; + + std::cout << "Looking for input file: " << input_file_ << std::endl; + std::cout << "Current directory: " << fs::current_path() << std::endl; + + if (!fs::exists(input_file_)) { + std::cerr << "ERROR: Input file doesn't exist: " << input_file_ << std::endl; + std::cerr << "Full path: " << fs::absolute(input_file_) << std::endl; + return; + } + + std::ifstream infile(input_file_); + if (!infile.is_open()) { + std::cerr << "ERROR: Cannot open file: " << input_file_ << std::endl; + std::cerr << "Error code: " << strerror(errno) << std::endl; + return; + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::cout << "File size: " << file_size << " bytes" << std::endl; + + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + std::string line; + + while (std::getline(infile, line)) { + if (line.empty()) + continue; + + // Test Tokenize (returns string) + std::string tokens_str = analyzer_->Tokenize(line); + std::istringstream iss(tokens_str); + std::string token; + std::vector tokenize_result; + while (iss >> token) { + tokenize_result.push_back(token); + } + + std::cout << "Input text: " << std::endl << line << std::endl; + std::cout << "Tokenize result: " << std::endl << tokens_str << std::endl; + + // Test TokenizeWithPosition (returns vector of tokens and positions) + auto [tokenize_with_pos_result, positions] = analyzer_->TokenizeWithPosition(line); + + // Check if results are identical + bool tokens_match = (tokenize_result.size() == tokenize_with_pos_result.size()); + if (tokens_match) { + for (size_t i = 0; i < tokenize_result.size(); ++i) { + if (tokenize_result[i] != tokenize_with_pos_result[i]) { + tokens_match = false; + break; + } + } + } + + assert(tokens_match == true); + if (!tokens_match) { + std::cout << "Tokenize count: " << tokenize_result.size() << ", TokenizeWithPosition count: " << + tokenize_with_pos_result.size() + << std::endl; + + std::cout << "TokenizeWithPosition result: " << std::endl; + std::string result_str = std::accumulate(tokenize_with_pos_result.begin(), + tokenize_with_pos_result.end(), + std::string(""), + [](const std::string &a, const std::string &b) { + return a + (a.empty() ? "" : " ") + b; + }); + std::cout << result_str << std::endl; + } + } + infile.close(); + + delete analyzer_; + analyzer_ = nullptr; +} + +std::vector SplitString(const std::string &str) { + std::vector tokens; + std::stringstream ss(str); + std::string token; + + while (ss >> token) { + tokens.push_back(token); + } + + return tokens; +} + +void test_tokenize_consistency_with_python() { + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + + std::string rag_tokenizer_path_ = "test"; + std::string input_file_ = rag_tokenizer_path_ + "/tokenizer_input.txt"; + + std::cout << "Looking for input file: " << input_file_ << std::endl; + std::cout << "Current directory: " << fs::current_path() << std::endl; + + if (!fs::exists(input_file_)) { + std::cerr << "ERROR: Input file doesn't exist: " << input_file_ << std::endl; + std::cerr << "Full path: " << fs::absolute(input_file_) << std::endl; + return; + } + + std::ifstream infile(input_file_); + if (!infile.is_open()) { + std::cerr << "ERROR: Cannot open file: " << input_file_ << std::endl; + std::cerr << "Error code: " << strerror(errno) << std::endl; + return; + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::cout << "File size: " << file_size << " bytes" << std::endl; + + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + std::unordered_set mismatch_tokens_ = {"be", "datum", "ccs", "experi", "fast", "llms", "larg", "ass"}; + + std::ifstream infile_python(rag_tokenizer_path_ + "/tokenizer_python_output.txt"); + std::string line; + std::string python_tokens; + while (std::getline(infile, line)) { + if (line.empty()) + continue; + + std::string tokens = analyzer_->Tokenize(line); + std::cout << "Input text: " << std::endl << line << std::endl; + std::cout << "Tokenize result: " << std::endl << tokens << std::endl; + + std::getline(infile_python, python_tokens); + + std::vector tokenize_result = SplitString(tokens); + std::vector python_tokenize_result = SplitString(python_tokens); + + bool is_size_match = tokenize_result.size() == python_tokenize_result.size(); + assert(is_size_match == true); + + bool is_match = true; + bool is_bad_token = false; + if (is_size_match) { + for (size_t i = 0; i < tokenize_result.size(); ++i) { + if (tokenize_result[i] != python_tokenize_result[i]) { + is_bad_token = mismatch_tokens_.contains(tokenize_result[i]); + if (!is_bad_token) { + is_match = false; + break; + } + } + } + assert(is_match == true); + } + if (!is_size_match || !is_match || is_bad_token) { + std::cout << "Tokenize count: " << tokenize_result.size() << ", Python tokenize count: " << + python_tokenize_result.size() << std::endl; + + std::cout << "Python tokenize result: " << std::endl << python_tokens << std::endl; + } + } + infile.close(); + + delete analyzer_; + analyzer_ = nullptr; +} + +void test_fine_grained_tokenize_consistency_with_python() { + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + + std::string rag_tokenizer_path_ = "test"; + std::string input_file_ = rag_tokenizer_path_ + "/tokenizer_input.txt"; + + std::cout << "Looking for input file: " << input_file_ << std::endl; + std::cout << "Current directory: " << fs::current_path() << std::endl; + + if (!fs::exists(input_file_)) { + std::cerr << "ERROR: Input file doesn't exist: " << input_file_ << std::endl; + std::cerr << "Full path: " << fs::absolute(input_file_) << std::endl; + return; + } + + std::ifstream infile(input_file_); + if (!infile.is_open()) { + std::cerr << "ERROR: Cannot open file: " << input_file_ << std::endl; + std::cerr << "Error code: " << strerror(errno) << std::endl; + return; + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::cout << "File size: " << file_size << " bytes" << std::endl; + + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + std::unordered_set mismatch_tokens_ = {"be", "datum", "ccs", "experi", "fast", "llms", "larg", "ass"}; + + analyzer_->SetEnablePosition(false); + analyzer_->SetFineGrained(true); + + std::ifstream infile_python(rag_tokenizer_path_ + "/fine_grained_tokenizer_python_output.txt"); + std::string line; + std::string python_tokens; + while (std::getline(infile, line)) { + if (line.empty()) + continue; + + TermList term_list; + analyzer_->Analyze(line, term_list); + + std::string fine_grained_tokens = + std::accumulate(term_list.begin(), + term_list.end(), + std::string(""), + [](const std::string &a, const Term &b) { + return a + (a.empty() ? "" : " ") + b.text_; + }); + + std::cout << "Input text: " << std::endl << line << std::endl; + std::cout << "Fine grained tokenize result: " << std::endl << fine_grained_tokens << std::endl; + + std::getline(infile_python, python_tokens); + std::vector python_tokenize_result = SplitString(python_tokens); + + bool is_size_match = term_list.size() == python_tokenize_result.size(); + assert(is_size_match == true); + + bool is_match = true; + bool is_bad_token = false; + if (is_size_match) { + for (size_t i = 0; i < term_list.size(); ++i) { + if (term_list[i].text_ != python_tokenize_result[i]) { + is_bad_token = mismatch_tokens_.contains(term_list[i].text_); + if (!is_bad_token) { + is_match = false; + break; + } + } + } + assert(is_match == true); + } + if (!is_size_match || !is_match || is_bad_token) { + std::cout << "Tokenize count: " << term_list.size() << ", Python tokenize count: " << python_tokenize_result + .size() << std::endl; + + std::cout << "Python tokenize result: " << std::endl << python_tokens << std::endl; + } + } + infile.close(); + + delete analyzer_; + analyzer_ = nullptr; +} + +void test_tokenize_text(const std::string& text) +{ + fs::path RESOURCE_DIR = "/usr/share/infinity/resource"; + if (!fs::exists(RESOURCE_DIR)) { + std::cerr << "Resource directory doesn't exist: " << RESOURCE_DIR << std::endl; + return; + } + auto analyzer_ = new RAGAnalyzer(RESOURCE_DIR.string()); + analyzer_->Load(); + + + analyzer_->SetEnablePosition(false); + analyzer_->SetFineGrained(false); + + std::string tokens = analyzer_->Tokenize(text); + std::cout << "Input text: " << std::endl << text << std::endl; + std::cout << "Tokenize result: " << std::endl << tokens << std::endl; + + delete analyzer_; + analyzer_ = nullptr; +} + +int main() { + // test_analyze_enable_position(); + // test_analyze_enable_position_fine_grained(); + // test_tokenize_consistency_with_position(); + // test_tokenize_consistency_with_python(); + // test_fine_grained_tokenize_consistency_with_python(); + test_tokenize_text("在本研究中,我们提出了一种novel的neural network架构,用于解决multi-modal learning问题。我们的方法结合了CNN(Convolutional Neural Networks)和Transformer的优势,在ImageNet数据集上达到了state-of-the-art性能。实验结果表明,在batch size为256、learning rate为0.001的条件下,我们的模型在validation set上的accuracy达到了95.7%,比baseline方法提高了3.2%。此外,我们还进行了ablation study来分析不同components的contribution。所有代码已在GitHub上开源,地址是https://github.com/example/our-project。未来工作将集中在model compression和real-time inference optimization上。"); + return 0; +} \ No newline at end of file diff --git a/internal/cpp/opencc/config_reader.c b/internal/cpp/opencc/config_reader.c new file mode 100644 index 00000000000..06f191e75b0 --- /dev/null +++ b/internal/cpp/opencc/config_reader.c @@ -0,0 +1,289 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config_reader.h" +#include "dictionary_set.h" + +#include + +#define BUFFER_SIZE 8192 +#define DICTIONARY_MAX_COUNT 1024 +#define CONFIG_DICT_TYPE_OCD "OCD" +#define CONFIG_DICT_TYPE_TEXT "TEXT" + +typedef struct { + opencc_dictionary_type dict_type; + char *file_name; + size_t index; + size_t stamp; +} dictionary_buffer; + +struct _config_desc { + char *title; + char *description; + dictionary_set_t dictionary_set; + char *home_dir; + dictionary_buffer dicts[DICTIONARY_MAX_COUNT]; + size_t dicts_count; + size_t stamp; +}; +typedef struct _config_desc config_desc; + +static config_error errnum = CONFIG_ERROR_VOID; + +static int qsort_dictionary_buffer_cmp(const void *a, const void *b) { + if (((dictionary_buffer *)a)->index < ((dictionary_buffer *)b)->index) + return -1; + if (((dictionary_buffer *)a)->index > ((dictionary_buffer *)b)->index) + return 1; + return ((dictionary_buffer *)a)->stamp < ((dictionary_buffer *)b)->stamp ? -1 : 1; +} + +static int load_dictionary(config_desc *config) { + if (config->dicts_count == 0) + return 0; + + qsort(config->dicts, config->dicts_count, sizeof(config->dicts[0]), qsort_dictionary_buffer_cmp); + + size_t i, last_index = 0; + dictionary_group_t group = dictionary_set_new_group(config->dictionary_set); + + for (i = 0; i < config->dicts_count; i++) { + if (config->dicts[i].index > last_index) { + last_index = config->dicts[i].index; + group = dictionary_set_new_group(config->dictionary_set); + } + dictionary_group_load(group, config->dicts[i].file_name, config->home_dir, config->dicts[i].dict_type); + } + + return 0; +} + +static int parse_add_dict(config_desc *config, size_t index, const char *dictstr) { + const char *pstr = dictstr; + + while (*pstr != '\0' && *pstr != ' ') + pstr++; + + opencc_dictionary_type dict_type; + + if (strncmp(dictstr, CONFIG_DICT_TYPE_OCD, sizeof(CONFIG_DICT_TYPE_OCD) - 1) == 0) + dict_type = OPENCC_DICTIONARY_TYPE_DATRIE; + else if (strncmp(dictstr, CONFIG_DICT_TYPE_TEXT, sizeof(CONFIG_DICT_TYPE_OCD) - 1) == 0) + dict_type = OPENCC_DICTIONARY_TYPE_TEXT; + else { + errnum = CONFIG_ERROR_INVALID_DICT_TYPE; + return -1; + } + + while (*pstr != '\0' && (*pstr == ' ' || *pstr == '\t')) + pstr++; + + size_t i = config->dicts_count++; + + config->dicts[i].dict_type = dict_type; + config->dicts[i].file_name = mstrcpy(pstr); + config->dicts[i].index = index; + config->dicts[i].stamp = config->stamp++; + + return 0; +} + +static int parse_property(config_desc *config, const char *key, const char *value) { + if (strncmp(key, "dict", 4) == 0) { + int index = 0; + sscanf(key + 4, "%d", &index); + return parse_add_dict(config, index, value); + } else if (strcmp(key, "title") == 0) { + free(config->title); + config->title = mstrcpy(value); + return 0; + } else if (strcmp(key, "description") == 0) { + free(config->description); + config->description = mstrcpy(value); + return 0; + } + + errnum = CONFIG_ERROR_NO_PROPERTY; + return -1; +} + +static int parse_line(const char *line, char **key, char **value) { + const char *line_begin = line; + + while (*line != '\0' && (*line != ' ' && *line != '\t' && *line != '=')) + line++; + + size_t key_len = line - line_begin; + + while (*line != '\0' && *line != '=') + line++; + + if (*line == '\0') + return -1; + + assert(*line == '='); + + *key = mstrncpy(line_begin, key_len); + + line++; + while (*line != '\0' && (*line == ' ' || *line == '\t')) + line++; + + if (*line == '\0') { + free(*key); + return -1; + } + + *value = mstrcpy(line); + + return 0; +} + +static char *parse_trim(char *str) { + for (; *str != '\0' && (*str == ' ' || *str == '\t'); str++) + ; + register char *prs = str; + for (; *prs != '\0' && *prs != '\n' && *prs != '\r'; prs++) + ; + for (prs--; prs > str && (*prs == ' ' || *prs == '\t'); prs--) + ; + *(++prs) = '\0'; + return str; +} + +static int parse(config_desc *config, const char *filename, const char *home_path) { + FILE *fp = fopen(filename, "rb"); + if (!fp) { + char *pkg_filename = (char *)malloc(sizeof(char) * (strlen(filename) + strlen(home_path) + 2)); + sprintf(pkg_filename, "%s/%s", home_path, filename); + printf("pkg_filename %s\n", pkg_filename); + fp = fopen(pkg_filename, "rb"); + if (!fp) { + free(pkg_filename); + errnum = CONFIG_ERROR_CANNOT_ACCESS_CONFIG_FILE; + return -1; + } + free(pkg_filename); + } + + config->home_dir = (char *)malloc(sizeof(char) * (strlen(home_path) + 1)); + sprintf(config->home_dir, "%s", home_path); + + static char buff[BUFFER_SIZE]; + + while (fgets(buff, BUFFER_SIZE, fp) != NULL) { + char *trimed_buff = parse_trim(buff); + if (*trimed_buff == ';' || *trimed_buff == '#' || *trimed_buff == '\0') { + /* Comment Line or empty line */ + continue; + } + + char *key = NULL, *value = NULL; + + if (parse_line(trimed_buff, &key, &value) == -1) { + free(key); + free(value); + fclose(fp); + errnum = CONFIG_ERROR_PARSE; + return -1; + } + + if (parse_property(config, key, value) == -1) { + free(key); + free(value); + fclose(fp); + return -1; + } + + free(key); + free(value); + } + + fclose(fp); + return 0; +} + +dictionary_set_t config_get_dictionary_set(config_t t_config) { + config_desc *config = (config_desc *)t_config; + + if (config->dictionary_set != NULL) { + dictionary_set_close(config->dictionary_set); + } + + config->dictionary_set = dictionary_set_open(); + load_dictionary(config); + + return config->dictionary_set; +} + +config_error config_errno(void) { return errnum; } + +void config_perror(const char *spec) { + perr(spec); + perr("\n"); + switch (errnum) { + case CONFIG_ERROR_VOID: + break; + case CONFIG_ERROR_CANNOT_ACCESS_CONFIG_FILE: + perror(_("Can not access configuration file")); + break; + case CONFIG_ERROR_PARSE: + perr(_("Configuration file parse error")); + break; + case CONFIG_ERROR_NO_PROPERTY: + perr(_("Invalid property")); + break; + case CONFIG_ERROR_INVALID_DICT_TYPE: + perr(_("Invalid dictionary type")); + break; + default: + perr(_("Unknown")); + } +} + +config_t config_open(const char *filename, const char *home_path) { + config_desc *config = (config_desc *)malloc(sizeof(config_desc)); + + config->title = NULL; + config->description = NULL; + config->home_dir = NULL; + config->dicts_count = 0; + config->stamp = 0; + config->dictionary_set = NULL; + + if (parse(config, filename, home_path) == -1) { + config_close((config_t)config); + return (config_t)-1; + } + + return (config_t)config; +} + +void config_close(config_t t_config) { + config_desc *config = (config_desc *)t_config; + + size_t i; + for (i = 0; i < config->dicts_count; i++) + free(config->dicts[i].file_name); + + free(config->title); + free(config->description); + free(config->home_dir); + free(config); +} diff --git a/internal/cpp/opencc/config_reader.h b/internal/cpp/opencc/config_reader.h new file mode 100644 index 00000000000..becfba04ecf --- /dev/null +++ b/internal/cpp/opencc/config_reader.h @@ -0,0 +1,46 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __OPENCC_CONFIG_H_ +#define __OPENCC_CONFIG_H_ + +#include "utils.h" +#include "dictionary_set.h" + +typedef void * config_t; + +typedef enum +{ + CONFIG_ERROR_VOID, + CONFIG_ERROR_CANNOT_ACCESS_CONFIG_FILE, + CONFIG_ERROR_PARSE, + CONFIG_ERROR_NO_PROPERTY, + CONFIG_ERROR_INVALID_DICT_TYPE, +} config_error; + +config_t config_open(const char * filename, const char* home_path); + +void config_close(config_t t_config); + +dictionary_set_t config_get_dictionary_set(config_t t_config); + +config_error config_errno(void); + +void config_perror(const char * spec); + +#endif /* __OPENCC_CONFIG_H_ */ diff --git a/internal/cpp/opencc/converter.c b/internal/cpp/opencc/converter.c new file mode 100644 index 00000000000..2b433bd678b --- /dev/null +++ b/internal/cpp/opencc/converter.c @@ -0,0 +1,590 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "converter.h" +#include "dictionary_set.h" +#include "encoding.h" +#include "utils.h" + +#define DELIMITER ' ' +#define SEGMENT_MAXIMUM_LENGTH 0 +#define SEGMENT_SHORTEST_PATH 1 +#define SEGMENT_METHOD SEGMENT_SHORTEST_PATH + +#if SEGMENT_METHOD == SEGMENT_SHORTEST_PATH + +#define OPENCC_SP_SEG_DEFAULT_BUFFER_SIZE 1024 + +typedef struct { + int initialized; + size_t buffer_size; + size_t *match_length; + size_t *min_len; + size_t *parent; + size_t *path; +} spseg_buffer_desc; + +#endif + +typedef struct { +#if SEGMENT_METHOD == SEGMENT_SHORTEST_PATH + spseg_buffer_desc spseg_buffer; +#endif + dictionary_set_t dictionary_set; + dictionary_group_t current_dictionary_group; + opencc_conversion_mode conversion_mode; +} converter_desc; +static converter_error errnum = CONVERTER_ERROR_VOID; + +#if SEGMENT_METHOD == SEGMENT_SHORTEST_PATH +static void sp_seg_buffer_free(spseg_buffer_desc *ossb) { + free(ossb->match_length); + free(ossb->min_len); + free(ossb->parent); + free(ossb->path); +} + +static void sp_seg_set_buffer_size(spseg_buffer_desc *ossb, size_t buffer_size) { + if (ossb->initialized == TRUE) + sp_seg_buffer_free(ossb); + + ossb->buffer_size = buffer_size; + ossb->match_length = (size_t *)malloc((buffer_size + 1) * sizeof(size_t)); + ossb->min_len = (size_t *)malloc(buffer_size * sizeof(size_t)); + ossb->parent = (size_t *)malloc(buffer_size * sizeof(size_t)); + ossb->path = (size_t *)malloc(buffer_size * sizeof(size_t)); + + ossb->initialized = TRUE; +} + +static size_t sp_seg(converter_desc *converter, ucs4_t **inbuf, size_t *inbuf_left, ucs4_t **outbuf, size_t *outbuf_left, size_t length) { + /* 最短路徑分詞 */ + + /* 對長度爲1時特殊優化 */ + if (length == 1) { + const ucs4_t *const *match_rs = dictionary_group_match_longest(converter->current_dictionary_group, *inbuf, 1, NULL); + + size_t match_len = 1; + if (converter->conversion_mode == OPENCC_CONVERSION_FAST) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + const ucs4_t *result = match_rs[0]; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) > *outbuf_left) { + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_LIST_CANDIDATES) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + size_t i; + for (i = 0; match_rs[i] != NULL; i++) { + const ucs4_t *result = match_rs[i]; + int show_delimiter = match_rs[i + 1] != NULL ? 1 : 0; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) + show_delimiter > *outbuf_left) { + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + if (show_delimiter) { + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } + } + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_SEGMENT_ONLY) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (match_len + 1 > *outbuf_left) { + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + size_t i; + for (i = 0; i < match_len; i++) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } + } + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } else + debug_should_not_be_here(); + /* 必須保證有一個字符空間 */ + return match_len; + } + + /* 設置緩衝區空間 */ + spseg_buffer_desc *ossb = &(converter->spseg_buffer); + size_t buffer_size_need = length + 1; + if (ossb->initialized == FALSE || ossb->buffer_size < buffer_size_need) + sp_seg_set_buffer_size(&(converter->spseg_buffer), buffer_size_need); + + size_t i, j; + + for (i = 0; i <= length; i++) + ossb->min_len[i] = INFINITY_INT; + + ossb->min_len[0] = ossb->parent[0] = 0; + + for (i = 0; i < length; i++) { + /* 獲取所有匹配長度 */ + size_t match_count = dictionary_group_get_all_match_lengths(converter->current_dictionary_group, (*inbuf) + i, ossb->match_length); + + if (ossb->match_length[0] != 1) + ossb->match_length[match_count++] = 1; + + /* 動態規劃求最短分割路徑 */ + for (j = 0; j < match_count; j++) { + size_t k = ossb->match_length[j]; + ossb->match_length[j] = 0; + + if (k > 1 && ossb->min_len[i] + 1 <= ossb->min_len[i + k]) { + ossb->min_len[i + k] = ossb->min_len[i] + 1; + ossb->parent[i + k] = i; + } else if (k == 1 && ossb->min_len[i] + 1 < ossb->min_len[i + k]) { + ossb->min_len[i + k] = ossb->min_len[i] + 1; + ossb->parent[i + k] = i; + } + } + } + + /* 取得最短分割路徑 */ + for (i = length, j = ossb->min_len[length]; i != 0; i = ossb->parent[i]) + ossb->path[--j] = i; + + size_t inbuf_left_start = *inbuf_left; + size_t begin, end; + + /* 根據最短分割路徑轉換 */ + for (i = begin = 0; i < ossb->min_len[length]; i++) { + end = ossb->path[i]; + + size_t match_len; + const ucs4_t *const *match_rs = dictionary_group_match_longest(converter->current_dictionary_group, *inbuf, end - begin, &match_len); + + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + if (converter->conversion_mode == OPENCC_CONVERSION_FAST) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + const ucs4_t *result = match_rs[0]; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_LIST_CANDIDATES) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + size_t i; + for (i = 0; match_rs[i] != NULL; i++) { + const ucs4_t *result = match_rs[i]; + int show_delimiter = match_rs[i + 1] != NULL ? 1 : 0; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) + show_delimiter > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + if (show_delimiter) { + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } + } + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_SEGMENT_ONLY) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (match_len + 1 > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + size_t i; + for (i = 0; i < match_len; i++) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } + } + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } else + debug_should_not_be_here(); + } + + begin = end; + } + + return inbuf_left_start - *inbuf_left; +} + +static size_t segment(converter_desc *converter, ucs4_t **inbuf, size_t *inbuf_left, ucs4_t **outbuf, size_t *outbuf_left) { + /* 歧義分割最短路徑分詞 */ + size_t i, start, bound; + const ucs4_t *inbuf_start = *inbuf; + size_t inbuf_left_start = *inbuf_left; + size_t sp_seg_length; + + bound = 0; + + for (i = start = 0; inbuf_start[i] && *inbuf_left > 0 && *outbuf_left > 0; i++) { + if (i != 0 && i == bound) { + /* 對歧義部分進行最短路徑分詞 */ + sp_seg_length = sp_seg(converter, inbuf, inbuf_left, outbuf, outbuf_left, bound - start); + if (sp_seg_length == (size_t)-1) + return (size_t)-1; + if (sp_seg_length == 0) { + if (inbuf_left_start - *inbuf_left > 0) + return inbuf_left_start - *inbuf_left; + /* 空間不足 */ + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + start = i; + } + + size_t match_len; + dictionary_group_match_longest(converter->current_dictionary_group, inbuf_start + i, 0, &match_len); + + if (match_len == 0) + match_len = 1; + + if (i + match_len > bound) + bound = i + match_len; + } + + if (*inbuf_left > 0 && *outbuf_left > 0) { + sp_seg_length = sp_seg(converter, inbuf, inbuf_left, outbuf, outbuf_left, bound - start); + if (sp_seg_length == (size_t)-1) + return (size_t)-1; + if (sp_seg_length == 0) { + if (inbuf_left_start - *inbuf_left > 0) + return inbuf_left_start - *inbuf_left; + /* 空間不足 */ + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + } + + if (converter->conversion_mode == OPENCC_CONVERSION_SEGMENT_ONLY) { + (*outbuf)--; + (*outbuf_left)++; + } + + return inbuf_left_start - *inbuf_left; +} + +#endif + +#if SEGMENT_METHOD == SEGMENT_MAXIMUM_LENGTH +static size_t segment(converter_desc *converter, ucs4_t **inbuf, size_t *inbuf_left, ucs4_t **outbuf, size_t *outbuf_left) { + /* 正向最大分詞 */ + size_t inbuf_left_start = *inbuf_left; + + for (; **inbuf && *inbuf_left > 0 && *outbuf_left > 0;) { + size_t match_len; + const ucs4_t *const *match_rs = dictionary_group_match_longest(converter->current_dictionary_group, *inbuf, *inbuf_left, &match_len); + + if (converter->conversion_mode == OPENCC_CONVERSION_FAST) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + const ucs4_t *result = match_rs[0]; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_LIST_CANDIDATES) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + size_t i; + for (i = 0; match_rs[i] != NULL; i++) { + const ucs4_t *result = match_rs[i]; + int show_delimiter = match_rs[i + 1] != NULL ? 1 : 0; + + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (ucs4len(result) + show_delimiter > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + for (; *result; result++) { + **outbuf = *result; + (*outbuf)++, (*outbuf_left)--; + } + + if (show_delimiter) { + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } + } + + *inbuf += match_len; + *inbuf_left -= match_len; + } + } else if (converter->conversion_mode == OPENCC_CONVERSION_SEGMENT_ONLY) { + if (match_rs == NULL) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } else { + /* 輸出緩衝區剩餘空間小於分詞長度 */ + if (match_len + 1 > *outbuf_left) { + if (inbuf_left_start - *inbuf_left > 0) + break; + errnum = CONVERTER_ERROR_OUTBUF; + return (size_t)-1; + } + + size_t i; + for (i = 0; i < match_len; i++) { + **outbuf = **inbuf; + (*outbuf)++, (*outbuf_left)--; + (*inbuf)++, (*inbuf_left)--; + } + } + **outbuf = DELIMITER; + (*outbuf)++, (*outbuf_left)--; + } else + debug_should_not_be_here(); + } + + if (converter->conversion_mode == OPENCC_CONVERSION_SEGMENT_ONLY) { + (*outbuf)--; + (*outbuf_left)++; + } + + return inbuf_left_start - *inbuf_left; +} +#endif + +size_t converter_convert(converter_t t_converter, ucs4_t **inbuf, size_t *inbuf_left, ucs4_t **outbuf, size_t *outbuf_left) { + converter_desc *converter = (converter_desc *)t_converter; + + if (converter->dictionary_set == NULL) { + errnum = CONVERTER_ERROR_NODICT; + return (size_t)-1; + } + + if (dictionary_set_count_group(converter->dictionary_set) == 1) { + /* 只有一個辭典,直接輸出 */ + return segment(converter, inbuf, inbuf_left, outbuf, outbuf_left); + } + + // 啓用辭典轉換鏈 + size_t inbuf_size = *inbuf_left; + size_t outbuf_size = *outbuf_left; + size_t retval = (size_t)-1; + size_t cinbuf_left, coutbuf_left; + size_t coutbuf_delta = 0; + size_t i, cur; + + ucs4_t *tmpbuf = (ucs4_t *)malloc(sizeof(ucs4_t) * outbuf_size); + ucs4_t *orig_outbuf = *outbuf; + ucs4_t *cinbuf, *coutbuf; + + cinbuf_left = inbuf_size; + coutbuf_left = outbuf_size; + cinbuf = *inbuf; + coutbuf = tmpbuf; + + for (i = cur = 0; i < dictionary_set_count_group(converter->dictionary_set); ++i, cur = 1 - cur) { + if (i > 0) { + cinbuf_left = coutbuf_delta; + coutbuf_left = outbuf_size; + if (cur == 1) { + cinbuf = tmpbuf; + coutbuf = orig_outbuf; + } else { + cinbuf = orig_outbuf; + coutbuf = tmpbuf; + } + } + + converter->current_dictionary_group = dictionary_set_get_group(converter->dictionary_set, i); + + size_t ret = segment(converter, &cinbuf, &cinbuf_left, &coutbuf, &coutbuf_left); + if (ret == (size_t)-1) { + free(tmpbuf); + return (size_t)-1; + } + coutbuf_delta = outbuf_size - coutbuf_left; + if (i == 0) { + retval = ret; + *inbuf = cinbuf; + *inbuf_left = cinbuf_left; + } + } + + if (cur == 1) { + // 結果在緩衝區 + memcpy(*outbuf, tmpbuf, coutbuf_delta * sizeof(ucs4_t)); + } + + *outbuf += coutbuf_delta; + *outbuf_left = coutbuf_left; + free(tmpbuf); + + return retval; +} + +void converter_assign_dictionary(converter_t t_converter, dictionary_set_t dictionary_set) { + converter_desc *converter = (converter_desc *)t_converter; + converter->dictionary_set = dictionary_set; + if (dictionary_set_count_group(converter->dictionary_set) > 0) + converter->current_dictionary_group = dictionary_set_get_group(converter->dictionary_set, 0); +} + +converter_t converter_open(void) { + converter_desc *converter = (converter_desc *)malloc(sizeof(converter_desc)); + + converter->dictionary_set = NULL; + converter->current_dictionary_group = NULL; + +#if SEGMENT_METHOD == SEGMENT_SHORTEST_PATH + converter->spseg_buffer.initialized = FALSE; + converter->spseg_buffer.match_length = converter->spseg_buffer.min_len = converter->spseg_buffer.parent = converter->spseg_buffer.path = NULL; + + sp_seg_set_buffer_size(&converter->spseg_buffer, OPENCC_SP_SEG_DEFAULT_BUFFER_SIZE); +#endif + + return (converter_t)converter; +} + +void converter_close(converter_t t_converter) { + converter_desc *converter = (converter_desc *)t_converter; + +#if SEGMENT_METHOD == SEGMENT_SHORTEST_PATH + sp_seg_buffer_free(&(converter->spseg_buffer)); +#endif + + free(converter); +} + +void converter_set_conversion_mode(converter_t t_converter, opencc_conversion_mode conversion_mode) { + converter_desc *converter = (converter_desc *)t_converter; + converter->conversion_mode = conversion_mode; +} + +converter_error converter_errno(void) { return errnum; } + +void converter_perror(const char *spec) { + perr(spec); + perr("\n"); + switch (errnum) { + case CONVERTER_ERROR_VOID: + break; + case CONVERTER_ERROR_NODICT: + perr(_("No dictionary loaded")); + break; + case CONVERTER_ERROR_OUTBUF: + perr(_("Output buffer not enough for one segment")); + break; + default: + perr(_("Unknown")); + } +} diff --git a/internal/cpp/opencc/converter.h b/internal/cpp/opencc/converter.h new file mode 100644 index 00000000000..e778600d3b2 --- /dev/null +++ b/internal/cpp/opencc/converter.h @@ -0,0 +1,48 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __CONVERTER_H_ +#define __CONVERTER_H_ + +#include "dictionary_set.h" + +typedef void * converter_t; + +typedef enum +{ + CONVERTER_ERROR_VOID, + CONVERTER_ERROR_NODICT, + CONVERTER_ERROR_OUTBUF, +} converter_error; + +void converter_assign_dictionary(converter_t t_converter, dictionary_set_t dictionary_set); + +converter_t converter_open(void); + +void converter_close(converter_t t_converter); + +size_t converter_convert(converter_t t_converter, ucs4_t ** inbuf, size_t * inbuf_left, + ucs4_t ** outbuf, size_t * outbuf_left); + +void converter_set_conversion_mode(converter_t t_converter, opencc_conversion_mode conversion_mode); + +converter_error converter_errno(void); + +void converter_perror(const char * spec); + +#endif /* __CONVERTER_H_ */ diff --git a/internal/cpp/opencc/dictionary/abstract.c b/internal/cpp/opencc/dictionary/abstract.c new file mode 100644 index 00000000000..d59524d4af0 --- /dev/null +++ b/internal/cpp/opencc/dictionary/abstract.c @@ -0,0 +1,94 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract.h" +#include "datrie.h" +#include "text.h" + +struct _dictionary { + opencc_dictionary_type type; + dictionary_t dict; +}; +typedef struct _dictionary dictionary_desc; + +dictionary_t dictionary_open(const char *filename, opencc_dictionary_type type) { + dictionary_desc *dictionary = (dictionary_desc *)malloc(sizeof(dictionary_desc)); + dictionary->type = type; + switch (type) { + case OPENCC_DICTIONARY_TYPE_TEXT: + dictionary->dict = dictionary_text_open(filename); + break; + case OPENCC_DICTIONARY_TYPE_DATRIE: + dictionary->dict = dictionary_datrie_open(filename); + break; + default: + free(dictionary); + dictionary = (dictionary_t)-1; /* TODO:辭典格式不支持 */ + } + return dictionary; +} + +dictionary_t dictionary_get(dictionary_t t_dictionary) { + dictionary_desc *dictionary = (dictionary_desc *)t_dictionary; + return dictionary->dict; +} + +void dictionary_close(dictionary_t t_dictionary) { + dictionary_desc *dictionary = (dictionary_desc *)t_dictionary; + switch (dictionary->type) { + case OPENCC_DICTIONARY_TYPE_TEXT: + dictionary_text_close(dictionary->dict); + break; + case OPENCC_DICTIONARY_TYPE_DATRIE: + dictionary_datrie_close(dictionary->dict); + break; + default: + debug_should_not_be_here(); + } + free(dictionary); +} + +const ucs4_t *const *dictionary_match_longest(dictionary_t t_dictionary, const ucs4_t *word, size_t maxlen, size_t *match_length) { + dictionary_desc *dictionary = (dictionary_desc *)t_dictionary; + switch (dictionary->type) { + case OPENCC_DICTIONARY_TYPE_TEXT: + return dictionary_text_match_longest(dictionary->dict, word, maxlen, match_length); + break; + case OPENCC_DICTIONARY_TYPE_DATRIE: + return dictionary_datrie_match_longest(dictionary->dict, word, maxlen, match_length); + break; + default: + debug_should_not_be_here(); + } + return (const ucs4_t *const *)-1; +} + +size_t dictionary_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t *word, size_t *match_length) { + dictionary_desc *dictionary = (dictionary_desc *)t_dictionary; + switch (dictionary->type) { + case OPENCC_DICTIONARY_TYPE_TEXT: + return dictionary_text_get_all_match_lengths(dictionary->dict, word, match_length); + break; + case OPENCC_DICTIONARY_TYPE_DATRIE: + return dictionary_datrie_get_all_match_lengths(dictionary->dict, word, match_length); + break; + default: + debug_should_not_be_here(); + } + return (size_t)-1; +} diff --git a/internal/cpp/opencc/dictionary/abstract.h b/internal/cpp/opencc/dictionary/abstract.h new file mode 100644 index 00000000000..fd8171e0e3a --- /dev/null +++ b/internal/cpp/opencc/dictionary/abstract.h @@ -0,0 +1,45 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __OPENCC_DICTIONARY_ABSTRACT_H_ +#define __OPENCC_DICTIONARY_ABSTRACT_H_ + +#include "../utils.h" + +struct _entry +{ + ucs4_t * key; + ucs4_t ** value; +}; +typedef struct _entry entry; + +typedef void * dictionary_t; + +dictionary_t dictionary_open(const char * filename, opencc_dictionary_type type); + +void dictionary_close(dictionary_t t_dictionary); + +dictionary_t dictionary_get(dictionary_t t_dictionary); + +const ucs4_t * const * dictionary_match_longest(dictionary_t t_dictionary, const ucs4_t * word, + size_t maxlen, size_t * match_length); + +size_t dictionary_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t * word, + size_t * match_length); + +#endif /* __OPENCC_DICTIONARY_ABSTRACT_H_ */ diff --git a/internal/cpp/opencc/dictionary/datrie.c b/internal/cpp/opencc/dictionary/datrie.c new file mode 100644 index 00000000000..5cf36bd7c80 --- /dev/null +++ b/internal/cpp/opencc/dictionary/datrie.c @@ -0,0 +1,250 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datrie.h" +#include +#include + +#ifdef __WIN32 +/* Todo: Win32 mmap*/ +#else +#include +#define MMAP_ENABLED +#endif + +typedef enum { MEMORY_TYPE_MMAP, MEMORY_TYPE_ALLOCATE } memory_type; + +struct _datrie_dictionary { + const DoubleArrayTrieItem *dat; + uint32_t dat_item_count; + ucs4_t *lexicon; + uint32_t lexicon_count; + + ucs4_t ***lexicon_set; + void *dic_memory; + size_t dic_size; + memory_type dic_memory_type; +}; +typedef struct _datrie_dictionary datrie_dictionary_desc; + +static int load_allocate(datrie_dictionary_desc *datrie_dictionary, int fd) { + datrie_dictionary->dic_memory_type = MEMORY_TYPE_ALLOCATE; + datrie_dictionary->dic_memory = malloc(datrie_dictionary->dic_size); + if (datrie_dictionary->dic_memory == NULL) { + /* 內存申請失敗 */ + return -1; + } + lseek(fd, 0, SEEK_SET); + if (read(fd, datrie_dictionary->dic_memory, datrie_dictionary->dic_size) == -1) { + /* 讀取失敗 */ + return -1; + } + return 0; +} + +static int load_mmap(datrie_dictionary_desc *datrie_dictionary, int fd) { +#ifdef MMAP_ENABLED + datrie_dictionary->dic_memory_type = MEMORY_TYPE_MMAP; + datrie_dictionary->dic_memory = mmap(NULL, datrie_dictionary->dic_size, PROT_READ, MAP_PRIVATE, fd, 0); + if (datrie_dictionary->dic_memory == MAP_FAILED) { + /* 內存映射創建失敗 */ + datrie_dictionary->dic_memory = NULL; + return -1; + } + return 0; +#else + return -1; +#endif +} + +static int load_dict(datrie_dictionary_desc *datrie_dictionary, FILE *fp) { + int fd = fileno(fp); + + fseek(fp, 0, SEEK_END); + datrie_dictionary->dic_size = ftell(fp); + + /* 首先嘗試mmap,如果失敗嘗試申請內存 */ + if (load_mmap(datrie_dictionary, fd) == -1) { + if (load_allocate(datrie_dictionary, fd) == -1) { + return -1; + } + } + + size_t header_len = strlen("OPENCCDATRIE"); + + if (strncmp((const char *)datrie_dictionary->dic_memory, "OPENCCDATRIE", header_len) != 0) { + return -1; + } + + size_t offset = 0; + + offset += header_len * sizeof(char); + + /* 詞彙表 */ + uint32_t lexicon_length = *((uint32_t *)(datrie_dictionary->dic_memory + offset)); + offset += sizeof(uint32_t); + + datrie_dictionary->lexicon = (ucs4_t *)(datrie_dictionary->dic_memory + offset); + offset += lexicon_length * sizeof(ucs4_t); + + /* 詞彙索引表 */ + uint32_t lexicon_index_length = *((uint32_t *)(datrie_dictionary->dic_memory + offset)); + offset += sizeof(uint32_t); + + uint32_t *lexicon_index = (uint32_t *)(datrie_dictionary->dic_memory + offset); + offset += lexicon_index_length * sizeof(uint32_t); + + datrie_dictionary->lexicon_count = *((uint32_t *)(datrie_dictionary->dic_memory + offset)); + offset += sizeof(uint32_t); + + datrie_dictionary->dat_item_count = *((uint32_t *)(datrie_dictionary->dic_memory + offset)); + offset += sizeof(uint32_t); + + datrie_dictionary->dat = (DoubleArrayTrieItem *)(datrie_dictionary->dic_memory + offset); + + /* 構造索引表 */ + datrie_dictionary->lexicon_set = (ucs4_t ***)malloc(datrie_dictionary->lexicon_count * sizeof(ucs4_t **)); + size_t i, last = 0; + for (i = 0; i < datrie_dictionary->lexicon_count; i++) { + size_t count, j; + for (j = last; j < lexicon_index_length; j++) { + if (lexicon_index[j] == (uint32_t)-1) + break; + } + count = j - last; + + datrie_dictionary->lexicon_set[i] = (ucs4_t **)malloc((count + 1) * sizeof(ucs4_t *)); + for (j = 0; j < count; j++) { + datrie_dictionary->lexicon_set[i][j] = datrie_dictionary->lexicon + lexicon_index[last + j]; + } + datrie_dictionary->lexicon_set[i][count] = NULL; + last += j + 1; + } + + return 0; +} + +static int unload_dict(datrie_dictionary_desc *datrie_dictionary) { + if (datrie_dictionary->dic_memory != NULL) { + size_t i; + for (i = 0; i < datrie_dictionary->lexicon_count; i++) { + free(datrie_dictionary->lexicon_set[i]); + } + free(datrie_dictionary->lexicon_set); + + if (MEMORY_TYPE_MMAP == datrie_dictionary->dic_memory_type) { +#ifdef MMAP_ENABLED + return munmap(datrie_dictionary->dic_memory, datrie_dictionary->dic_size); +#else + debug_should_not_be_here(); +#endif + } else if (MEMORY_TYPE_ALLOCATE == datrie_dictionary->dic_memory_type) { + free(datrie_dictionary->dic_memory); + } else { + return -1; + } + } + return 0; +} + +dictionary_t dictionary_datrie_open(const char *filename) { + datrie_dictionary_desc *datrie_dictionary = (datrie_dictionary_desc *)malloc(sizeof(datrie_dictionary_desc)); + datrie_dictionary->dat = NULL; + datrie_dictionary->lexicon = NULL; + + FILE *fp = fopen(filename, "rb"); + + if (load_dict(datrie_dictionary, fp) == -1) { + dictionary_datrie_close((dictionary_t)datrie_dictionary); + return (dictionary_t)-1; + } + + fclose(fp); + + return (dictionary_t)datrie_dictionary; +} + +int dictionary_datrie_close(dictionary_t t_dictionary) { + datrie_dictionary_desc *datrie_dictionary = (datrie_dictionary_desc *)t_dictionary; + + if (unload_dict(datrie_dictionary) == -1) { + free(datrie_dictionary); + return -1; + } + + free(datrie_dictionary); + return 0; +} + +int encode_char(ucs4_t ch) { return (int)ch; } + +void datrie_match(const datrie_dictionary_desc *datrie_dictionary, const ucs4_t *word, size_t *match_pos, size_t *id, size_t limit) { + size_t i, p; + for (i = 0, p = 0; word[p] && (limit == 0 || p < limit) && datrie_dictionary->dat[i].base != DATRIE_UNUSED; p++) { + int k = encode_char(word[p]); + int j = datrie_dictionary->dat[i].base + k; + if (j < 0 || j >= datrie_dictionary->dat_item_count || datrie_dictionary->dat[j].parent != i) + break; + i = j; + } + if (match_pos) + *match_pos = p; + if (id) + *id = i; +} + +const ucs4_t *const *dictionary_datrie_match_longest(dictionary_t t_dictionary, const ucs4_t *word, size_t maxlen, size_t *match_length) { + datrie_dictionary_desc *datrie_dictionary = (datrie_dictionary_desc *)t_dictionary; + + size_t pos, item; + datrie_match(datrie_dictionary, word, &pos, &item, maxlen); + + while (datrie_dictionary->dat[item].word == -1 && pos > 1) + datrie_match(datrie_dictionary, word, &pos, &item, pos - 1); + + if (pos == 0 || datrie_dictionary->dat[item].word == -1) { + if (match_length != NULL) + *match_length = 0; + return NULL; + } + + if (match_length != NULL) + *match_length = pos; + + return (const ucs4_t *const *)datrie_dictionary->lexicon_set[datrie_dictionary->dat[item].word]; +} + +size_t dictionary_datrie_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t *word, size_t *match_length) { + datrie_dictionary_desc *datrie_dictionary = (datrie_dictionary_desc *)t_dictionary; + + size_t rscnt = 0; + + size_t i, p; + for (i = 0, p = 0; word[p] && datrie_dictionary->dat[i].base != DATRIE_UNUSED; p++) { + int k = encode_char(word[p]); + int j = datrie_dictionary->dat[i].base + k; + if (j < 0 || j >= datrie_dictionary->dat_item_count || datrie_dictionary->dat[j].parent != i) + break; + i = j; + + if (datrie_dictionary->dat[i].word != -1) + match_length[rscnt++] = p + 1; + } + + return rscnt; +} diff --git a/internal/cpp/opencc/dictionary/datrie.h b/internal/cpp/opencc/dictionary/datrie.h new file mode 100644 index 00000000000..ae2767de334 --- /dev/null +++ b/internal/cpp/opencc/dictionary/datrie.h @@ -0,0 +1,45 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __OPENCC_DICTIONARY_DATRIE_H_ +#define __OPENCC_DICTIONARY_DATRIE_H_ + +#include "abstract.h" + +#define DATRIE_UNUSED -1 + +typedef struct +{ + int base; + int parent; + int word; +} DoubleArrayTrieItem; + +dictionary_t dictionary_datrie_open(const char * filename); + +int dictionary_datrie_close(dictionary_t t_dictionary); + +const ucs4_t * const * dictionary_datrie_match_longest(dictionary_t t_dictionary, const ucs4_t * word, + size_t maxlen, size_t * match_length); + +size_t dictionary_datrie_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t * word, + size_t * match_length); + +int encode_char(ucs4_t ch); + +#endif /* __OPENCC_DICTIONARY_DATRIE_H_ */ diff --git a/internal/cpp/opencc/dictionary/text.c b/internal/cpp/opencc/dictionary/text.c new file mode 100644 index 00000000000..41bcdbb45af --- /dev/null +++ b/internal/cpp/opencc/dictionary/text.c @@ -0,0 +1,232 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "text.h" +#include "../encoding.h" + +#define INITIAL_DICTIONARY_SIZE 1024 +#define ENTRY_BUFF_SIZE 128 +#define ENTRY_WBUFF_SIZE ENTRY_BUFF_SIZE / sizeof(size_t) + +struct _text_dictionary { + size_t entry_count; + size_t max_length; + entry *lexicon; + ucs4_t *word_buff; +}; +typedef struct _text_dictionary text_dictionary_desc; + +int qsort_entry_cmp(const void *a, const void *b) { return ucs4cmp(((entry *)a)->key, ((entry *)b)->key); } + +int parse_entry(const char *buff, entry *entry_i) { + size_t length; + const char *pbuff; + + /* 解析鍵 */ + for (pbuff = buff; *pbuff != '\t' && *pbuff != '\0'; ++pbuff) + ; + if (*pbuff == '\0') + return -1; + length = pbuff - buff; + + ucs4_t *ucs4_buff; + ucs4_buff = utf8_to_ucs4(buff, length); + if (ucs4_buff == (ucs4_t *)-1) + return -1; + entry_i->key = (ucs4_t *)malloc((length + 1) * sizeof(ucs4_t)); + ucs4cpy(entry_i->key, ucs4_buff); + free(ucs4_buff); + + /* 解析值 */ + size_t value_i, value_count = INITIAL_DICTIONARY_SIZE; + entry_i->value = (ucs4_t **)malloc(value_count * sizeof(ucs4_t *)); + + for (value_i = 0; *pbuff != '\0' && *pbuff != '\n'; ++value_i) { + if (value_i >= value_count) { + value_count += value_count; + entry_i->value = (ucs4_t **)realloc(entry_i->value, value_count * sizeof(ucs4_t *)); + } + + for (buff = ++pbuff; *pbuff != ' ' && *pbuff != '\0' && *pbuff != '\n'; ++pbuff) + ; + length = pbuff - buff; + ucs4_buff = utf8_to_ucs4(buff, length); + if (ucs4_buff == (ucs4_t *)-1) { + /* 發生錯誤 回退內存申請 */ + ssize_t i; + for (i = value_i - 1; i >= 0; --i) + free(entry_i->value[i]); + free(entry_i->value); + free(entry_i->key); + return -1; + } + + entry_i->value[value_i] = (ucs4_t *)malloc((length + 1) * sizeof(ucs4_t)); + ucs4cpy(entry_i->value[value_i], ucs4_buff); + free(ucs4_buff); + } + + entry_i->value = (ucs4_t **)realloc(entry_i->value, value_count * sizeof(ucs4_t *)); + entry_i->value[value_i] = NULL; + + return 0; +} + +dictionary_t dictionary_text_open(const char *filename) { + text_dictionary_desc *text_dictionary; + text_dictionary = (text_dictionary_desc *)malloc(sizeof(text_dictionary_desc)); + text_dictionary->entry_count = INITIAL_DICTIONARY_SIZE; + text_dictionary->max_length = 0; + text_dictionary->lexicon = (entry *)malloc(sizeof(entry) * text_dictionary->entry_count); + text_dictionary->word_buff = NULL; + + static char buff[ENTRY_BUFF_SIZE]; + + FILE *fp = fopen(filename, "rb"); + if (fp == NULL) { + dictionary_text_close((dictionary_t)text_dictionary); + return (dictionary_t)-1; + } + + size_t i = 0; + while (fgets(buff, ENTRY_BUFF_SIZE, fp)) { + if (i >= text_dictionary->entry_count) { + text_dictionary->entry_count += text_dictionary->entry_count; + text_dictionary->lexicon = (entry *)realloc(text_dictionary->lexicon, sizeof(entry) * text_dictionary->entry_count); + } + + if (parse_entry(buff, text_dictionary->lexicon + i) == -1) { + text_dictionary->entry_count = i; + dictionary_text_close((dictionary_t)text_dictionary); + return (dictionary_t)-1; + } + + size_t length = ucs4len(text_dictionary->lexicon[i].key); + if (length > text_dictionary->max_length) + text_dictionary->max_length = length; + + i++; + } + + fclose(fp); + + text_dictionary->entry_count = i; + text_dictionary->lexicon = (entry *)realloc(text_dictionary->lexicon, sizeof(entry) * text_dictionary->entry_count); + text_dictionary->word_buff = (ucs4_t *)malloc(sizeof(ucs4_t) * (text_dictionary->max_length + 1)); + + qsort(text_dictionary->lexicon, text_dictionary->entry_count, sizeof(text_dictionary->lexicon[0]), qsort_entry_cmp); + + return (dictionary_t)text_dictionary; +} + +void dictionary_text_close(dictionary_t t_dictionary) { + text_dictionary_desc *text_dictionary = (text_dictionary_desc *)t_dictionary; + + size_t i; + for (i = 0; i < text_dictionary->entry_count; ++i) { + free(text_dictionary->lexicon[i].key); + + ucs4_t **j; + for (j = text_dictionary->lexicon[i].value; *j; ++j) { + free(*j); + } + free(text_dictionary->lexicon[i].value); + } + + free(text_dictionary->lexicon); + free(text_dictionary->word_buff); + free(text_dictionary); +} + +const ucs4_t *const *dictionary_text_match_longest(dictionary_t t_dictionary, const ucs4_t *word, size_t maxlen, size_t *match_length) { + text_dictionary_desc *text_dictionary = (text_dictionary_desc *)t_dictionary; + + if (text_dictionary->entry_count == 0) + return NULL; + + if (maxlen == 0) + maxlen = ucs4len(word); + size_t len = text_dictionary->max_length; + if (maxlen < len) + len = maxlen; + + ucs4ncpy(text_dictionary->word_buff, word, len); + text_dictionary->word_buff[len] = L'\0'; + + entry buff; + buff.key = text_dictionary->word_buff; + + for (; len > 0; len--) { + text_dictionary->word_buff[len] = L'\0'; + entry *brs = + (entry *)bsearch(&buff, text_dictionary->lexicon, text_dictionary->entry_count, sizeof(text_dictionary->lexicon[0]), qsort_entry_cmp); + + if (brs != NULL) { + if (match_length != NULL) + *match_length = len; + return (const ucs4_t *const *)brs->value; + } + } + + if (match_length != NULL) + *match_length = 0; + return NULL; +} + +size_t dictionary_text_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t *word, size_t *match_length) { + text_dictionary_desc *text_dictionary = (text_dictionary_desc *)t_dictionary; + + size_t rscnt = 0; + + if (text_dictionary->entry_count == 0) + return rscnt; + + size_t length = ucs4len(word); + size_t len = text_dictionary->max_length; + if (length < len) + len = length; + + ucs4ncpy(text_dictionary->word_buff, word, len); + text_dictionary->word_buff[len] = L'\0'; + + entry buff; + buff.key = text_dictionary->word_buff; + + for (; len > 0; len--) { + text_dictionary->word_buff[len] = L'\0'; + entry *brs = + (entry *)bsearch(&buff, text_dictionary->lexicon, text_dictionary->entry_count, sizeof(text_dictionary->lexicon[0]), qsort_entry_cmp); + + if (brs != NULL) + match_length[rscnt++] = len; + } + + return rscnt; +} + +size_t dictionary_text_get_lexicon(dictionary_t t_dictionary, entry *lexicon) { + text_dictionary_desc *text_dictionary = (text_dictionary_desc *)t_dictionary; + + size_t i; + for (i = 0; i < text_dictionary->entry_count; i++) { + lexicon[i].key = text_dictionary->lexicon[i].key; + lexicon[i].value = text_dictionary->lexicon[i].value; + } + + return text_dictionary->entry_count; +} diff --git a/internal/cpp/opencc/dictionary/text.h b/internal/cpp/opencc/dictionary/text.h new file mode 100644 index 00000000000..bc52d008a25 --- /dev/null +++ b/internal/cpp/opencc/dictionary/text.h @@ -0,0 +1,36 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __OPENCC_DICTIONARY_TEXT_H_ +#define __OPENCC_DICTIONARY_TEXT_H_ + +#include "abstract.h" + +dictionary_t dictionary_text_open(const char * filename); + +void dictionary_text_close(dictionary_t t_dictionary); + +const ucs4_t * const * dictionary_text_match_longest(dictionary_t t_dictionary, const ucs4_t * word, + size_t maxlen, size_t * match_length); + +size_t dictionary_text_get_all_match_lengths(dictionary_t t_dictionary, const ucs4_t * word, + size_t * match_length); + +size_t dictionary_text_get_lexicon(dictionary_t t_dictionary, entry * lexicon); + +#endif /* __OPENCC_DICTIONARY_TEXT_H_ */ diff --git a/internal/cpp/opencc/dictionary_group.c b/internal/cpp/opencc/dictionary_group.c new file mode 100644 index 00000000000..f96e09e9176 --- /dev/null +++ b/internal/cpp/opencc/dictionary_group.c @@ -0,0 +1,177 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dictionary_group.h" + +#define DICTIONARY_MAX_COUNT 128 + +struct _dictionary_group { + size_t count; + dictionary_t dicts[DICTIONARY_MAX_COUNT]; +}; +typedef struct _dictionary_group dictionary_group_desc; + +static dictionary_error errnum = DICTIONARY_ERROR_VOID; + +dictionary_group_t dictionary_group_open(void) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)malloc(sizeof(dictionary_group_desc)); + + dictionary_group->count = 0; + + return dictionary_group; +} + +void dictionary_group_close(dictionary_group_t t_dictionary) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + + size_t i; + for (i = 0; i < dictionary_group->count; i++) + dictionary_close(dictionary_group->dicts[i]); + + free(dictionary_group); +} + +int dictionary_group_load(dictionary_group_t t_dictionary, const char *filename, const char *home_path, opencc_dictionary_type type) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + dictionary_t dictionary; + + FILE *fp = fopen(filename, "rb"); + if (!fp) { + char *new_filename = (char *)malloc(sizeof(char) * (strlen(filename) + strlen(home_path) + 2)); + sprintf(new_filename, "%s/%s", home_path, filename); + + fp = fopen(new_filename, "rb"); + if (!fp) { + free(new_filename); + errnum = DICTIONARY_ERROR_CANNOT_ACCESS_DICTFILE; + return -1; + } + dictionary = dictionary_open(new_filename, type); + free(new_filename); + } else { + dictionary = dictionary_open(filename, type); + } + fclose(fp); + + if (dictionary == (dictionary_t)-1) { + errnum = DICTIONARY_ERROR_INVALID_DICT; + return -1; + } + dictionary_group->dicts[dictionary_group->count++] = dictionary; + return 0; +} + +dictionary_t dictionary_group_get_dictionary(dictionary_group_t t_dictionary, size_t index) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + + if (index < 0 || index >= dictionary_group->count) { + errnum = DICTIONARY_ERROR_INVALID_INDEX; + return (dictionary_t)-1; + } + + return dictionary_group->dicts[index]; +} + +size_t dictionary_group_count(dictionary_group_t t_dictionary) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + return dictionary_group->count; +} + +const ucs4_t *const *dictionary_group_match_longest(dictionary_group_t t_dictionary, const ucs4_t *word, size_t maxlen, size_t *match_length) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + + if (dictionary_group->count == 0) { + errnum = DICTIONARY_ERROR_NODICT; + return (const ucs4_t *const *)-1; + } + + const ucs4_t *const *retval = NULL; + size_t t_match_length, max_length = 0; + + size_t i; + for (i = 0; i < dictionary_group->count; i++) { + /* 依次查找每個辭典,取得最長匹配長度 */ + const ucs4_t *const *t_retval = dictionary_match_longest(dictionary_group->dicts[i], word, maxlen, &t_match_length); + + if (t_retval != NULL) { + if (t_match_length > max_length) { + max_length = t_match_length; + retval = t_retval; + } + } + } + + if (match_length != NULL) { + *match_length = max_length; + } + + return retval; +} + +size_t dictionary_group_get_all_match_lengths(dictionary_group_t t_dictionary, const ucs4_t *word, size_t *match_length) { + dictionary_group_desc *dictionary_group = (dictionary_group_desc *)t_dictionary; + + if (dictionary_group->count == 0) { + errnum = DICTIONARY_ERROR_NODICT; + return (size_t)-1; + } + + size_t rscnt = 0; + size_t i; + for (i = 0; i < dictionary_group->count; i++) { + size_t retval; + retval = dictionary_get_all_match_lengths(dictionary_group->dicts[i], word, match_length + rscnt); + rscnt += retval; + /* 去除重複長度 */ + if (i > 0 && rscnt > 1) { + qsort(match_length, rscnt, sizeof(match_length[0]), qsort_int_cmp); + int j, k; + for (j = 0, k = 1; k < rscnt; k++) { + if (match_length[k] != match_length[j]) + match_length[++j] = match_length[k]; + } + rscnt = j + 1; + } + } + return rscnt; +} + +dictionary_error dictionary_errno(void) { return errnum; } + +void dictionary_perror(const char *spec) { + perr(spec); + perr("\n"); + switch (errnum) { + case DICTIONARY_ERROR_VOID: + break; + case DICTIONARY_ERROR_NODICT: + perr(_("No dictionary loaded")); + break; + case DICTIONARY_ERROR_CANNOT_ACCESS_DICTFILE: + perror(_("Can not open dictionary file")); + break; + case DICTIONARY_ERROR_INVALID_DICT: + perror(_("Invalid dictionary file")); + break; + case DICTIONARY_ERROR_INVALID_INDEX: + perror(_("Invalid dictionary index")); + break; + default: + perr(_("Unknown")); + } +} diff --git a/internal/cpp/opencc/dictionary_group.h b/internal/cpp/opencc/dictionary_group.h new file mode 100644 index 00000000000..f0fc064fd7d --- /dev/null +++ b/internal/cpp/opencc/dictionary_group.h @@ -0,0 +1,57 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __DICTIONARY_GROUP_H_ +#define __DICTIONARY_GROUP_H_ + +#include "utils.h" +#include "dictionary/abstract.h" + +typedef void * dictionary_group_t; + +typedef enum +{ + DICTIONARY_ERROR_VOID, + DICTIONARY_ERROR_NODICT, + DICTIONARY_ERROR_CANNOT_ACCESS_DICTFILE, + DICTIONARY_ERROR_INVALID_DICT, + DICTIONARY_ERROR_INVALID_INDEX, +} dictionary_error; + +dictionary_group_t dictionary_group_open(void); + +void dictionary_group_close(dictionary_group_t t_dictionary); + +int dictionary_group_load(dictionary_group_t t_dictionary, const char * filename, const char* home_dir, + opencc_dictionary_type type); + +const ucs4_t * const * dictionary_group_match_longest(dictionary_group_t t_dictionary, const ucs4_t * word, + size_t maxlen, size_t * match_length); + +size_t dictionary_group_get_all_match_lengths(dictionary_group_t t_dictionary, const ucs4_t * word, + size_t * match_length); + +dictionary_t dictionary_group_get_dictionary(dictionary_group_t t_dictionary, size_t index); + +size_t dictionary_group_count(dictionary_group_t t_dictionary); + +dictionary_error dictionary_errno(void); + +void dictionary_perror(const char * spec); + +#endif /* __DICTIONARY_GROUP_H_ */ diff --git a/internal/cpp/opencc/dictionary_set.c b/internal/cpp/opencc/dictionary_set.c new file mode 100644 index 00000000000..7a01f537136 --- /dev/null +++ b/internal/cpp/opencc/dictionary_set.c @@ -0,0 +1,73 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dictionary_set.h" + +#define DICTIONARY_GROUP_MAX_COUNT 128 + +struct _dictionary_set { + size_t count; + dictionary_group_t groups[DICTIONARY_GROUP_MAX_COUNT]; +}; +typedef struct _dictionary_set dictionary_set_desc; + +dictionary_set_t dictionary_set_open(void) { + dictionary_set_desc *dictionary_set = (dictionary_set_desc *)malloc(sizeof(dictionary_set_desc)); + + dictionary_set->count = 0; + + return dictionary_set; +} + +void dictionary_set_close(dictionary_set_t t_dictionary) { + dictionary_set_desc *dictionary_set = (dictionary_set_desc *)t_dictionary; + + size_t i; + for (i = 0; i < dictionary_set->count; i++) + dictionary_group_close(dictionary_set->groups[i]); + + free(dictionary_set); +} + +dictionary_group_t dictionary_set_new_group(dictionary_set_t t_dictionary) { + dictionary_set_desc *dictionary_set = (dictionary_set_desc *)t_dictionary; + + if (dictionary_set->count + 1 == DICTIONARY_GROUP_MAX_COUNT) { + return (dictionary_group_t)-1; + } + + dictionary_group_t group = dictionary_group_open(); + dictionary_set->groups[dictionary_set->count++] = group; + + return group; +} + +dictionary_group_t dictionary_set_get_group(dictionary_set_t t_dictionary, size_t index) { + dictionary_set_desc *dictionary_set = (dictionary_set_desc *)t_dictionary; + + if (index < 0 || index >= dictionary_set->count) { + return (dictionary_group_t)-1; + } + + return dictionary_set->groups[index]; +} + +size_t dictionary_set_count_group(dictionary_set_t t_dictionary) { + dictionary_set_desc *dictionary_set = (dictionary_set_desc *)t_dictionary; + return dictionary_set->count; +} diff --git a/internal/cpp/opencc/dictionary_set.h b/internal/cpp/opencc/dictionary_set.h new file mode 100644 index 00000000000..39be7b6132c --- /dev/null +++ b/internal/cpp/opencc/dictionary_set.h @@ -0,0 +1,37 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __DICTIONARY_SET_H_ +#define __DICTIONARY_SET_H_ + +#include "utils.h" +#include "dictionary_group.h" + +typedef void * dictionary_set_t; + +dictionary_set_t dictionary_set_open(void); + +void dictionary_set_close(dictionary_set_t t_dictionary); + +dictionary_group_t dictionary_set_new_group(dictionary_set_t t_dictionary); + +dictionary_group_t dictionary_set_get_group(dictionary_set_t t_dictionary, size_t index); + +size_t dictionary_set_count_group(dictionary_set_t t_dictionary); + +#endif /* __DICTIONARY_SET_H_ */ diff --git a/internal/cpp/opencc/encoding.c b/internal/cpp/opencc/encoding.c new file mode 100644 index 00000000000..d2e3056d7f5 --- /dev/null +++ b/internal/cpp/opencc/encoding.c @@ -0,0 +1,230 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "encoding.h" +#include "opencc.h" + +#define INITIAL_BUFF_SIZE 1024 +#define GET_BIT(byte, pos) (((byte) >> (pos)) & 1) +#define BITMASK(length) ((1 << length) - 1) + +ucs4_t *utf8_to_ucs4(const char *utf8, size_t length) { + if (length == 0) + length = (size_t)-1; + size_t i; + for (i = 0; i < length && utf8[i] != '\0'; i++) + ; + length = i; + + size_t freesize = INITIAL_BUFF_SIZE; + ucs4_t *ucs4 = (ucs4_t *)malloc(sizeof(ucs4_t) * freesize); + ucs4_t *pucs4 = ucs4; + + for (i = 0; i < length; i++) { + ucs4_t byte[4] = {0}; + if (GET_BIT(utf8[i], 7) == 0) { + /* U-00000000 - U-0000007F */ + /* 0xxxxxxx */ + byte[0] = utf8[i] & BITMASK(7); + } else if (GET_BIT(utf8[i], 5) == 0) { + /* U-00000080 - U-000007FF */ + /* 110xxxxx 10xxxxxx */ + if (i + 1 >= length) + goto err; + + byte[0] = (utf8[i + 1] & BITMASK(6)) + ((utf8[i] & BITMASK(2)) << 6); + byte[1] = (utf8[i] >> 2) & BITMASK(3); + + i += 1; + } else if (GET_BIT(utf8[i], 4) == 0) { + /* U-00000800 - U-0000FFFF */ + /* 1110xxxx 10xxxxxx 10xxxxxx */ + if (i + 2 >= length) + goto err; + + byte[0] = (utf8[i + 2] & BITMASK(6)) + ((utf8[i + 1] & BITMASK(2)) << 6); + byte[1] = ((utf8[i + 1] >> 2) & BITMASK(4)) + ((utf8[i] & BITMASK(4)) << 4); + + i += 2; + } else if (GET_BIT(utf8[i], 3) == 0) { + /* U-00010000 - U-001FFFFF */ + /* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (i + 3 >= length) + goto err; + + byte[0] = (utf8[i + 3] & BITMASK(6)) + ((utf8[i + 2] & BITMASK(2)) << 6); + byte[1] = ((utf8[i + 2] >> 2) & BITMASK(4)) + ((utf8[i + 1] & BITMASK(4)) << 4); + byte[2] = ((utf8[i + 1] >> 4) & BITMASK(2)) + ((utf8[i] & BITMASK(3)) << 2); + + i += 3; + } else if (GET_BIT(utf8[i], 2) == 0) { + /* U-00200000 - U-03FFFFFF */ + /* 111110xx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (i + 4 >= length) + goto err; + + byte[0] = (utf8[i + 4] & BITMASK(6)) + ((utf8[i + 3] & BITMASK(2)) << 6); + byte[1] = ((utf8[i + 3] >> 2) & BITMASK(4)) + ((utf8[i + 2] & BITMASK(4)) << 4); + byte[2] = ((utf8[i + 2] >> 4) & BITMASK(2)) + ((utf8[i + 1] & BITMASK(6)) << 2); + byte[3] = utf8[i] & BITMASK(2); + i += 4; + } else if (GET_BIT(utf8[i], 2) == 0) { + /* U-04000000 - U-7FFFFFFF */ + /* 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (i + 5 >= length) + goto err; + + byte[0] = (utf8[i + 5] & BITMASK(6)) + ((utf8[i + 4] & BITMASK(2)) << 6); + byte[1] = ((utf8[i + 4] >> 2) & BITMASK(4)) + ((utf8[i + 3] & BITMASK(4)) << 4); + byte[2] = ((utf8[i + 3] >> 4) & BITMASK(2)) + ((utf8[i + 2] & BITMASK(6)) << 2); + byte[3] = (utf8[i + 1] & BITMASK(6)) + ((utf8[i] & BITMASK(1)) << 6); + i += 5; + } else + goto err; + + if (freesize == 0) { + freesize = pucs4 - ucs4; + ucs4 = (ucs4_t *)realloc(ucs4, sizeof(ucs4_t) * (freesize + freesize)); + pucs4 = ucs4 + freesize; + } + + *pucs4 = (byte[3] << 24) + (byte[2] << 16) + (byte[1] << 8) + byte[0]; + + pucs4++; + freesize--; + } + + length = (pucs4 - ucs4 + 1); + ucs4 = (ucs4_t *)realloc(ucs4, sizeof(ucs4_t) * length); + ucs4[length - 1] = 0; + return ucs4; + +err: + free(ucs4); + return (ucs4_t *)-1; +} + +char *ucs4_to_utf8(const ucs4_t *ucs4, size_t length) { + if (length == 0) + length = (size_t)-1; + size_t i; + for (i = 0; i < length && ucs4[i] != 0; i++) + ; + length = i; + + size_t freesize = INITIAL_BUFF_SIZE; + char *utf8 = (char *)malloc(sizeof(char) * freesize); + char *putf8 = utf8; + + for (i = 0; i < length; i++) { + if ((ssize_t)freesize - 6 <= 0) { + freesize = putf8 - utf8; + utf8 = (char *)realloc(utf8, sizeof(char) * (freesize + freesize)); + putf8 = utf8 + freesize; + } + + ucs4_t c = ucs4[i]; + ucs4_t byte[4] = {(c >> 0) & BITMASK(8), (c >> 8) & BITMASK(8), (c >> 16) & BITMASK(8), (c >> 24) & BITMASK(8)}; + + size_t delta = 0; + + if (c <= 0x7F) { + /* U-00000000 - U-0000007F */ + /* 0xxxxxxx */ + putf8[0] = byte[0] & BITMASK(7); + delta = 1; + } else if (c <= 0x7FF) { + /* U-00000080 - U-000007FF */ + /* 110xxxxx 10xxxxxx */ + putf8[1] = 0x80 + (byte[0] & BITMASK(6)); + putf8[0] = 0xC0 + ((byte[0] >> 6) & BITMASK(2)) + ((byte[1] & BITMASK(3)) << 2); + delta = 2; + } else if (c <= 0xFFFF) { + /* U-00000800 - U-0000FFFF */ + /* 1110xxxx 10xxxxxx 10xxxxxx */ + putf8[2] = 0x80 + (byte[0] & BITMASK(6)); + putf8[1] = 0x80 + ((byte[0] >> 6) & BITMASK(2)) + ((byte[1] & BITMASK(4)) << 2); + putf8[0] = 0xE0 + ((byte[1] >> 4) & BITMASK(4)); + delta = 3; + } else if (c <= 0x1FFFFF) { + /* U-00010000 - U-001FFFFF */ + /* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ + putf8[3] = 0x80 + (byte[0] & BITMASK(6)); + putf8[2] = 0x80 + ((byte[0] >> 6) & BITMASK(2)) + ((byte[1] & BITMASK(4)) << 2); + putf8[1] = 0x80 + ((byte[1] >> 4) & BITMASK(4)) + ((byte[2] & BITMASK(2)) << 4); + putf8[0] = 0xF0 + ((byte[2] >> 2) & BITMASK(3)); + delta = 4; + } else if (c <= 0x3FFFFFF) { + /* U-00200000 - U-03FFFFFF */ + /* 111110xx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + putf8[4] = 0x80 + (byte[0] & BITMASK(6)); + putf8[3] = 0x80 + ((byte[0] >> 6) & BITMASK(2)) + ((byte[1] & BITMASK(4)) << 2); + putf8[2] = 0x80 + ((byte[1] >> 4) & BITMASK(4)) + ((byte[2] & BITMASK(2)) << 4); + putf8[1] = 0x80 + ((byte[2] >> 2) & BITMASK(6)); + putf8[0] = 0xF8 + (byte[3] & BITMASK(2)); + delta = 5; + + } else if (c <= 0x7FFFFFFF) { + /* U-04000000 - U-7FFFFFFF */ + /* 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + putf8[5] = 0x80 + (byte[0] & BITMASK(6)); + putf8[4] = 0x80 + ((byte[0] >> 6) & BITMASK(2)) + ((byte[1] & BITMASK(4)) << 2); + putf8[3] = 0x80 + ((byte[1] >> 4) & BITMASK(4)) + ((byte[2] & BITMASK(2)) << 4); + putf8[2] = 0x80 + ((byte[2] >> 2) & BITMASK(6)); + putf8[1] = 0x80 + (byte[3] & BITMASK(6)); + putf8[0] = 0xFC + ((byte[3] >> 6) & BITMASK(1)); + delta = 6; + } else { + free(utf8); + return (char *)-1; + } + + putf8 += delta; + freesize -= delta; + } + + length = (putf8 - utf8 + 1); + utf8 = (char *)realloc(utf8, sizeof(char) * length); + utf8[length - 1] = '\0'; + return utf8; +} + +size_t ucs4len(const ucs4_t *str) { + const register ucs4_t *pstr = str; + while (*pstr) + ++pstr; + return pstr - str; +} + +int ucs4cmp(const ucs4_t *src, const ucs4_t *dst) { + register int ret = 0; + while (!(ret = *src - *dst) && *dst) + ++src, ++dst; + return ret; +} + +void ucs4cpy(ucs4_t *dest, const ucs4_t *src) { + while (*src) + *dest++ = *src++; + *dest = 0; +} + +void ucs4ncpy(ucs4_t *dest, const ucs4_t *src, size_t len) { + while (*src && len-- > 0) + *dest++ = *src++; +} diff --git a/internal/cpp/opencc/encoding.h b/internal/cpp/opencc/encoding.h new file mode 100644 index 00000000000..d54a526ab0d --- /dev/null +++ b/internal/cpp/opencc/encoding.h @@ -0,0 +1,36 @@ +/* +* Open Chinese Convert +* +* Copyright 2010 BYVoid +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef __OPENCC_ENCODING_H_ +#define __OPENCC_ENCODING_H_ + +#include "utils.h" + +ucs4_t * utf8_to_ucs4(const char * utf8, size_t length); + +char * ucs4_to_utf8(const ucs4_t * ucs4, size_t length); + +size_t ucs4len(const ucs4_t * str); + +int ucs4cmp(const ucs4_t * str1, const ucs4_t * str2); + +void ucs4cpy(ucs4_t * dest, const ucs4_t * src); + +void ucs4ncpy(ucs4_t * dest, const ucs4_t * src, size_t len); + +#endif /* __OPENCC_ENCODING_H_ */ diff --git a/internal/cpp/opencc/opencc.c b/internal/cpp/opencc/opencc.c new file mode 100644 index 00000000000..58c23958479 --- /dev/null +++ b/internal/cpp/opencc/opencc.c @@ -0,0 +1,219 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "opencc.h" +#include "config_reader.h" +#include "converter.h" +#include "dictionary_set.h" +#include "encoding.h" +#include "utils.h" + +typedef struct { + dictionary_set_t dictionary_set; + converter_t converter; +} opencc_desc; + +static opencc_error errnum = OPENCC_ERROR_VOID; +static int lib_initialized = FALSE; + +static void lib_initialize(void) { lib_initialized = TRUE; } + +size_t opencc_convert(opencc_t t_opencc, ucs4_t **inbuf, size_t *inbuf_left, ucs4_t **outbuf, size_t *outbuf_left) { + if (!lib_initialized) + lib_initialize(); + + opencc_desc *opencc = (opencc_desc *)t_opencc; + + size_t retval = converter_convert(opencc->converter, inbuf, inbuf_left, outbuf, outbuf_left); + + if (retval == (size_t)-1) + errnum = OPENCC_ERROR_CONVERTER; + + return retval; +} + +char *opencc_convert_utf8(opencc_t t_opencc, const char *inbuf, size_t length) { + if (!lib_initialized) + lib_initialize(); + + if (length == (size_t)-1 || length > strlen(inbuf)) + length = strlen(inbuf); + + /* 將輸入數據轉換爲ucs4_t字符串 */ + ucs4_t *winbuf = utf8_to_ucs4(inbuf, length); + if (winbuf == (ucs4_t *)-1) { + /* 輸入數據轉換失敗 */ + errnum = OPENCC_ERROR_ENCODIND; + return (char *)-1; + } + + /* 設置輸出UTF8文本緩衝區空間 */ + size_t outbuf_len = length; + size_t outsize = outbuf_len; + char *original_outbuf = (char *)malloc(sizeof(char) * (outbuf_len + 1)); + char *outbuf = original_outbuf; + original_outbuf[0] = '\0'; + + /* 設置轉換緩衝區空間 */ + size_t wbufsize = length + 64; + ucs4_t *woutbuf = (ucs4_t *)malloc(sizeof(ucs4_t) * (wbufsize + 1)); + + ucs4_t *pinbuf = winbuf; + ucs4_t *poutbuf = woutbuf; + size_t inbuf_left, outbuf_left; + + inbuf_left = ucs4len(winbuf); + outbuf_left = wbufsize; + + while (inbuf_left > 0) { + size_t retval = opencc_convert(t_opencc, &pinbuf, &inbuf_left, &poutbuf, &outbuf_left); + if (retval == (size_t)-1) { + free(outbuf); + free(winbuf); + free(woutbuf); + return (char *)-1; + } + + *poutbuf = L'\0'; + + char *ubuff = ucs4_to_utf8(woutbuf, (size_t)-1); + + if (ubuff == (char *)-1) { + free(outbuf); + free(winbuf); + free(woutbuf); + errnum = OPENCC_ERROR_ENCODIND; + return (char *)-1; + } + + size_t ubuff_len = strlen(ubuff); + + while (ubuff_len > outsize) { + size_t outbuf_offset = outbuf - original_outbuf; + outsize += outbuf_len; + outbuf_len += outbuf_len; + original_outbuf = (char *)realloc(original_outbuf, sizeof(char) * outbuf_len); + outbuf = original_outbuf + outbuf_offset; + } + + strncpy(outbuf, ubuff, ubuff_len); + free(ubuff); + + outbuf += ubuff_len; + *outbuf = '\0'; + + outbuf_left = wbufsize; + poutbuf = woutbuf; + } + + free(winbuf); + free(woutbuf); + + original_outbuf = (char *)realloc(original_outbuf, sizeof(char) * (strlen(original_outbuf) + 1)); + + return original_outbuf; +} + +opencc_t opencc_open(const char *config_file, const char *home_path) { + if (!lib_initialized) + lib_initialize(); + + opencc_desc *opencc; + opencc = (opencc_desc *)malloc(sizeof(opencc_desc)); + + opencc->dictionary_set = NULL; + opencc->converter = converter_open(); + converter_set_conversion_mode(opencc->converter, OPENCC_CONVERSION_FAST); + + /* 加載默認辭典 */ + int retval; + if (config_file == NULL) + retval = 0; + else { + config_t config = config_open(config_file, home_path); + + if (config == (config_t)-1) { + errnum = OPENCC_ERROR_CONFIG; + return (opencc_t)-1; + } + + opencc->dictionary_set = config_get_dictionary_set(config); + converter_assign_dictionary(opencc->converter, opencc->dictionary_set); + + config_close(config); + } + + return (opencc_t)opencc; +} + +int opencc_close(opencc_t t_opencc) { + if (!lib_initialized) + lib_initialize(); + + opencc_desc *opencc = (opencc_desc *)t_opencc; + + converter_close(opencc->converter); + if (opencc->dictionary_set != NULL) + dictionary_set_close(opencc->dictionary_set); + free(opencc); + + return 0; +} + +void opencc_set_conversion_mode(opencc_t t_opencc, opencc_conversion_mode conversion_mode) { + if (!lib_initialized) + lib_initialize(); + + opencc_desc *opencc = (opencc_desc *)t_opencc; + + converter_set_conversion_mode(opencc->converter, conversion_mode); +} + +opencc_error opencc_errno(void) { + if (!lib_initialized) + lib_initialize(); + + return errnum; +} + +void opencc_perror(const char *spec) { + if (!lib_initialized) + lib_initialize(); + + perr(spec); + perr("\n"); + switch (errnum) { + case OPENCC_ERROR_VOID: + break; + case OPENCC_ERROR_DICTLOAD: + dictionary_perror(_("Dictionary loading error")); + break; + case OPENCC_ERROR_CONFIG: + config_perror(_("Configuration error")); + break; + case OPENCC_ERROR_CONVERTER: + converter_perror(_("Converter error")); + break; + case OPENCC_ERROR_ENCODIND: + perr(_("Encoding error")); + break; + default: + perr(_("Unknown")); + } + perr("\n"); +} diff --git a/internal/cpp/opencc/opencc.h b/internal/cpp/opencc/opencc.h new file mode 100644 index 00000000000..11a1f2e6745 --- /dev/null +++ b/internal/cpp/opencc/opencc.h @@ -0,0 +1,116 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPENCC_H_ +#define __OPENCC_H_ + +#include "opencc_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Headers from C standard library + */ + +/* Macros */ +#define OPENCC_DEFAULT_CONFIG_SIMP_TO_TRAD "zhs2zht.ini" +#define OPENCC_DEFAULT_CONFIG_TRAD_TO_SIMP "zht2zhs.ini" + +/** + * opencc_open: + * @config_file: Location of configuration file. + * @returns: A description pointer of the newly allocated instance of opencc. + * + * Make an instance of opencc. + * + * Note: Leave config_file to NULL if you do not want to load any configuration file. + * + */ +opencc_t opencc_open(const char *config_file, const char *home_path); + +/** + * opencc_close: + * @od: The description pointer. + * @returns: 0 on success or non-zero number on failure. + * + * Destroy an instance of opencc. + * + */ +int opencc_close(opencc_t od); + +/** + * opencc_convert: + * @od: The opencc description pointer. + * @inbuf: The pointer to the wide character string of the input buffer. + * @inbufleft: The maximum number of characters in *inbuf to convert. + * @outbuf: The pointer to the wide character string of the output buffer. + * @outbufleft: The size of output buffer. + * + * @returns: The number of characters of the input buffer that converted. + * + * Convert string from *inbuf to *outbuf. + * + * Note: Don't forget to assign **outbuf to L'\0' after called. + * + */ +size_t opencc_convert(opencc_t od, ucs4_t **inbuf, size_t *inbufleft, ucs4_t **outbuf, size_t *outbufleft); + +/** + * opencc_convert_utf8: + * @od: The opencc description pointer. + * @inbuf: The UTF-8 encoded string. + * @length: The maximum number of characters in inbuf to convert. + * + * @returns: The newly allocated UTF-8 string that converted from inbuf. + * + * Convert UTF-8 string from inbuf. This function returns a newly allocated + * c-style string via malloc(), which stores the converted string. + * DON'T FORGET TO CALL free() to recycle memory. + * + */ +char *opencc_convert_utf8(opencc_t t_opencc, const char *inbuf, size_t length); + +void opencc_set_conversion_mode(opencc_t t_opencc, opencc_conversion_mode conversion_mode); + +/** + * opencc_errno: + * + * @returns: The error number. + * + * Return an opencc_convert_errno_t which describes the last error that occured or + * OPENCC_CONVERT_ERROR_VOID + * + */ +opencc_error opencc_errno(void); + +/** + * opencc_perror: + * @spec Prefix message. + * + * Print the error message to stderr. + * + */ +void opencc_perror(const char *spec); + +#ifdef __cplusplus +}; +#endif + +#endif /* __OPENCC_H_ */ diff --git a/internal/cpp/opencc/opencc_types.h b/internal/cpp/opencc/opencc_types.h new file mode 100644 index 00000000000..03dd4245919 --- /dev/null +++ b/internal/cpp/opencc/opencc_types.h @@ -0,0 +1,59 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPENCC_TYPES_H_ +#define __OPENCC_TYPES_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +typedef void *opencc_t; + +typedef uint32_t ucs4_t; + +enum _opencc_error { + OPENCC_ERROR_VOID, + OPENCC_ERROR_DICTLOAD, + OPENCC_ERROR_CONFIG, + OPENCC_ERROR_ENCODIND, + OPENCC_ERROR_CONVERTER, +}; +typedef enum _opencc_error opencc_error; + +enum _opencc_dictionary_type { + OPENCC_DICTIONARY_TYPE_TEXT, + OPENCC_DICTIONARY_TYPE_DATRIE, +}; +typedef enum _opencc_dictionary_type opencc_dictionary_type; + +enum _opencc_conversion_mode { + OPENCC_CONVERSION_FAST, + OPENCC_CONVERSION_SEGMENT_ONLY, + OPENCC_CONVERSION_LIST_CANDIDATES, +}; +typedef enum _opencc_conversion_mode opencc_conversion_mode; + +#ifdef __cplusplus +}; +#endif + +#endif /* __OPENCC_TYPES_H_ */ diff --git a/internal/cpp/opencc/openccxx.cpp b/internal/cpp/opencc/openccxx.cpp new file mode 100644 index 00000000000..54b27e0d26f --- /dev/null +++ b/internal/cpp/opencc/openccxx.cpp @@ -0,0 +1,80 @@ +#include "openccxx.h" +#include "opencc.h" +#include "utils.h" + +#include +#include + +OpenCC::OpenCC(const std::string &home_dir) : od((opencc_t)-1) { + config_file = mstrcpy(OPENCC_DEFAULT_CONFIG_TRAD_TO_SIMP); + open(config_file, home_dir.c_str()); +} + +OpenCC::~OpenCC() { + if (od != (opencc_t)-1) + opencc_close(od); + free(config_file); +} + +int OpenCC::open(const char *config_file, const char *home_dir) { + if (od != (opencc_t)-1) + opencc_close(od); + od = opencc_open(config_file, home_dir); + return (od == (opencc_t)-1) ? (-1) : (0); +} + +long OpenCC::convert(const std::string &in, std::string &out, long length) { + if (od == (opencc_t)-1) + return -1; + + if (length == -1) + length = in.length(); + + char *outbuf = opencc_convert_utf8(od, in.c_str(), length); + + if (outbuf == (char *)-1) + return -1; + + out = outbuf; + free(outbuf); + + return length; +} + +/** + * Warning: + * This method can be used only if wchar_t is encoded in UCS4 on your platform. + */ +long OpenCC::convert(const std::wstring &in, std::wstring &out, long length) { + if (od == (opencc_t)-1) + return -1; + + size_t inbuf_left = in.length(); + if (length >= 0 && length < (long)inbuf_left) + inbuf_left = length; + + const ucs4_t *inbuf = (const ucs4_t *)in.c_str(); + long count = 0; + + while (inbuf_left != 0) { + size_t retval; + size_t outbuf_left; + ucs4_t *outbuf; + + /* occupy space */ + outbuf_left = inbuf_left + 64; + out.resize(count + outbuf_left); + outbuf = (ucs4_t *)out.c_str() + count; + + retval = opencc_convert(od, (ucs4_t **)&inbuf, &inbuf_left, &outbuf, &outbuf_left); + if (retval == (size_t)-1) + return -1; + count += retval; + } + + /* set the zero termination and shrink the size */ + out.resize(count + 1); + out[count] = L'\0'; + + return count; +} diff --git a/internal/cpp/opencc/openccxx.h b/internal/cpp/opencc/openccxx.h new file mode 100644 index 00000000000..844bbacdb5e --- /dev/null +++ b/internal/cpp/opencc/openccxx.h @@ -0,0 +1,20 @@ +#pragma once + +#include "opencc_types.h" +#include + +class OpenCC { +public: + OpenCC(const std::string &home_dir); + virtual ~OpenCC(); + + int open(const char *config_file, const char *home_dir); + + long convert(const std::string &in, std::string &out, long length = -1); + + long convert(const std::wstring &in, std::wstring &out, long length = -1); + +private: + char *config_file; + opencc_t od; +}; diff --git a/internal/cpp/opencc/utils.c b/internal/cpp/opencc/utils.c new file mode 100644 index 00000000000..9f93aae8f3f --- /dev/null +++ b/internal/cpp/opencc/utils.c @@ -0,0 +1,36 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils.h" + +void perr(const char *str) { fputs(str, stderr); } + +int qsort_int_cmp(const void *a, const void *b) { return *((int *)a) - *((int *)b); } + +char *mstrcpy(const char *str) { + char *strbuf = (char *)malloc(sizeof(char) * (strlen(str) + 1)); + strcpy(strbuf, str); + return strbuf; +} + +char *mstrncpy(const char *str, size_t n) { + char *strbuf = (char *)malloc(sizeof(char) * (n + 1)); + strncpy(strbuf, str, n); + strbuf[n] = '\0'; + return strbuf; +} diff --git a/internal/cpp/opencc/utils.h b/internal/cpp/opencc/utils.h new file mode 100644 index 00000000000..693249a6651 --- /dev/null +++ b/internal/cpp/opencc/utils.h @@ -0,0 +1,71 @@ +/* + * Open Chinese Convert + * + * Copyright 2010 BYVoid + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPENCC_UTILS_H_ +#define __OPENCC_UTILS_H_ + +#include +#include +#include +#include + +#include "opencc_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define FALSE (0) +#define TRUE (!(0)) +#define INFINITY_INT ((~0U) >> 1) + +#ifndef BIG_ENDIAN +#define BIG_ENDIAN (0) +#endif + +#ifndef LITTLE_ENDIAN +#define LITTLE_ENDIAN (1) +#endif + +#ifdef ENABLE_GETTEXT +#include +#include +#define _(STRING) dgettext(PACKAGE_NAME, STRING) +#else +#define _(STRING) STRING +#endif + +#define debug_should_not_be_here() \ + do { \ + fprintf(stderr, "Should not be here %s: %d\n", __FILE__, __LINE__); \ + assert(0); \ + } while (0) + +void perr(const char *str); + +int qsort_int_cmp(const void *a, const void *b); + +char *mstrcpy(const char *str); + +char *mstrncpy(const char *str, size_t n); + +#ifdef __cplusplus +}; +#endif + +#endif /* __OPENCC_UTILS_H_ */ diff --git a/internal/cpp/pcre2.h b/internal/cpp/pcre2.h new file mode 100644 index 00000000000..37431c72452 --- /dev/null +++ b/internal/cpp/pcre2.h @@ -0,0 +1,1079 @@ +/************************************************* +* Perl-Compatible Regular Expressions * +*************************************************/ + +/* This is the public header file for the PCRE library, second API, to be +#included by applications that call PCRE2 functions. + + Copyright (c) 2016-2024 University of Cambridge + +----------------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the University of Cambridge nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +*/ + +#ifndef PCRE2_H_IDEMPOTENT_GUARD +#define PCRE2_H_IDEMPOTENT_GUARD + +/* The current PCRE version information. */ + +#define PCRE2_MAJOR 10 +#define PCRE2_MINOR 47 +#define PCRE2_PRERELEASE +#define PCRE2_DATE 2025-10-21 + +/* When an application links to a PCRE2 DLL in Windows, the symbols that are +imported have to be identified as such. When building PCRE2, the appropriate +export setting is defined in pcre2_internal.h, which includes this file. So, we +don't change existing definitions of PCRE2_EXP_DECL. + +By default, we use the standard "extern" declarations. */ + +#ifndef PCRE2_EXP_DECL +# if defined(_WIN32) && !1 +# define PCRE2_EXP_DECL extern __declspec(dllimport) +# elif defined __cplusplus +# define PCRE2_EXP_DECL extern "C" +# else +# define PCRE2_EXP_DECL extern +# endif +#endif + +/* When compiling with the MSVC compiler, it is sometimes necessary to include +a "calling convention" before exported function names. For example: + + void __cdecl function(....) + +might be needed. In order to make this easy, all the exported functions have +PCRE2_CALL_CONVENTION just before their names. + +PCRE2 normally uses the platform's standard calling convention, so this should +not be set unless you know you need it. */ + +#ifndef PCRE2_CALL_CONVENTION +#define PCRE2_CALL_CONVENTION +#endif + +/* Have to include limits.h, stdlib.h, and inttypes.h to ensure that size_t and +uint8_t, UCHAR_MAX, etc are defined. Some systems that do have inttypes.h do +not have stdint.h, which is why we use inttypes.h, which according to the C +standard is a superset of stdint.h. If inttypes.h is not available the build +will break and the relevant values must be provided by some other means. */ + +#include +#include +#include + +/* Allow for C++ users compiling this directly. */ + +#ifdef __cplusplus +extern "C" { +#endif + +/* The following option bits can be passed to pcre2_compile(), pcre2_match(), +or pcre2_dfa_match(). PCRE2_NO_UTF_CHECK affects only the function to which it +is passed. Put these bits at the most significant end of the options word so +others can be added next to them */ + +#define PCRE2_ANCHORED 0x80000000u +#define PCRE2_NO_UTF_CHECK 0x40000000u +#define PCRE2_ENDANCHORED 0x20000000u + +/* The following option bits can be passed only to pcre2_compile(). However, +they may affect compilation, JIT compilation, and/or interpretive execution. +The following tags indicate which: + +C alters what is compiled by pcre2_compile() +J alters what is compiled by pcre2_jit_compile() +M is inspected during pcre2_match() execution +D is inspected during pcre2_dfa_match() execution +*/ + +#define PCRE2_ALLOW_EMPTY_CLASS 0x00000001u /* C */ +#define PCRE2_ALT_BSUX 0x00000002u /* C */ +#define PCRE2_AUTO_CALLOUT 0x00000004u /* C */ +#define PCRE2_CASELESS 0x00000008u /* C */ +#define PCRE2_DOLLAR_ENDONLY 0x00000010u /* J M D */ +#define PCRE2_DOTALL 0x00000020u /* C */ +#define PCRE2_DUPNAMES 0x00000040u /* C */ +#define PCRE2_EXTENDED 0x00000080u /* C */ +#define PCRE2_FIRSTLINE 0x00000100u /* J M D */ +#define PCRE2_MATCH_UNSET_BACKREF 0x00000200u /* C J M */ +#define PCRE2_MULTILINE 0x00000400u /* C */ +#define PCRE2_NEVER_UCP 0x00000800u /* C */ +#define PCRE2_NEVER_UTF 0x00001000u /* C */ +#define PCRE2_NO_AUTO_CAPTURE 0x00002000u /* C */ +#define PCRE2_NO_AUTO_POSSESS 0x00004000u /* C */ +#define PCRE2_NO_DOTSTAR_ANCHOR 0x00008000u /* C */ +#define PCRE2_NO_START_OPTIMIZE 0x00010000u /* J M D */ +#define PCRE2_UCP 0x00020000u /* C J M D */ +#define PCRE2_UNGREEDY 0x00040000u /* C */ +#define PCRE2_UTF 0x00080000u /* C J M D */ +#define PCRE2_NEVER_BACKSLASH_C 0x00100000u /* C */ +#define PCRE2_ALT_CIRCUMFLEX 0x00200000u /* J M D */ +#define PCRE2_ALT_VERBNAMES 0x00400000u /* C */ +#define PCRE2_USE_OFFSET_LIMIT 0x00800000u /* J M D */ +#define PCRE2_EXTENDED_MORE 0x01000000u /* C */ +#define PCRE2_LITERAL 0x02000000u /* C */ +#define PCRE2_MATCH_INVALID_UTF 0x04000000u /* J M D */ +#define PCRE2_ALT_EXTENDED_CLASS 0x08000000u /* C */ + +/* An additional compile options word is available in the compile context. */ + +#define PCRE2_EXTRA_ALLOW_SURROGATE_ESCAPES 0x00000001u /* C */ +#define PCRE2_EXTRA_BAD_ESCAPE_IS_LITERAL 0x00000002u /* C */ +#define PCRE2_EXTRA_MATCH_WORD 0x00000004u /* C */ +#define PCRE2_EXTRA_MATCH_LINE 0x00000008u /* C */ +#define PCRE2_EXTRA_ESCAPED_CR_IS_LF 0x00000010u /* C */ +#define PCRE2_EXTRA_ALT_BSUX 0x00000020u /* C */ +#define PCRE2_EXTRA_ALLOW_LOOKAROUND_BSK 0x00000040u /* C */ +#define PCRE2_EXTRA_CASELESS_RESTRICT 0x00000080u /* C */ +#define PCRE2_EXTRA_ASCII_BSD 0x00000100u /* C */ +#define PCRE2_EXTRA_ASCII_BSS 0x00000200u /* C */ +#define PCRE2_EXTRA_ASCII_BSW 0x00000400u /* C */ +#define PCRE2_EXTRA_ASCII_POSIX 0x00000800u /* C */ +#define PCRE2_EXTRA_ASCII_DIGIT 0x00001000u /* C */ +#define PCRE2_EXTRA_PYTHON_OCTAL 0x00002000u /* C */ +#define PCRE2_EXTRA_NO_BS0 0x00004000u /* C */ +#define PCRE2_EXTRA_NEVER_CALLOUT 0x00008000u /* C */ +#define PCRE2_EXTRA_TURKISH_CASING 0x00010000u /* C */ + +/* These are for pcre2_jit_compile(). */ + +#define PCRE2_JIT_COMPLETE 0x00000001u /* For full matching */ +#define PCRE2_JIT_PARTIAL_SOFT 0x00000002u +#define PCRE2_JIT_PARTIAL_HARD 0x00000004u +#define PCRE2_JIT_INVALID_UTF 0x00000100u +#define PCRE2_JIT_TEST_ALLOC 0x00000200u + +/* These are for pcre2_match(), pcre2_dfa_match(), pcre2_jit_match(), and +pcre2_substitute(). Some are allowed only for one of the functions, and in +these cases it is noted below. Note that PCRE2_ANCHORED, PCRE2_ENDANCHORED and +PCRE2_NO_UTF_CHECK can also be passed to these functions (though +pcre2_jit_match() ignores the latter since it bypasses all sanity checks). */ + +#define PCRE2_NOTBOL 0x00000001u +#define PCRE2_NOTEOL 0x00000002u +#define PCRE2_NOTEMPTY 0x00000004u /* ) These two must be kept */ +#define PCRE2_NOTEMPTY_ATSTART 0x00000008u /* ) adjacent to each other. */ +#define PCRE2_PARTIAL_SOFT 0x00000010u +#define PCRE2_PARTIAL_HARD 0x00000020u +#define PCRE2_DFA_RESTART 0x00000040u /* pcre2_dfa_match() only */ +#define PCRE2_DFA_SHORTEST 0x00000080u /* pcre2_dfa_match() only */ +#define PCRE2_SUBSTITUTE_GLOBAL 0x00000100u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_EXTENDED 0x00000200u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_UNSET_EMPTY 0x00000400u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_UNKNOWN_UNSET 0x00000800u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_OVERFLOW_LENGTH 0x00001000u /* pcre2_substitute() only */ +#define PCRE2_NO_JIT 0x00002000u /* not for pcre2_dfa_match() */ +#define PCRE2_COPY_MATCHED_SUBJECT 0x00004000u +#define PCRE2_SUBSTITUTE_LITERAL 0x00008000u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_MATCHED 0x00010000u /* pcre2_substitute() only */ +#define PCRE2_SUBSTITUTE_REPLACEMENT_ONLY 0x00020000u /* pcre2_substitute() only */ +#define PCRE2_DISABLE_RECURSELOOP_CHECK 0x00040000u /* not for pcre2_dfa_match() or pcre2_jit_match() */ + +/* Options for pcre2_pattern_convert(). */ + +#define PCRE2_CONVERT_UTF 0x00000001u +#define PCRE2_CONVERT_NO_UTF_CHECK 0x00000002u +#define PCRE2_CONVERT_POSIX_BASIC 0x00000004u +#define PCRE2_CONVERT_POSIX_EXTENDED 0x00000008u +#define PCRE2_CONVERT_GLOB 0x00000010u +#define PCRE2_CONVERT_GLOB_NO_WILD_SEPARATOR 0x00000030u +#define PCRE2_CONVERT_GLOB_NO_STARSTAR 0x00000050u + +/* Newline and \R settings, for use in compile contexts. The newline values +must be kept in step with values set in config.h and both sets must all be +greater than zero. */ + +#define PCRE2_NEWLINE_CR 1 +#define PCRE2_NEWLINE_LF 2 +#define PCRE2_NEWLINE_CRLF 3 +#define PCRE2_NEWLINE_ANY 4 +#define PCRE2_NEWLINE_ANYCRLF 5 +#define PCRE2_NEWLINE_NUL 6 + +#define PCRE2_BSR_UNICODE 1 +#define PCRE2_BSR_ANYCRLF 2 + +/* Error codes for pcre2_compile(). Some of these are also used by +pcre2_pattern_convert(). */ + +#define PCRE2_ERROR_END_BACKSLASH 101 +#define PCRE2_ERROR_END_BACKSLASH_C 102 +#define PCRE2_ERROR_UNKNOWN_ESCAPE 103 +#define PCRE2_ERROR_QUANTIFIER_OUT_OF_ORDER 104 +#define PCRE2_ERROR_QUANTIFIER_TOO_BIG 105 +#define PCRE2_ERROR_MISSING_SQUARE_BRACKET 106 +#define PCRE2_ERROR_ESCAPE_INVALID_IN_CLASS 107 +#define PCRE2_ERROR_CLASS_RANGE_ORDER 108 +#define PCRE2_ERROR_QUANTIFIER_INVALID 109 +#define PCRE2_ERROR_INTERNAL_UNEXPECTED_REPEAT 110 +#define PCRE2_ERROR_INVALID_AFTER_PARENS_QUERY 111 +#define PCRE2_ERROR_POSIX_CLASS_NOT_IN_CLASS 112 +#define PCRE2_ERROR_POSIX_NO_SUPPORT_COLLATING 113 +#define PCRE2_ERROR_MISSING_CLOSING_PARENTHESIS 114 +#define PCRE2_ERROR_BAD_SUBPATTERN_REFERENCE 115 +#define PCRE2_ERROR_NULL_PATTERN 116 +#define PCRE2_ERROR_BAD_OPTIONS 117 +#define PCRE2_ERROR_MISSING_COMMENT_CLOSING 118 +#define PCRE2_ERROR_PARENTHESES_NEST_TOO_DEEP 119 +#define PCRE2_ERROR_PATTERN_TOO_LARGE 120 +#define PCRE2_ERROR_HEAP_FAILED 121 +#define PCRE2_ERROR_UNMATCHED_CLOSING_PARENTHESIS 122 +#define PCRE2_ERROR_INTERNAL_CODE_OVERFLOW 123 +#define PCRE2_ERROR_MISSING_CONDITION_CLOSING 124 +#define PCRE2_ERROR_LOOKBEHIND_NOT_FIXED_LENGTH 125 +#define PCRE2_ERROR_ZERO_RELATIVE_REFERENCE 126 +#define PCRE2_ERROR_TOO_MANY_CONDITION_BRANCHES 127 +#define PCRE2_ERROR_CONDITION_ASSERTION_EXPECTED 128 +#define PCRE2_ERROR_BAD_RELATIVE_REFERENCE 129 +#define PCRE2_ERROR_UNKNOWN_POSIX_CLASS 130 +#define PCRE2_ERROR_INTERNAL_STUDY_ERROR 131 +#define PCRE2_ERROR_UNICODE_NOT_SUPPORTED 132 +#define PCRE2_ERROR_PARENTHESES_STACK_CHECK 133 +#define PCRE2_ERROR_CODE_POINT_TOO_BIG 134 +#define PCRE2_ERROR_LOOKBEHIND_TOO_COMPLICATED 135 +#define PCRE2_ERROR_LOOKBEHIND_INVALID_BACKSLASH_C 136 +#define PCRE2_ERROR_UNSUPPORTED_ESCAPE_SEQUENCE 137 +#define PCRE2_ERROR_CALLOUT_NUMBER_TOO_BIG 138 +#define PCRE2_ERROR_MISSING_CALLOUT_CLOSING 139 +#define PCRE2_ERROR_ESCAPE_INVALID_IN_VERB 140 +#define PCRE2_ERROR_UNRECOGNIZED_AFTER_QUERY_P 141 +#define PCRE2_ERROR_MISSING_NAME_TERMINATOR 142 +#define PCRE2_ERROR_DUPLICATE_SUBPATTERN_NAME 143 +#define PCRE2_ERROR_INVALID_SUBPATTERN_NAME 144 +#define PCRE2_ERROR_UNICODE_PROPERTIES_UNAVAILABLE 145 +#define PCRE2_ERROR_MALFORMED_UNICODE_PROPERTY 146 +#define PCRE2_ERROR_UNKNOWN_UNICODE_PROPERTY 147 +#define PCRE2_ERROR_SUBPATTERN_NAME_TOO_LONG 148 +#define PCRE2_ERROR_TOO_MANY_NAMED_SUBPATTERNS 149 +#define PCRE2_ERROR_CLASS_INVALID_RANGE 150 +#define PCRE2_ERROR_OCTAL_BYTE_TOO_BIG 151 +#define PCRE2_ERROR_INTERNAL_OVERRAN_WORKSPACE 152 +#define PCRE2_ERROR_INTERNAL_MISSING_SUBPATTERN 153 +#define PCRE2_ERROR_DEFINE_TOO_MANY_BRANCHES 154 +#define PCRE2_ERROR_BACKSLASH_O_MISSING_BRACE 155 +#define PCRE2_ERROR_INTERNAL_UNKNOWN_NEWLINE 156 +#define PCRE2_ERROR_BACKSLASH_G_SYNTAX 157 +#define PCRE2_ERROR_PARENS_QUERY_R_MISSING_CLOSING 158 +/* Error 159 is obsolete and should now never occur */ +#define PCRE2_ERROR_VERB_ARGUMENT_NOT_ALLOWED 159 +#define PCRE2_ERROR_VERB_UNKNOWN 160 +#define PCRE2_ERROR_SUBPATTERN_NUMBER_TOO_BIG 161 +#define PCRE2_ERROR_SUBPATTERN_NAME_EXPECTED 162 +#define PCRE2_ERROR_INTERNAL_PARSED_OVERFLOW 163 +#define PCRE2_ERROR_INVALID_OCTAL 164 +#define PCRE2_ERROR_SUBPATTERN_NAMES_MISMATCH 165 +#define PCRE2_ERROR_MARK_MISSING_ARGUMENT 166 +#define PCRE2_ERROR_INVALID_HEXADECIMAL 167 +#define PCRE2_ERROR_BACKSLASH_C_SYNTAX 168 +#define PCRE2_ERROR_BACKSLASH_K_SYNTAX 169 +#define PCRE2_ERROR_INTERNAL_BAD_CODE_LOOKBEHINDS 170 +#define PCRE2_ERROR_BACKSLASH_N_IN_CLASS 171 +#define PCRE2_ERROR_CALLOUT_STRING_TOO_LONG 172 +#define PCRE2_ERROR_UNICODE_DISALLOWED_CODE_POINT 173 +#define PCRE2_ERROR_UTF_IS_DISABLED 174 +#define PCRE2_ERROR_UCP_IS_DISABLED 175 +#define PCRE2_ERROR_VERB_NAME_TOO_LONG 176 +#define PCRE2_ERROR_BACKSLASH_U_CODE_POINT_TOO_BIG 177 +#define PCRE2_ERROR_MISSING_OCTAL_OR_HEX_DIGITS 178 +#define PCRE2_ERROR_VERSION_CONDITION_SYNTAX 179 +#define PCRE2_ERROR_INTERNAL_BAD_CODE_AUTO_POSSESS 180 +#define PCRE2_ERROR_CALLOUT_NO_STRING_DELIMITER 181 +#define PCRE2_ERROR_CALLOUT_BAD_STRING_DELIMITER 182 +#define PCRE2_ERROR_BACKSLASH_C_CALLER_DISABLED 183 +#define PCRE2_ERROR_QUERY_BARJX_NEST_TOO_DEEP 184 +#define PCRE2_ERROR_BACKSLASH_C_LIBRARY_DISABLED 185 +#define PCRE2_ERROR_PATTERN_TOO_COMPLICATED 186 +#define PCRE2_ERROR_LOOKBEHIND_TOO_LONG 187 +#define PCRE2_ERROR_PATTERN_STRING_TOO_LONG 188 +#define PCRE2_ERROR_INTERNAL_BAD_CODE 189 +#define PCRE2_ERROR_INTERNAL_BAD_CODE_IN_SKIP 190 +#define PCRE2_ERROR_NO_SURROGATES_IN_UTF16 191 +#define PCRE2_ERROR_BAD_LITERAL_OPTIONS 192 +#define PCRE2_ERROR_SUPPORTED_ONLY_IN_UNICODE 193 +#define PCRE2_ERROR_INVALID_HYPHEN_IN_OPTIONS 194 +#define PCRE2_ERROR_ALPHA_ASSERTION_UNKNOWN 195 +#define PCRE2_ERROR_SCRIPT_RUN_NOT_AVAILABLE 196 +#define PCRE2_ERROR_TOO_MANY_CAPTURES 197 +#define PCRE2_ERROR_MISSING_OCTAL_DIGIT 198 +#define PCRE2_ERROR_BACKSLASH_K_IN_LOOKAROUND 199 +#define PCRE2_ERROR_MAX_VAR_LOOKBEHIND_EXCEEDED 200 +#define PCRE2_ERROR_PATTERN_COMPILED_SIZE_TOO_BIG 201 +#define PCRE2_ERROR_OVERSIZE_PYTHON_OCTAL 202 +#define PCRE2_ERROR_CALLOUT_CALLER_DISABLED 203 +#define PCRE2_ERROR_EXTRA_CASING_REQUIRES_UNICODE 204 +#define PCRE2_ERROR_TURKISH_CASING_REQUIRES_UTF 205 +#define PCRE2_ERROR_EXTRA_CASING_INCOMPATIBLE 206 +#define PCRE2_ERROR_ECLASS_NEST_TOO_DEEP 207 +#define PCRE2_ERROR_ECLASS_INVALID_OPERATOR 208 +#define PCRE2_ERROR_ECLASS_UNEXPECTED_OPERATOR 209 +#define PCRE2_ERROR_ECLASS_EXPECTED_OPERAND 210 +#define PCRE2_ERROR_ECLASS_MIXED_OPERATORS 211 +#define PCRE2_ERROR_ECLASS_HINT_SQUARE_BRACKET 212 +#define PCRE2_ERROR_PERL_ECLASS_UNEXPECTED_EXPR 213 +#define PCRE2_ERROR_PERL_ECLASS_EMPTY_EXPR 214 +#define PCRE2_ERROR_PERL_ECLASS_MISSING_CLOSE 215 +#define PCRE2_ERROR_PERL_ECLASS_UNEXPECTED_CHAR 216 +#define PCRE2_ERROR_EXPECTED_CAPTURE_GROUP 217 +#define PCRE2_ERROR_MISSING_OPENING_PARENTHESIS 218 +#define PCRE2_ERROR_MISSING_NUMBER_TERMINATOR 219 +#define PCRE2_ERROR_NULL_ERROROFFSET 220 + +/* "Expected" matching error codes: no match and partial match. */ + +#define PCRE2_ERROR_NOMATCH (-1) +#define PCRE2_ERROR_PARTIAL (-2) + +/* Error codes for UTF-8 validity checks */ + +#define PCRE2_ERROR_UTF8_ERR1 (-3) +#define PCRE2_ERROR_UTF8_ERR2 (-4) +#define PCRE2_ERROR_UTF8_ERR3 (-5) +#define PCRE2_ERROR_UTF8_ERR4 (-6) +#define PCRE2_ERROR_UTF8_ERR5 (-7) +#define PCRE2_ERROR_UTF8_ERR6 (-8) +#define PCRE2_ERROR_UTF8_ERR7 (-9) +#define PCRE2_ERROR_UTF8_ERR8 (-10) +#define PCRE2_ERROR_UTF8_ERR9 (-11) +#define PCRE2_ERROR_UTF8_ERR10 (-12) +#define PCRE2_ERROR_UTF8_ERR11 (-13) +#define PCRE2_ERROR_UTF8_ERR12 (-14) +#define PCRE2_ERROR_UTF8_ERR13 (-15) +#define PCRE2_ERROR_UTF8_ERR14 (-16) +#define PCRE2_ERROR_UTF8_ERR15 (-17) +#define PCRE2_ERROR_UTF8_ERR16 (-18) +#define PCRE2_ERROR_UTF8_ERR17 (-19) +#define PCRE2_ERROR_UTF8_ERR18 (-20) +#define PCRE2_ERROR_UTF8_ERR19 (-21) +#define PCRE2_ERROR_UTF8_ERR20 (-22) +#define PCRE2_ERROR_UTF8_ERR21 (-23) + +/* Error codes for UTF-16 validity checks */ + +#define PCRE2_ERROR_UTF16_ERR1 (-24) +#define PCRE2_ERROR_UTF16_ERR2 (-25) +#define PCRE2_ERROR_UTF16_ERR3 (-26) + +/* Error codes for UTF-32 validity checks */ + +#define PCRE2_ERROR_UTF32_ERR1 (-27) +#define PCRE2_ERROR_UTF32_ERR2 (-28) + +/* Miscellaneous error codes for pcre2[_dfa]_match(), substring extraction +functions, context functions, and serializing functions. They are in numerical +order. Originally they were in alphabetical order too, but now that PCRE2 is +released, the numbers must not be changed. */ + +#define PCRE2_ERROR_BADDATA (-29) +#define PCRE2_ERROR_MIXEDTABLES (-30) /* Name was changed */ +#define PCRE2_ERROR_BADMAGIC (-31) +#define PCRE2_ERROR_BADMODE (-32) +#define PCRE2_ERROR_BADOFFSET (-33) +#define PCRE2_ERROR_BADOPTION (-34) +#define PCRE2_ERROR_BADREPLACEMENT (-35) +#define PCRE2_ERROR_BADUTFOFFSET (-36) +#define PCRE2_ERROR_CALLOUT (-37) /* Never used by PCRE2 itself */ +#define PCRE2_ERROR_DFA_BADRESTART (-38) +#define PCRE2_ERROR_DFA_RECURSE (-39) +#define PCRE2_ERROR_DFA_UCOND (-40) +#define PCRE2_ERROR_DFA_UFUNC (-41) +#define PCRE2_ERROR_DFA_UITEM (-42) +#define PCRE2_ERROR_DFA_WSSIZE (-43) +#define PCRE2_ERROR_INTERNAL (-44) +#define PCRE2_ERROR_JIT_BADOPTION (-45) +#define PCRE2_ERROR_JIT_STACKLIMIT (-46) +#define PCRE2_ERROR_MATCHLIMIT (-47) +#define PCRE2_ERROR_NOMEMORY (-48) +#define PCRE2_ERROR_NOSUBSTRING (-49) +#define PCRE2_ERROR_NOUNIQUESUBSTRING (-50) +#define PCRE2_ERROR_NULL (-51) +#define PCRE2_ERROR_RECURSELOOP (-52) +#define PCRE2_ERROR_DEPTHLIMIT (-53) +#define PCRE2_ERROR_RECURSIONLIMIT (-53) /* Obsolete synonym */ +#define PCRE2_ERROR_UNAVAILABLE (-54) +#define PCRE2_ERROR_UNSET (-55) +#define PCRE2_ERROR_BADOFFSETLIMIT (-56) +#define PCRE2_ERROR_BADREPESCAPE (-57) +#define PCRE2_ERROR_REPMISSINGBRACE (-58) +#define PCRE2_ERROR_BADSUBSTITUTION (-59) +#define PCRE2_ERROR_BADSUBSPATTERN (-60) +#define PCRE2_ERROR_TOOMANYREPLACE (-61) +#define PCRE2_ERROR_BADSERIALIZEDDATA (-62) +#define PCRE2_ERROR_HEAPLIMIT (-63) +#define PCRE2_ERROR_CONVERT_SYNTAX (-64) +#define PCRE2_ERROR_INTERNAL_DUPMATCH (-65) +#define PCRE2_ERROR_DFA_UINVALID_UTF (-66) +#define PCRE2_ERROR_INVALIDOFFSET (-67) +#define PCRE2_ERROR_JIT_UNSUPPORTED (-68) +#define PCRE2_ERROR_REPLACECASE (-69) +#define PCRE2_ERROR_TOOLARGEREPLACE (-70) +#define PCRE2_ERROR_DIFFSUBSPATTERN (-71) +#define PCRE2_ERROR_DIFFSUBSSUBJECT (-72) +#define PCRE2_ERROR_DIFFSUBSOFFSET (-73) +#define PCRE2_ERROR_DIFFSUBSOPTIONS (-74) +#define PCRE2_ERROR_BAD_BACKSLASH_K (-75) + + +/* Request types for pcre2_pattern_info() */ + +#define PCRE2_INFO_ALLOPTIONS 0 +#define PCRE2_INFO_ARGOPTIONS 1 +#define PCRE2_INFO_BACKREFMAX 2 +#define PCRE2_INFO_BSR 3 +#define PCRE2_INFO_CAPTURECOUNT 4 +#define PCRE2_INFO_FIRSTCODEUNIT 5 +#define PCRE2_INFO_FIRSTCODETYPE 6 +#define PCRE2_INFO_FIRSTBITMAP 7 +#define PCRE2_INFO_HASCRORLF 8 +#define PCRE2_INFO_JCHANGED 9 +#define PCRE2_INFO_JITSIZE 10 +#define PCRE2_INFO_LASTCODEUNIT 11 +#define PCRE2_INFO_LASTCODETYPE 12 +#define PCRE2_INFO_MATCHEMPTY 13 +#define PCRE2_INFO_MATCHLIMIT 14 +#define PCRE2_INFO_MAXLOOKBEHIND 15 +#define PCRE2_INFO_MINLENGTH 16 +#define PCRE2_INFO_NAMECOUNT 17 +#define PCRE2_INFO_NAMEENTRYSIZE 18 +#define PCRE2_INFO_NAMETABLE 19 +#define PCRE2_INFO_NEWLINE 20 +#define PCRE2_INFO_DEPTHLIMIT 21 +#define PCRE2_INFO_RECURSIONLIMIT 21 /* Obsolete synonym */ +#define PCRE2_INFO_SIZE 22 +#define PCRE2_INFO_HASBACKSLASHC 23 +#define PCRE2_INFO_FRAMESIZE 24 +#define PCRE2_INFO_HEAPLIMIT 25 +#define PCRE2_INFO_EXTRAOPTIONS 26 + +/* Request types for pcre2_config(). */ + +#define PCRE2_CONFIG_BSR 0 +#define PCRE2_CONFIG_JIT 1 +#define PCRE2_CONFIG_JITTARGET 2 +#define PCRE2_CONFIG_LINKSIZE 3 +#define PCRE2_CONFIG_MATCHLIMIT 4 +#define PCRE2_CONFIG_NEWLINE 5 +#define PCRE2_CONFIG_PARENSLIMIT 6 +#define PCRE2_CONFIG_DEPTHLIMIT 7 +#define PCRE2_CONFIG_RECURSIONLIMIT 7 /* Obsolete synonym */ +#define PCRE2_CONFIG_STACKRECURSE 8 /* Obsolete */ +#define PCRE2_CONFIG_UNICODE 9 +#define PCRE2_CONFIG_UNICODE_VERSION 10 +#define PCRE2_CONFIG_VERSION 11 +#define PCRE2_CONFIG_HEAPLIMIT 12 +#define PCRE2_CONFIG_NEVER_BACKSLASH_C 13 +#define PCRE2_CONFIG_COMPILED_WIDTHS 14 +#define PCRE2_CONFIG_TABLES_LENGTH 15 +#define PCRE2_CONFIG_EFFECTIVE_LINKSIZE 16 + +/* Optimization directives for pcre2_set_optimize(). +For binary compatibility, only add to this list; do not renumber. */ + +#define PCRE2_OPTIMIZATION_NONE 0 +#define PCRE2_OPTIMIZATION_FULL 1 + +#define PCRE2_AUTO_POSSESS 64 +#define PCRE2_AUTO_POSSESS_OFF 65 +#define PCRE2_DOTSTAR_ANCHOR 66 +#define PCRE2_DOTSTAR_ANCHOR_OFF 67 +#define PCRE2_START_OPTIMIZE 68 +#define PCRE2_START_OPTIMIZE_OFF 69 + +/* Types used in pcre2_set_substitute_case_callout(). + +PCRE2_SUBSTITUTE_CASE_LOWER and PCRE2_SUBSTITUTE_CASE_UPPER are passed to the +callout to indicate that the case of the entire callout input should be +case-transformed. PCRE2_SUBSTITUTE_CASE_TITLE_FIRST is passed to indicate that +only the first character or glyph should be transformed to Unicode titlecase, +and the rest to lowercase. */ + +#define PCRE2_SUBSTITUTE_CASE_LOWER 1 +#define PCRE2_SUBSTITUTE_CASE_UPPER 2 +#define PCRE2_SUBSTITUTE_CASE_TITLE_FIRST 3 + +/* Types for code units in patterns and subject strings. */ + +typedef uint8_t PCRE2_UCHAR8; +typedef uint16_t PCRE2_UCHAR16; +typedef uint32_t PCRE2_UCHAR32; + +typedef const PCRE2_UCHAR8 *PCRE2_SPTR8; +typedef const PCRE2_UCHAR16 *PCRE2_SPTR16; +typedef const PCRE2_UCHAR32 *PCRE2_SPTR32; + +/* The PCRE2_SIZE type is used for all string lengths and offsets in PCRE2, +including pattern offsets for errors and subject offsets after a match. We +define special values to indicate zero-terminated strings and unset offsets in +the offset vector (ovector). */ + +#define PCRE2_SIZE size_t +#define PCRE2_SIZE_MAX SIZE_MAX +#define PCRE2_ZERO_TERMINATED (~(PCRE2_SIZE)0) +#define PCRE2_UNSET (~(PCRE2_SIZE)0) + +/* Generic types for opaque structures and JIT callback functions. These +declarations are defined in a macro that is expanded for each width later. */ + +#define PCRE2_TYPES_LIST \ +struct pcre2_real_general_context; \ +typedef struct pcre2_real_general_context pcre2_general_context; \ +\ +struct pcre2_real_compile_context; \ +typedef struct pcre2_real_compile_context pcre2_compile_context; \ +\ +struct pcre2_real_match_context; \ +typedef struct pcre2_real_match_context pcre2_match_context; \ +\ +struct pcre2_real_convert_context; \ +typedef struct pcre2_real_convert_context pcre2_convert_context; \ +\ +struct pcre2_real_code; \ +typedef struct pcre2_real_code pcre2_code; \ +\ +struct pcre2_real_match_data; \ +typedef struct pcre2_real_match_data pcre2_match_data; \ +\ +struct pcre2_real_jit_stack; \ +typedef struct pcre2_real_jit_stack pcre2_jit_stack; \ +\ +typedef pcre2_jit_stack *(*pcre2_jit_callback)(void *); + + +/* The structures for passing out data via callout functions. We use structures +so that new fields can be added on the end in future versions, without changing +the API of the function, thereby allowing old clients to work without +modification. Define the generic versions in a macro; the width-specific +versions are generated from this macro below. */ + +/* Flags for the callout_flags field. These are cleared after a callout. */ + +#define PCRE2_CALLOUT_STARTMATCH 0x00000001u /* Set for each bumpalong */ +#define PCRE2_CALLOUT_BACKTRACK 0x00000002u /* Set after a backtrack */ + +#define PCRE2_STRUCTURE_LIST \ +typedef struct pcre2_callout_block { \ + uint32_t version; /* Identifies version of block */ \ + /* ------------------------ Version 0 ------------------------------- */ \ + uint32_t callout_number; /* Number compiled into pattern */ \ + uint32_t capture_top; /* Max current capture */ \ + uint32_t capture_last; /* Most recently closed capture */ \ + PCRE2_SIZE *offset_vector; /* The offset vector */ \ + PCRE2_SPTR mark; /* Pointer to current mark or NULL */ \ + PCRE2_SPTR subject; /* The subject being matched */ \ + PCRE2_SIZE subject_length; /* The length of the subject */ \ + PCRE2_SIZE start_match; /* Offset to start of this match attempt */ \ + PCRE2_SIZE current_position; /* Where we currently are in the subject */ \ + PCRE2_SIZE pattern_position; /* Offset to next item in the pattern */ \ + PCRE2_SIZE next_item_length; /* Length of next item in the pattern */ \ + /* ------------------- Added for Version 1 -------------------------- */ \ + PCRE2_SIZE callout_string_offset; /* Offset to string within pattern */ \ + PCRE2_SIZE callout_string_length; /* Length of string compiled into pattern */ \ + PCRE2_SPTR callout_string; /* String compiled into pattern */ \ + /* ------------------- Added for Version 2 -------------------------- */ \ + uint32_t callout_flags; /* See above for list */ \ + /* ------------------------------------------------------------------ */ \ +} pcre2_callout_block; \ +\ +typedef struct pcre2_callout_enumerate_block { \ + uint32_t version; /* Identifies version of block */ \ + /* ------------------------ Version 0 ------------------------------- */ \ + PCRE2_SIZE pattern_position; /* Offset to next item in the pattern */ \ + PCRE2_SIZE next_item_length; /* Length of next item in the pattern */ \ + uint32_t callout_number; /* Number compiled into pattern */ \ + PCRE2_SIZE callout_string_offset; /* Offset to string within pattern */ \ + PCRE2_SIZE callout_string_length; /* Length of string compiled into pattern */ \ + PCRE2_SPTR callout_string; /* String compiled into pattern */ \ + /* ------------------------------------------------------------------ */ \ +} pcre2_callout_enumerate_block; \ +\ +typedef struct pcre2_substitute_callout_block { \ + uint32_t version; /* Identifies version of block */ \ + /* ------------------------ Version 0 ------------------------------- */ \ + PCRE2_SPTR input; /* Pointer to input subject string */ \ + PCRE2_SPTR output; /* Pointer to output buffer */ \ + PCRE2_SIZE output_offsets[2]; /* Changed portion of the output */ \ + PCRE2_SIZE *ovector; /* Pointer to current ovector */ \ + uint32_t oveccount; /* Count of pairs set in ovector */ \ + uint32_t subscount; /* Substitution number */ \ + /* ------------------------------------------------------------------ */ \ +} pcre2_substitute_callout_block; + + +/* List the generic forms of all other functions in macros, which will be +expanded for each width below. Start with functions that give general +information. */ + +#define PCRE2_GENERAL_INFO_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION pcre2_config(uint32_t, void *); + + +/* Functions for manipulating contexts. */ + +#define PCRE2_GENERAL_CONTEXT_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_general_context *PCRE2_CALL_CONVENTION \ + pcre2_general_context_copy(pcre2_general_context *); \ +PCRE2_EXP_DECL pcre2_general_context *PCRE2_CALL_CONVENTION \ + pcre2_general_context_create(void *(*)(size_t, void *), \ + void (*)(void *, void *), void *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_general_context_free(pcre2_general_context *); + +#define PCRE2_COMPILE_CONTEXT_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_compile_context *PCRE2_CALL_CONVENTION \ + pcre2_compile_context_copy(pcre2_compile_context *); \ +PCRE2_EXP_DECL pcre2_compile_context *PCRE2_CALL_CONVENTION \ + pcre2_compile_context_create(pcre2_general_context *);\ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_compile_context_free(pcre2_compile_context *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_bsr(pcre2_compile_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_character_tables(pcre2_compile_context *, const uint8_t *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_compile_extra_options(pcre2_compile_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_max_pattern_length(pcre2_compile_context *, PCRE2_SIZE); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_max_pattern_compiled_length(pcre2_compile_context *, PCRE2_SIZE); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_max_varlookbehind(pcre2_compile_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_newline(pcre2_compile_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_parens_nest_limit(pcre2_compile_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_compile_recursion_guard(pcre2_compile_context *, \ + int (*)(uint32_t, void *), void *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_optimize(pcre2_compile_context *, uint32_t); + +#define PCRE2_MATCH_CONTEXT_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_match_context *PCRE2_CALL_CONVENTION \ + pcre2_match_context_copy(pcre2_match_context *); \ +PCRE2_EXP_DECL pcre2_match_context *PCRE2_CALL_CONVENTION \ + pcre2_match_context_create(pcre2_general_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_match_context_free(pcre2_match_context *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_callout(pcre2_match_context *, \ + int (*)(pcre2_callout_block *, void *), void *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_substitute_callout(pcre2_match_context *, \ + int (*)(pcre2_substitute_callout_block *, void *), void *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_substitute_case_callout(pcre2_match_context *, \ + PCRE2_SIZE (*)(PCRE2_SPTR, PCRE2_SIZE, PCRE2_UCHAR *, PCRE2_SIZE, int, \ + void *), \ + void *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_depth_limit(pcre2_match_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_heap_limit(pcre2_match_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_match_limit(pcre2_match_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_offset_limit(pcre2_match_context *, PCRE2_SIZE); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_recursion_limit(pcre2_match_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_recursion_memory_management(pcre2_match_context *, \ + void *(*)(size_t, void *), void (*)(void *, void *), void *); + +#define PCRE2_CONVERT_CONTEXT_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_convert_context *PCRE2_CALL_CONVENTION \ + pcre2_convert_context_copy(pcre2_convert_context *); \ +PCRE2_EXP_DECL pcre2_convert_context *PCRE2_CALL_CONVENTION \ + pcre2_convert_context_create(pcre2_general_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_convert_context_free(pcre2_convert_context *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_glob_escape(pcre2_convert_context *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_set_glob_separator(pcre2_convert_context *, uint32_t); + + +/* Functions concerned with compiling a pattern to PCRE internal code. */ + +#define PCRE2_COMPILE_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_code *PCRE2_CALL_CONVENTION \ + pcre2_compile(PCRE2_SPTR, PCRE2_SIZE, uint32_t, int *, PCRE2_SIZE *, \ + pcre2_compile_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_code_free(pcre2_code *); \ +PCRE2_EXP_DECL pcre2_code *PCRE2_CALL_CONVENTION \ + pcre2_code_copy(const pcre2_code *); \ +PCRE2_EXP_DECL pcre2_code *PCRE2_CALL_CONVENTION \ + pcre2_code_copy_with_tables(const pcre2_code *); + + +/* Functions that give information about a compiled pattern. */ + +#define PCRE2_PATTERN_INFO_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_pattern_info(const pcre2_code *, uint32_t, void *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_callout_enumerate(const pcre2_code *, \ + int (*)(pcre2_callout_enumerate_block *, void *), void *); + + +/* Functions for running a match and inspecting the result. */ + +#define PCRE2_MATCH_FUNCTIONS \ +PCRE2_EXP_DECL pcre2_match_data *PCRE2_CALL_CONVENTION \ + pcre2_match_data_create(uint32_t, pcre2_general_context *); \ +PCRE2_EXP_DECL pcre2_match_data *PCRE2_CALL_CONVENTION \ + pcre2_match_data_create_from_pattern(const pcre2_code *, \ + pcre2_general_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_match_data_free(pcre2_match_data *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_dfa_match(const pcre2_code *, PCRE2_SPTR, PCRE2_SIZE, PCRE2_SIZE, \ + uint32_t, pcre2_match_data *, pcre2_match_context *, int *, PCRE2_SIZE); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_match(const pcre2_code *, PCRE2_SPTR, PCRE2_SIZE, PCRE2_SIZE, \ + uint32_t, pcre2_match_data *, pcre2_match_context *); \ +PCRE2_EXP_DECL PCRE2_SPTR PCRE2_CALL_CONVENTION \ + pcre2_get_mark(pcre2_match_data *); \ +PCRE2_EXP_DECL PCRE2_SIZE PCRE2_CALL_CONVENTION \ + pcre2_get_match_data_size(pcre2_match_data *); \ +PCRE2_EXP_DECL PCRE2_SIZE PCRE2_CALL_CONVENTION \ + pcre2_get_match_data_heapframes_size(pcre2_match_data *); \ +PCRE2_EXP_DECL uint32_t PCRE2_CALL_CONVENTION \ + pcre2_get_ovector_count(pcre2_match_data *); \ +PCRE2_EXP_DECL PCRE2_SIZE *PCRE2_CALL_CONVENTION \ + pcre2_get_ovector_pointer(pcre2_match_data *); \ +PCRE2_EXP_DECL PCRE2_SIZE PCRE2_CALL_CONVENTION \ + pcre2_get_startchar(pcre2_match_data *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_next_match(pcre2_match_data *, PCRE2_SIZE *, uint32_t *); + + +/* Convenience functions for handling matched substrings. */ + +#define PCRE2_SUBSTRING_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_copy_byname(pcre2_match_data *, PCRE2_SPTR, PCRE2_UCHAR *, \ + PCRE2_SIZE *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_copy_bynumber(pcre2_match_data *, uint32_t, PCRE2_UCHAR *, \ + PCRE2_SIZE *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_substring_free(PCRE2_UCHAR *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_get_byname(pcre2_match_data *, PCRE2_SPTR, PCRE2_UCHAR **, \ + PCRE2_SIZE *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_get_bynumber(pcre2_match_data *, uint32_t, PCRE2_UCHAR **, \ + PCRE2_SIZE *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_length_byname(pcre2_match_data *, PCRE2_SPTR, PCRE2_SIZE *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_length_bynumber(pcre2_match_data *, uint32_t, PCRE2_SIZE *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_nametable_scan(const pcre2_code *, PCRE2_SPTR, PCRE2_SPTR *, \ + PCRE2_SPTR *); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_number_from_name(const pcre2_code *, PCRE2_SPTR); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_substring_list_free(PCRE2_UCHAR **); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substring_list_get(pcre2_match_data *, PCRE2_UCHAR ***, PCRE2_SIZE **); + + +/* Functions for serializing / deserializing compiled patterns. */ + +#define PCRE2_SERIALIZE_FUNCTIONS \ +PCRE2_EXP_DECL int32_t PCRE2_CALL_CONVENTION \ + pcre2_serialize_encode(const pcre2_code **, int32_t, uint8_t **, \ + PCRE2_SIZE *, pcre2_general_context *); \ +PCRE2_EXP_DECL int32_t PCRE2_CALL_CONVENTION \ + pcre2_serialize_decode(pcre2_code **, int32_t, const uint8_t *, \ + pcre2_general_context *); \ +PCRE2_EXP_DECL int32_t PCRE2_CALL_CONVENTION \ + pcre2_serialize_get_number_of_codes(const uint8_t *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_serialize_free(uint8_t *); + + +/* Convenience function for match + substitute. */ + +#define PCRE2_SUBSTITUTE_FUNCTION \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_substitute(const pcre2_code *, PCRE2_SPTR, PCRE2_SIZE, PCRE2_SIZE, \ + uint32_t, pcre2_match_data *, pcre2_match_context *, PCRE2_SPTR, \ + PCRE2_SIZE, PCRE2_UCHAR *, PCRE2_SIZE *); + + +/* Functions for converting pattern source strings. */ + +#define PCRE2_CONVERT_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_pattern_convert(PCRE2_SPTR, PCRE2_SIZE, uint32_t, PCRE2_UCHAR **, \ + PCRE2_SIZE *, pcre2_convert_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_converted_pattern_free(PCRE2_UCHAR *); + + +/* Functions for JIT processing */ + +#define PCRE2_JIT_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_jit_compile(pcre2_code *, uint32_t); \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_jit_match(const pcre2_code *, PCRE2_SPTR, PCRE2_SIZE, PCRE2_SIZE, \ + uint32_t, pcre2_match_data *, pcre2_match_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_jit_free_unused_memory(pcre2_general_context *); \ +PCRE2_EXP_DECL pcre2_jit_stack *PCRE2_CALL_CONVENTION \ + pcre2_jit_stack_create(size_t, size_t, pcre2_general_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_jit_stack_assign(pcre2_match_context *, pcre2_jit_callback, void *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_jit_stack_free(pcre2_jit_stack *); + + +/* Other miscellaneous functions. */ + +#define PCRE2_OTHER_FUNCTIONS \ +PCRE2_EXP_DECL int PCRE2_CALL_CONVENTION \ + pcre2_get_error_message(int, PCRE2_UCHAR *, PCRE2_SIZE); \ +PCRE2_EXP_DECL const uint8_t *PCRE2_CALL_CONVENTION \ + pcre2_maketables(pcre2_general_context *); \ +PCRE2_EXP_DECL void PCRE2_CALL_CONVENTION \ + pcre2_maketables_free(pcre2_general_context *, const uint8_t *); + +/* Define macros that generate width-specific names from generic versions. The +three-level macro scheme is necessary to get the macros expanded when we want +them to be. First we get the width from PCRE2_LOCAL_WIDTH, which is used for +generating three versions of everything below. After that, PCRE2_SUFFIX will be +re-defined to use PCRE2_CODE_UNIT_WIDTH, for use when macros such as +pcre2_compile are called by application code. */ + +#define PCRE2_JOIN(a,b) a ## b +#define PCRE2_GLUE(a,b) PCRE2_JOIN(a,b) +#define PCRE2_SUFFIX(a) PCRE2_GLUE(a,PCRE2_LOCAL_WIDTH) + + +/* Data types */ + +#define PCRE2_UCHAR PCRE2_SUFFIX(PCRE2_UCHAR) +#define PCRE2_SPTR PCRE2_SUFFIX(PCRE2_SPTR) + +#define pcre2_code PCRE2_SUFFIX(pcre2_code_) +#define pcre2_jit_callback PCRE2_SUFFIX(pcre2_jit_callback_) +#define pcre2_jit_stack PCRE2_SUFFIX(pcre2_jit_stack_) + +#define pcre2_real_code PCRE2_SUFFIX(pcre2_real_code_) +#define pcre2_real_general_context PCRE2_SUFFIX(pcre2_real_general_context_) +#define pcre2_real_compile_context PCRE2_SUFFIX(pcre2_real_compile_context_) +#define pcre2_real_convert_context PCRE2_SUFFIX(pcre2_real_convert_context_) +#define pcre2_real_match_context PCRE2_SUFFIX(pcre2_real_match_context_) +#define pcre2_real_jit_stack PCRE2_SUFFIX(pcre2_real_jit_stack_) +#define pcre2_real_match_data PCRE2_SUFFIX(pcre2_real_match_data_) + + +/* Data blocks */ + +#define pcre2_callout_block PCRE2_SUFFIX(pcre2_callout_block_) +#define pcre2_callout_enumerate_block PCRE2_SUFFIX(pcre2_callout_enumerate_block_) +#define pcre2_substitute_callout_block PCRE2_SUFFIX(pcre2_substitute_callout_block_) +#define pcre2_general_context PCRE2_SUFFIX(pcre2_general_context_) +#define pcre2_compile_context PCRE2_SUFFIX(pcre2_compile_context_) +#define pcre2_convert_context PCRE2_SUFFIX(pcre2_convert_context_) +#define pcre2_match_context PCRE2_SUFFIX(pcre2_match_context_) +#define pcre2_match_data PCRE2_SUFFIX(pcre2_match_data_) + + +/* Functions: the complete list in alphabetical order */ + +#define pcre2_callout_enumerate PCRE2_SUFFIX(pcre2_callout_enumerate_) +#define pcre2_code_copy PCRE2_SUFFIX(pcre2_code_copy_) +#define pcre2_code_copy_with_tables PCRE2_SUFFIX(pcre2_code_copy_with_tables_) +#define pcre2_code_free PCRE2_SUFFIX(pcre2_code_free_) +#define pcre2_compile PCRE2_SUFFIX(pcre2_compile_) +#define pcre2_compile_context_copy PCRE2_SUFFIX(pcre2_compile_context_copy_) +#define pcre2_compile_context_create PCRE2_SUFFIX(pcre2_compile_context_create_) +#define pcre2_compile_context_free PCRE2_SUFFIX(pcre2_compile_context_free_) +#define pcre2_config PCRE2_SUFFIX(pcre2_config_) +#define pcre2_convert_context_copy PCRE2_SUFFIX(pcre2_convert_context_copy_) +#define pcre2_convert_context_create PCRE2_SUFFIX(pcre2_convert_context_create_) +#define pcre2_convert_context_free PCRE2_SUFFIX(pcre2_convert_context_free_) +#define pcre2_converted_pattern_free PCRE2_SUFFIX(pcre2_converted_pattern_free_) +#define pcre2_dfa_match PCRE2_SUFFIX(pcre2_dfa_match_) +#define pcre2_general_context_copy PCRE2_SUFFIX(pcre2_general_context_copy_) +#define pcre2_general_context_create PCRE2_SUFFIX(pcre2_general_context_create_) +#define pcre2_general_context_free PCRE2_SUFFIX(pcre2_general_context_free_) +#define pcre2_get_error_message PCRE2_SUFFIX(pcre2_get_error_message_) +#define pcre2_get_mark PCRE2_SUFFIX(pcre2_get_mark_) +#define pcre2_get_match_data_heapframes_size PCRE2_SUFFIX(pcre2_get_match_data_heapframes_size_) +#define pcre2_get_match_data_size PCRE2_SUFFIX(pcre2_get_match_data_size_) +#define pcre2_get_ovector_pointer PCRE2_SUFFIX(pcre2_get_ovector_pointer_) +#define pcre2_get_ovector_count PCRE2_SUFFIX(pcre2_get_ovector_count_) +#define pcre2_get_startchar PCRE2_SUFFIX(pcre2_get_startchar_) +#define pcre2_jit_compile PCRE2_SUFFIX(pcre2_jit_compile_) +#define pcre2_jit_match PCRE2_SUFFIX(pcre2_jit_match_) +#define pcre2_jit_free_unused_memory PCRE2_SUFFIX(pcre2_jit_free_unused_memory_) +#define pcre2_jit_stack_assign PCRE2_SUFFIX(pcre2_jit_stack_assign_) +#define pcre2_jit_stack_create PCRE2_SUFFIX(pcre2_jit_stack_create_) +#define pcre2_jit_stack_free PCRE2_SUFFIX(pcre2_jit_stack_free_) +#define pcre2_maketables PCRE2_SUFFIX(pcre2_maketables_) +#define pcre2_maketables_free PCRE2_SUFFIX(pcre2_maketables_free_) +#define pcre2_match PCRE2_SUFFIX(pcre2_match_) +#define pcre2_match_context_copy PCRE2_SUFFIX(pcre2_match_context_copy_) +#define pcre2_match_context_create PCRE2_SUFFIX(pcre2_match_context_create_) +#define pcre2_match_context_free PCRE2_SUFFIX(pcre2_match_context_free_) +#define pcre2_match_data_create PCRE2_SUFFIX(pcre2_match_data_create_) +#define pcre2_match_data_create_from_pattern PCRE2_SUFFIX(pcre2_match_data_create_from_pattern_) +#define pcre2_match_data_free PCRE2_SUFFIX(pcre2_match_data_free_) +#define pcre2_next_match PCRE2_SUFFIX(pcre2_next_match_) +#define pcre2_pattern_convert PCRE2_SUFFIX(pcre2_pattern_convert_) +#define pcre2_pattern_info PCRE2_SUFFIX(pcre2_pattern_info_) +#define pcre2_serialize_decode PCRE2_SUFFIX(pcre2_serialize_decode_) +#define pcre2_serialize_encode PCRE2_SUFFIX(pcre2_serialize_encode_) +#define pcre2_serialize_free PCRE2_SUFFIX(pcre2_serialize_free_) +#define pcre2_serialize_get_number_of_codes PCRE2_SUFFIX(pcre2_serialize_get_number_of_codes_) +#define pcre2_set_bsr PCRE2_SUFFIX(pcre2_set_bsr_) +#define pcre2_set_callout PCRE2_SUFFIX(pcre2_set_callout_) +#define pcre2_set_character_tables PCRE2_SUFFIX(pcre2_set_character_tables_) +#define pcre2_set_compile_extra_options PCRE2_SUFFIX(pcre2_set_compile_extra_options_) +#define pcre2_set_compile_recursion_guard PCRE2_SUFFIX(pcre2_set_compile_recursion_guard_) +#define pcre2_set_depth_limit PCRE2_SUFFIX(pcre2_set_depth_limit_) +#define pcre2_set_glob_escape PCRE2_SUFFIX(pcre2_set_glob_escape_) +#define pcre2_set_glob_separator PCRE2_SUFFIX(pcre2_set_glob_separator_) +#define pcre2_set_heap_limit PCRE2_SUFFIX(pcre2_set_heap_limit_) +#define pcre2_set_match_limit PCRE2_SUFFIX(pcre2_set_match_limit_) +#define pcre2_set_max_varlookbehind PCRE2_SUFFIX(pcre2_set_max_varlookbehind_) +#define pcre2_set_max_pattern_length PCRE2_SUFFIX(pcre2_set_max_pattern_length_) +#define pcre2_set_max_pattern_compiled_length PCRE2_SUFFIX(pcre2_set_max_pattern_compiled_length_) +#define pcre2_set_newline PCRE2_SUFFIX(pcre2_set_newline_) +#define pcre2_set_parens_nest_limit PCRE2_SUFFIX(pcre2_set_parens_nest_limit_) +#define pcre2_set_offset_limit PCRE2_SUFFIX(pcre2_set_offset_limit_) +#define pcre2_set_optimize PCRE2_SUFFIX(pcre2_set_optimize_) +#define pcre2_set_substitute_callout PCRE2_SUFFIX(pcre2_set_substitute_callout_) +#define pcre2_set_substitute_case_callout PCRE2_SUFFIX(pcre2_set_substitute_case_callout_) +#define pcre2_substitute PCRE2_SUFFIX(pcre2_substitute_) +#define pcre2_substring_copy_byname PCRE2_SUFFIX(pcre2_substring_copy_byname_) +#define pcre2_substring_copy_bynumber PCRE2_SUFFIX(pcre2_substring_copy_bynumber_) +#define pcre2_substring_free PCRE2_SUFFIX(pcre2_substring_free_) +#define pcre2_substring_get_byname PCRE2_SUFFIX(pcre2_substring_get_byname_) +#define pcre2_substring_get_bynumber PCRE2_SUFFIX(pcre2_substring_get_bynumber_) +#define pcre2_substring_length_byname PCRE2_SUFFIX(pcre2_substring_length_byname_) +#define pcre2_substring_length_bynumber PCRE2_SUFFIX(pcre2_substring_length_bynumber_) +#define pcre2_substring_list_get PCRE2_SUFFIX(pcre2_substring_list_get_) +#define pcre2_substring_list_free PCRE2_SUFFIX(pcre2_substring_list_free_) +#define pcre2_substring_nametable_scan PCRE2_SUFFIX(pcre2_substring_nametable_scan_) +#define pcre2_substring_number_from_name PCRE2_SUFFIX(pcre2_substring_number_from_name_) + +/* Keep this old function name for backwards compatibility */ +#define pcre2_set_recursion_limit PCRE2_SUFFIX(pcre2_set_recursion_limit_) + +/* Keep this obsolete function for backwards compatibility: it is now a noop. */ +#define pcre2_set_recursion_memory_management PCRE2_SUFFIX(pcre2_set_recursion_memory_management_) + +/* Now generate all three sets of width-specific structures and function +prototypes. */ + +#define PCRE2_TYPES_STRUCTURES_AND_FUNCTIONS \ +PCRE2_TYPES_LIST \ +PCRE2_STRUCTURE_LIST \ +PCRE2_GENERAL_INFO_FUNCTIONS \ +PCRE2_GENERAL_CONTEXT_FUNCTIONS \ +PCRE2_COMPILE_CONTEXT_FUNCTIONS \ +PCRE2_CONVERT_CONTEXT_FUNCTIONS \ +PCRE2_CONVERT_FUNCTIONS \ +PCRE2_MATCH_CONTEXT_FUNCTIONS \ +PCRE2_COMPILE_FUNCTIONS \ +PCRE2_PATTERN_INFO_FUNCTIONS \ +PCRE2_MATCH_FUNCTIONS \ +PCRE2_SUBSTRING_FUNCTIONS \ +PCRE2_SERIALIZE_FUNCTIONS \ +PCRE2_SUBSTITUTE_FUNCTION \ +PCRE2_JIT_FUNCTIONS \ +PCRE2_OTHER_FUNCTIONS + +#define PCRE2_LOCAL_WIDTH 8 +PCRE2_TYPES_STRUCTURES_AND_FUNCTIONS +#undef PCRE2_LOCAL_WIDTH + +#define PCRE2_LOCAL_WIDTH 16 +PCRE2_TYPES_STRUCTURES_AND_FUNCTIONS +#undef PCRE2_LOCAL_WIDTH + +#define PCRE2_LOCAL_WIDTH 32 +PCRE2_TYPES_STRUCTURES_AND_FUNCTIONS +#undef PCRE2_LOCAL_WIDTH + +/* Undefine the list macros; they are no longer needed. */ + +#undef PCRE2_TYPES_LIST +#undef PCRE2_STRUCTURE_LIST +#undef PCRE2_GENERAL_INFO_FUNCTIONS +#undef PCRE2_GENERAL_CONTEXT_FUNCTIONS +#undef PCRE2_COMPILE_CONTEXT_FUNCTIONS +#undef PCRE2_CONVERT_CONTEXT_FUNCTIONS +#undef PCRE2_MATCH_CONTEXT_FUNCTIONS +#undef PCRE2_COMPILE_FUNCTIONS +#undef PCRE2_PATTERN_INFO_FUNCTIONS +#undef PCRE2_MATCH_FUNCTIONS +#undef PCRE2_SUBSTRING_FUNCTIONS +#undef PCRE2_SERIALIZE_FUNCTIONS +#undef PCRE2_SUBSTITUTE_FUNCTION +#undef PCRE2_JIT_FUNCTIONS +#undef PCRE2_OTHER_FUNCTIONS +#undef PCRE2_TYPES_STRUCTURES_AND_FUNCTIONS + +/* PCRE2_CODE_UNIT_WIDTH must be defined. If it is 8, 16, or 32, redefine +PCRE2_SUFFIX to use it. If it is 0, undefine the other macros and make +PCRE2_SUFFIX a no-op. Otherwise, generate an error. */ + +#undef PCRE2_SUFFIX +#ifndef PCRE2_CODE_UNIT_WIDTH +#error PCRE2_CODE_UNIT_WIDTH must be defined before including pcre2.h. +#error Use 8, 16, or 32; or 0 for a multi-width application. +#else /* PCRE2_CODE_UNIT_WIDTH is defined */ +#if PCRE2_CODE_UNIT_WIDTH == 8 || \ + PCRE2_CODE_UNIT_WIDTH == 16 || \ + PCRE2_CODE_UNIT_WIDTH == 32 +#define PCRE2_SUFFIX(a) PCRE2_GLUE(a, PCRE2_CODE_UNIT_WIDTH) +#elif PCRE2_CODE_UNIT_WIDTH == 0 +#undef PCRE2_JOIN +#undef PCRE2_GLUE +#define PCRE2_SUFFIX(a) a +#else +#error PCRE2_CODE_UNIT_WIDTH must be 0, 8, 16, or 32. +#endif +#endif /* PCRE2_CODE_UNIT_WIDTH is defined */ + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* PCRE2_H_IDEMPOTENT_GUARD */ + +/* End of pcre2.h */ diff --git a/internal/cpp/pcre2posix.h b/internal/cpp/pcre2posix.h new file mode 100644 index 00000000000..198612afcbc --- /dev/null +++ b/internal/cpp/pcre2posix.h @@ -0,0 +1,184 @@ +/************************************************* +* Perl-Compatible Regular Expressions * +*************************************************/ + +/* PCRE2 is a library of functions to support regular expressions whose syntax +and semantics are as close as possible to those of the Perl 5 language. This is +the public header file to be #included by applications that call PCRE2 via the +POSIX wrapper interface. + + Written by Philip Hazel + Original API code Copyright (c) 1997-2012 University of Cambridge + New API code Copyright (c) 2016-2023 University of Cambridge + +----------------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the University of Cambridge nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +*/ + +#ifndef PCRE2POSIX_H_IDEMPOTENT_GUARD +#define PCRE2POSIX_H_IDEMPOTENT_GUARD + +/* Have to include stdlib.h in order to ensure that size_t is defined. */ + +#include + +/* Allow for C++ users */ + +#ifdef __cplusplus +extern "C" { +#endif + +/* Options, mostly defined by POSIX, but with some extras. */ + +#define REG_ICASE 0x0001 /* Maps to PCRE2_CASELESS */ +#define REG_NEWLINE 0x0002 /* Maps to PCRE2_MULTILINE */ +#define REG_NOTBOL 0x0004 /* Maps to PCRE2_NOTBOL */ +#define REG_NOTEOL 0x0008 /* Maps to PCRE2_NOTEOL */ +#define REG_DOTALL 0x0010 /* NOT defined by POSIX; maps to PCRE2_DOTALL */ +#define REG_NOSUB 0x0020 /* Do not report what was matched */ +#define REG_UTF 0x0040 /* NOT defined by POSIX; maps to PCRE2_UTF */ +#define REG_STARTEND 0x0080 /* BSD feature: pass subject string by so,eo */ +#define REG_NOTEMPTY 0x0100 /* NOT defined by POSIX; maps to PCRE2_NOTEMPTY */ +#define REG_UNGREEDY 0x0200 /* NOT defined by POSIX; maps to PCRE2_UNGREEDY */ +#define REG_UCP 0x0400 /* NOT defined by POSIX; maps to PCRE2_UCP */ +#define REG_PEND 0x0800 /* GNU feature: pass end pattern by re_endp */ +#define REG_NOSPEC 0x1000 /* Maps to PCRE2_LITERAL */ + +/* This is not used by PCRE2, but by defining it we make it easier +to slot PCRE2 into existing programs that make POSIX calls. */ + +#define REG_EXTENDED 0 + +/* Error values. Not all these are relevant or used by the wrapper. */ + +enum { + REG_ASSERT = 1, /* internal error ? */ + REG_BADBR, /* invalid repeat counts in {} */ + REG_BADPAT, /* pattern error */ + REG_BADRPT, /* ? * + invalid */ + REG_EBRACE, /* unbalanced {} */ + REG_EBRACK, /* unbalanced [] */ + REG_ECOLLATE, /* collation error - not relevant */ + REG_ECTYPE, /* bad class */ + REG_EESCAPE, /* bad escape sequence */ + REG_EMPTY, /* empty expression */ + REG_EPAREN, /* unbalanced () */ + REG_ERANGE, /* bad range inside [] */ + REG_ESIZE, /* expression too big */ + REG_ESPACE, /* failed to get memory */ + REG_ESUBREG, /* bad back reference */ + REG_INVARG, /* bad argument */ + REG_NOMATCH /* match failed */ +}; + + +/* The structure representing a compiled regular expression. It is also used +for passing the pattern end pointer when REG_PEND is set. */ + +typedef struct { + void *re_pcre2_code; + void *re_match_data; + const char *re_endp; + size_t re_nsub; + size_t re_erroffset; + int re_cflags; +} regex_t; + +/* The structure in which a captured offset is returned. */ + +typedef int regoff_t; + +typedef struct { + regoff_t rm_so; + regoff_t rm_eo; +} regmatch_t; + +/* When an application links to a PCRE2 DLL in Windows, the symbols that are +imported have to be identified as such. When building PCRE2, the appropriate +export settings are needed, and are set in pcre2posix.c before including this +file. So, we don't change existing definitions of PCRE2POSIX_EXP_DECL. + +By default, we use the standard "extern" declarations. */ + +#ifndef PCRE2POSIX_EXP_DECL +# if defined(_WIN32) && defined(PCRE2POSIX_SHARED) +# define PCRE2POSIX_EXP_DECL extern __declspec(dllimport) +# elif defined __cplusplus +# define PCRE2POSIX_EXP_DECL extern "C" +# else +# define PCRE2POSIX_EXP_DECL extern +# endif +#endif + +/* When compiling with the MSVC compiler, it is sometimes necessary to include +a "calling convention" before exported function names. For example: + + void __cdecl function(....) + +might be needed. In order to make this easy, all the exported functions have +PCRE2_CALL_CONVENTION just before their names. + +PCRE2 normally uses the platform's standard calling convention, so this should +not be set unless you know you need it. */ + +#ifndef PCRE2_CALL_CONVENTION +#define PCRE2_CALL_CONVENTION +#endif + +/* The functions. The actual code is in functions with pcre2_xxx names for +uniqueness. POSIX names are provided as macros for API compatibility with POSIX +regex functions. It's done this way to ensure to they are always linked from +the PCRE2 library and not by accident from elsewhere (regex_t differs in size +elsewhere). */ + +PCRE2POSIX_EXP_DECL int PCRE2_CALL_CONVENTION pcre2_regcomp(regex_t *, const char *, int); +PCRE2POSIX_EXP_DECL int PCRE2_CALL_CONVENTION pcre2_regexec(const regex_t *, const char *, size_t, + regmatch_t *, int); +PCRE2POSIX_EXP_DECL size_t PCRE2_CALL_CONVENTION pcre2_regerror(int, const regex_t *, char *, size_t); +PCRE2POSIX_EXP_DECL void PCRE2_CALL_CONVENTION pcre2_regfree(regex_t *); + +#define regcomp pcre2_regcomp +#define regexec pcre2_regexec +#define regerror pcre2_regerror +#define regfree pcre2_regfree + +/* Debian had a patch that used different names. These are now here to save +them having to maintain their own patch, but are not documented by PCRE2. */ + +#define PCRE2regcomp pcre2_regcomp +#define PCRE2regexec pcre2_regexec +#define PCRE2regerror pcre2_regerror +#define PCRE2regfree pcre2_regfree + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* PCRE2POSIX_H_IDEMPOTENT_GUARD */ + +/* End of pcre2posix.h */ diff --git a/internal/cpp/rag_analyzer.cpp b/internal/cpp/rag_analyzer.cpp new file mode 100644 index 00000000000..5f7799bb19f --- /dev/null +++ b/internal/cpp/rag_analyzer.cpp @@ -0,0 +1,2431 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define PCRE2_CODE_UNIT_WIDTH 8 + +#include "opencc/openccxx.h" +#include "pcre2.h" + +#include "string_utils.h" +#include "rag_analyzer.h" +#include "re2/re2.h" + +#include +#include +#include +#include +#include +#include +// import :term; +// import :stemmer; +// import :analyzer; +// import :darts_trie; +// import :wordnet_lemmatizer; +// import :stemmer; +// import :term; +// +// import std.compat; + +namespace fs = std::filesystem; + +static const std::string DICT_PATH = "rag/huqie.txt"; +static const std::string POS_DEF_PATH = "rag/pos-id.def"; +static const std::string TRIE_PATH = "rag/huqie.trie"; +static const std::string WORDNET_PATH = "wordnet"; + +static const std::string OPENCC_PATH = "opencc"; + +static const std::string REGEX_SPLIT_CHAR = + R"#(([ ,\.<>/?;'\[\]\`!@#$%^&*$$\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z\.-]+|[0-9,\.-]+))#"; + +static const std::string NLTK_TOKENIZE_PATTERN = + R"((?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|(?=[^\(\"\`{\[:;&\#\*@\)}\]\-,])\S+?(?=\s|$|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|,(?=$|\s|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)))|\S)"; + +static constexpr std::size_t MAX_SENTENCE_LEN = 100; + +static inline int32_t Encode(int32_t freq, int32_t idx) { + uint32_t encoded_value = 0; + if (freq < 0) { + encoded_value |= static_cast(-freq); + encoded_value |= (1U << 23); + } else { + encoded_value = static_cast(freq & 0x7FFFFF); + } + + encoded_value |= static_cast(idx) << 24; + return static_cast(encoded_value); +} + +static inline int32_t DecodeFreq(int32_t value) { + uint32_t v1 = static_cast(value) & 0xFFFFFF; + if (v1 & (1 << 23)) { + v1 &= 0x7FFFFF; + return -static_cast(v1); + } else { + v1 = static_cast(v1); + } + return v1; +} + +static inline int32_t DecodePOSIndex(int32_t value) { + // POS index is stored in the high 8 bits (bits 24-31) + return static_cast(static_cast(value) >> 24); +} + +void Split(const std::string &input, const std::string &split_pattern, std::vector &result, bool keep_delim = false) { + re2::RE2 pattern(split_pattern); + re2::StringPiece leftover(input.data()); + re2::StringPiece last_end = leftover; + re2::StringPiece extracted_delim_token; + + while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) { + std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data()); + if (!token.empty()) { + result.emplace_back(token.data(), token.size()); + } + if (keep_delim) + result.emplace_back(extracted_delim_token.data(), extracted_delim_token.size()); + last_end = leftover; + } + + if (!leftover.empty()) { + result.emplace_back(leftover.data(), leftover.size()); + } +} + +void Split(const std::string &input, const re2::RE2 &pattern, std::vector &result, bool keep_delim = false) { + re2::StringPiece leftover(input.data()); + re2::StringPiece last_end = leftover; + re2::StringPiece extracted_delim_token; + + while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) { + std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data()); + if (!token.empty()) { + result.emplace_back(token.data(), token.size()); + } + if (keep_delim) + result.emplace_back(extracted_delim_token.data(), extracted_delim_token.size()); + last_end = leftover; + } + + if (!leftover.empty()) { + result.emplace_back(leftover.data(), leftover.size()); + } +} + +std::string Replace(const re2::RE2 &re, const std::string &replacement, const std::string &input) { + std::string output = input; + re2::RE2::GlobalReplace(&output, re, replacement); + return output; +} + +template +std::string Join(const std::vector &tokens, int start, int end, const std::string &delim = " ") { + std::ostringstream oss; + for (int i = start; i < end; ++i) { + if (i > start) + oss << delim; + oss << tokens[i]; + } + return std::move(oss).str(); +} + +template +std::string Join(const std::vector &tokens, int start, const std::string &delim = " ") { + return Join(tokens, start, tokens.size(), delim); +} + +std::string Join(const TermList &tokens, int start, int end, const std::string &delim = " ") { + std::ostringstream oss; + for (int i = start; i < end; ++i) { + if (i > start) + oss << delim; + oss << tokens[i].text_; + } + return std::move(oss).str(); +} + +bool IsChinese(const std::string &str) { + for (std::size_t i = 0; i < str.length(); ++i) { + unsigned char c = str[i]; + if (c >= 0xE4 && c <= 0xE9) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 >= 0x80 && c2 <= 0xBF) && (c3 >= 0x80 && c3 <= 0xBF)) { + return true; + } + } + } + } + return false; +} + +bool IsAlphabet(const std::string &str) { + for (std::size_t i = 0; i < str.length(); ++i) { + unsigned char c = str[i]; + if (c > 0x7F) { + return false; + } + } + return true; +} + +bool IsKorean(const std::string &str) { + for (std::size_t i = 0; i < str.length(); ++i) { + unsigned char c = str[i]; + if (c == 0xE1) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 == 0x84 || c2 == 0x85 || c2 == 0x86 || c2 == 0x87) && (c3 >= 0x80 && c3 <= 0xBF)) { + return true; + } + } + } + } + return false; +} + +bool IsJapanese(const std::string &str) { + for (std::size_t i = 0; i < str.length(); ++i) { + unsigned char c = str[i]; + if (c == 0xE3) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 == 0x81 || c2 == 0x82 || c2 == 0x83) && (c3 >= 0x81 && c3 <= 0xBF)) { + return true; + } + } + } + } + return false; +} + +bool IsCJK(const std::string &str) { + for (std::size_t i = 0; i < str.length(); ++i) { + unsigned char c = str[i]; + + // Check Chinese + if (c >= 0xE4 && c <= 0xE9) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 >= 0x80 && c2 <= 0xBF) && (c3 >= 0x80 && c3 <= 0xBF)) { + return true; + } + } + } + + // Check Japanese + if (c == 0xE3) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 == 0x81 || c2 == 0x82 || c2 == 0x83) && (c3 >= 0x81 && c3 <= 0xBF)) { + return true; + } + } + } + + // Check Korean + if (c == 0xE1) { + if (i + 2 < str.length()) { + unsigned char c2 = str[i + 1]; + unsigned char c3 = str[i + 2]; + if ((c2 == 0x84 || c2 == 0x85 || c2 == 0x86 || c2 == 0x87) && (c3 >= 0x80 && c3 <= 0xBF)) { + return true; + } + } + } + } + return false; +} + +class RegexTokenizer { +public: + RegexTokenizer() { + int errorcode = 0; + PCRE2_SIZE erroffset = 0; + + re_ = pcre2_compile((PCRE2_SPTR)(NLTK_TOKENIZE_PATTERN.c_str()), + PCRE2_ZERO_TERMINATED, + PCRE2_MULTILINE | PCRE2_UTF, + &errorcode, + &erroffset, + nullptr); + } + + ~RegexTokenizer() { + pcre2_code_free(re_); + } + + void RegexTokenize(const std::string &input, TermList &tokens) { + PCRE2_SPTR subject = (PCRE2_SPTR)input.c_str(); + PCRE2_SIZE subject_length = input.length(); + + pcre2_match_data_8 *match_data = pcre2_match_data_create_8(1024, nullptr); + + PCRE2_SIZE start_offset = 0; + + while (start_offset < subject_length) { + int res = pcre2_match(re_, subject, subject_length, start_offset, 0, match_data, nullptr); + + if (res < 0) { + if (res == PCRE2_ERROR_NOMATCH) { + break; // No more matches + } else { + std::cerr << "Matching error code: " << res << std::endl; + break; // Other error + } + } + + // Extract matched substring + PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data); + for (int i = 0; i < res; ++i) { + PCRE2_SIZE start = ovector[2 * i]; + PCRE2_SIZE end = ovector[2 * i + 1]; + tokens.Add(input.c_str() + start, end - start, start, end); + } + + // Update the start offset for the next search + start_offset = ovector[1]; // Move to the end of the last match + } + + // Free memory + pcre2_match_data_free(match_data); + } + +private: + pcre2_code_8 *re_{nullptr}; +}; + +class MacIntyreContractions { +public: + // List of contractions adapted from Robert MacIntyre's tokenizer. + std::vector CONTRACTIONS2 = {R"((?i)\b(can)(?#X)(not)\b)", + R"((?i)\b(d)(?#X)('ye)\b)", + R"((?i)\b(gim)(?#X)(me)\b)", + R"((?i)\b(gon)(?#X)(na)\b)", + R"((?i)\b(got)(?#X)(ta)\b)", + R"((?i)\b(lem)(?#X)(me)\b)", + R"((?i)\b(more)(?#X)('n)\b)", + R"((?i)\b(wan)(?#X)(na)(?=\s))"}; + std::vector CONTRACTIONS3 = {R"((?i) ('t)(?#X)(is)\b)", R"((?i) ('t)(?#X)(was)\b)"}; + std::vector CONTRACTIONS4 = {R"((?i)\b(whad)(dd)(ya)\b)", R"((?i)\b(wha)(t)(cha)\b)"}; +}; + +// Structure to hold precompiled regex patterns +struct CompiledRegex { + pcre2_code *re{nullptr}; + std::string substitution; + + CompiledRegex(pcre2_code *r, std::string sub) : re(r), substitution(std::move(sub)) { + } + + CompiledRegex(const CompiledRegex &) = delete; + CompiledRegex &operator=(const CompiledRegex &) = delete; + CompiledRegex(CompiledRegex &&other) noexcept : re(other.re), substitution(std::move(other.substitution)) { other.re = nullptr; } + + CompiledRegex &operator=(CompiledRegex &&other) noexcept { + if (this != &other) { + if (re) + pcre2_code_free(re); + re = other.re; + substitution = std::move(other.substitution); + other.re = nullptr; + } + return *this; + } + + ~CompiledRegex() { + if (re) { + pcre2_code_free(re); + } + } +}; + +class NLTKWordTokenizer { + MacIntyreContractions contractions_; + + // Static singleton instance + static std::unique_ptr instance_; + static std::once_flag init_flag_; + +public: + // Static method to get the singleton instance + static NLTKWordTokenizer &GetInstance() { + std::call_once(init_flag_, []() { instance_ = std::make_unique(); }); + return *instance_; + } + + // Starting quotes. + std::vector> STARTING_QUOTES = { + {std::string(R"(([«“‘„]|[`]+))"), std::string(R"( $1 )")}, + {std::string(R"(^\")"), std::string(R"(``)")}, + {std::string(R"((``))"), std::string(R"( $1 )")}, + {std::string(R"(([ \(\[{<])(\"|\'{2}))"), std::string(R"($1 `` )")}, + {std::string(R"((?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b)"), std::string(R"($1 $2)")}}; + + // Ending quotes. + std::vector> ENDING_QUOTES = { + {std::string(R"(([»”’]))"), std::string(R"( $1 )")}, + {std::string(R"('')"), std::string(R"( '' )")}, + {std::string(R"(")"), std::string(R"( '' )")}, + {std::string(R"(\s+)"), std::string(R"( )")}, + {std::string(R"(([^' ])('[sS]|'[mM]|'[dD]|') )"), std::string(R"($1 $2 )")}, + {std::string(R"(([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) )"), std::string(R"($1 $2 )")}}; + + // Punctuation. + std::vector> PUNCTUATION = { + {std::string(R"(([^\.])(\.)([\]\)}>"\'»”’ ]*)\s*$)"), std::string(R"($1 $2 $3 )")}, + {std::string(R"(([:,])([^\d]))"), std::string(R"( $1 $2)")}, + {std::string(R"(([:,])$)"), std::string(R"($1 )")}, + {std::string(R"(\.{2,})"), std::string(R"($0 )")}, + {std::string(R"([;@#$%&])"), std::string(R"($0 )")}, + {std::string(R"(([^\.])(\.)([\]\)}>"\']*)\s*$)"), std::string(R"($1 $2 $3 )")}, + {std::string(R"([?!])"), std::string(R"($0 )")}, + {std::string(R"(([^'])' )"), std::string(R"($1 ' )")}, + {std::string(R"([*])"), std::string(R"($0 )")}}; + + // Pads parentheses + std::pair PARENS_BRACKETS = {std::string(R"([\]\[\(\)\{\}\<\>])"), std::string(R"( $0 )")}; + + std::vector> CONVERT_PARENTHESES = {{std::string(R"(\()"), std::string("-LRB-")}, + {std::string(R"(\))"), std::string("-RRB-")}, + {std::string(R"(\[)"), std::string("-LSB-")}, + {std::string(R"(\])"), std::string("-RSB-")}, + {std::string(R"(\{)"), std::string("-LCB-")}, + {std::string(R"(\})"), std::string("-RCB-")}}; + + std::pair DOUBLE_DASHES = {std::string(R"(--)"), std::string(R"( -- )")}; + + // Cache for compiled regex patterns + std::vector compiled_starting_quotes_; + std::vector compiled_ending_quotes_; + std::vector compiled_punctuation_; + CompiledRegex compiled_parens_brackets_; + std::vector compiled_convert_parentheses_; + CompiledRegex compiled_double_dashes_; + std::vector compiled_contractions2_; + std::vector compiled_contractions3_; + + // Constructor that precompiles all regex patterns + NLTKWordTokenizer() : compiled_parens_brackets_(nullptr, ""), compiled_double_dashes_(nullptr, "") { CompileRegexPatterns(); } + + void Tokenize(const std::string &text, std::vector &tokens, bool convert_parentheses = false) { + std::string result = text; + + for (const auto &compiled : compiled_starting_quotes_) { + result = ApplyRegex(result, compiled); + } + for (const auto &compiled : compiled_punctuation_) { + result = ApplyRegex(result, compiled); + } + + // Handles parentheses. + result = ApplyRegex(result, compiled_parens_brackets_); + + // Optionally convert parentheses + if (convert_parentheses) { + for (const auto &compiled : compiled_convert_parentheses_) { + result = ApplyRegex(result, compiled); + } + } + + // Handles double dash. + result = ApplyRegex(result, compiled_double_dashes_); + + // Add extra space to make things easier + result = " " + result + " "; + + for (const auto &compiled : compiled_ending_quotes_) { + result = ApplyRegex(result, compiled); + } + + for (const auto &compiled : compiled_contractions2_) { + result = ApplyRegex(result, compiled); + } + + for (const auto &compiled : compiled_contractions3_) { + result = ApplyRegex(result, compiled); + } + + // Split the result into tokens + size_t start = 0; + size_t end = result.find(' '); + while (end != std::string::npos) { + if (end != start) { + std::string token = result.substr(start, end - start); + // Handle underscore tokens properly + if (token == "_") { + // Single underscore token + tokens.push_back("_"); + } else if (token.find('_') != std::string::npos) { + // Split tokens containing underscores and keep underscores as separate tokens + std::stringstream ss(token); + std::string sub_token; + bool first = true; + while (std::getline(ss, sub_token, '_')) { + if (!first) { + tokens.push_back("_"); + } + if (!sub_token.empty()) { + tokens.push_back(sub_token); + } + first = false; + } + // Handle case where token ends with underscore + if (token.back() == '_') { + tokens.push_back("_"); + } + } else { + tokens.push_back(token); + } + } + start = end + 1; + end = result.find(' ', start); + } + if (start != result.length()) { + std::string token = result.substr(start); + // Handle underscore tokens properly + if (token == "_") { + // Single underscore token + tokens.push_back("_"); + } else if (token.find('_') != std::string::npos) { + // Split tokens containing underscores and keep underscores as separate tokens + std::stringstream ss(token); + std::string sub_token; + bool first = true; + while (std::getline(ss, sub_token, '_')) { + if (!first) { + tokens.push_back("_"); + } + if (!sub_token.empty()) { + tokens.push_back(sub_token); + } + first = false; + } + // Handle case where token ends with underscore + if (token.back() == '_') { + tokens.push_back("_"); + } + } else { + tokens.push_back(token); + } + } + } + +private: + void CompileRegexPatterns() { + compiled_starting_quotes_.reserve(STARTING_QUOTES.size()); + for (const auto &[pattern, substitution] : STARTING_QUOTES) { + compiled_starting_quotes_.emplace_back(CompilePattern(pattern), substitution); + } + + compiled_ending_quotes_.reserve(ENDING_QUOTES.size()); + for (const auto &[pattern, substitution] : ENDING_QUOTES) { + compiled_ending_quotes_.emplace_back(CompilePattern(pattern), substitution); + } + + compiled_punctuation_.reserve(PUNCTUATION.size()); + for (const auto &[pattern, substitution] : PUNCTUATION) { + compiled_punctuation_.emplace_back(CompilePattern(pattern), substitution); + } + + compiled_parens_brackets_ = CompiledRegex(CompilePattern(PARENS_BRACKETS.first), PARENS_BRACKETS.second); + + compiled_convert_parentheses_.reserve(CONVERT_PARENTHESES.size()); + for (const auto &[pattern, substitution] : CONVERT_PARENTHESES) { + compiled_convert_parentheses_.emplace_back(CompilePattern(pattern), substitution); + } + + compiled_double_dashes_ = CompiledRegex(CompilePattern(DOUBLE_DASHES.first), DOUBLE_DASHES.second); + + compiled_contractions2_.reserve(contractions_.CONTRACTIONS2.size()); + for (const auto &pattern : contractions_.CONTRACTIONS2) { + compiled_contractions2_.emplace_back(CompilePattern(pattern), R"( $1 $2 )"); + } + + compiled_contractions3_.reserve(contractions_.CONTRACTIONS3.size()); + for (const auto &pattern : contractions_.CONTRACTIONS3) { + compiled_contractions3_.emplace_back(CompilePattern(pattern), R"( $1 $2 )"); + } + } + + pcre2_code *CompilePattern(const std::string &pattern) { + int errorcode = 0; + PCRE2_SIZE erroffset = 0; + pcre2_code *re = pcre2_compile(reinterpret_cast(pattern.c_str()), + PCRE2_ZERO_TERMINATED, + PCRE2_MULTILINE | PCRE2_UTF, + &errorcode, + &erroffset, + nullptr); + + if (re == nullptr) { + PCRE2_UCHAR buffer[256]; + pcre2_get_error_message(errorcode, buffer, sizeof(buffer)); + std::cerr << "PCRE2 compilation failed at offset " << erroffset << ": " << buffer << std::endl; + return nullptr; + } + return re; + } + + std::string ApplyRegex(const std::string &text, const CompiledRegex &compiled) { + if (compiled.re == nullptr) { + return text; + } + + PCRE2_SPTR pcre2_subject = reinterpret_cast(text.c_str()); + PCRE2_SPTR pcre2_replacement = reinterpret_cast(compiled.substitution.c_str()); + + size_t outlength = text.length() * 2 < 1024 ? 1024 : text.length() * 2; + auto buffer = std::make_unique(outlength); + int rc = pcre2_substitute(compiled.re, + pcre2_subject, + text.length(), + 0, + PCRE2_SUBSTITUTE_GLOBAL, + nullptr, + nullptr, + pcre2_replacement, + PCRE2_ZERO_TERMINATED, + buffer.get(), + &outlength); + + if (rc < 0) { + return text; + } + + return std::string(reinterpret_cast(buffer.get()), outlength); + } +}; + +// Static member definitions for NLTKWordTokenizer singleton +std::unique_ptr NLTKWordTokenizer::instance_ = nullptr; +std::once_flag NLTKWordTokenizer::init_flag_; + +void SentenceSplitter(const std::string &text, std::vector &result) { + int error_code; + PCRE2_SIZE error_offset; + const char *pattern = R"( *[\.\?!]['"\)\]]* *)"; + + pcre2_code *re = pcre2_compile((PCRE2_SPTR)pattern, PCRE2_ZERO_TERMINATED, PCRE2_MULTILINE | PCRE2_UTF, &error_code, &error_offset, nullptr); + + if (re == nullptr) { + PCRE2_UCHAR buffer[256]; + pcre2_get_error_message(error_code, buffer, sizeof(buffer)); + std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " << buffer << std::endl; + return; + } + + pcre2_match_data *match_data = pcre2_match_data_create_from_pattern(re, nullptr); + + PCRE2_SIZE start_offset = 0; + while (start_offset < text.size()) { + int rc = pcre2_match(re, (PCRE2_SPTR)text.c_str(), text.size(), start_offset, 0, match_data, nullptr); + + if (rc < 0) { + result.push_back(text.substr(start_offset)); + break; + } + + PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data); + PCRE2_SIZE match_start = ovector[0]; + PCRE2_SIZE match_end = ovector[1]; + + if (match_start > start_offset) { + result.push_back(text.substr(start_offset, match_end - start_offset)); + } + + start_offset = match_end; + } + + pcre2_match_data_free(match_data); + pcre2_code_free(re); +} + +RAGAnalyzer::RAGAnalyzer(const std::string &path) + : dict_path_(path), stemmer_(std::make_unique()) { + InitStemmer(STEM_LANG_ENGLISH); +} + +RAGAnalyzer::RAGAnalyzer(const RAGAnalyzer &other) + : own_dict_(false), trie_(other.trie_), pos_table_(other.pos_table_), wordnet_lemma_(other.wordnet_lemma_), stemmer_(std::make_unique()), + opencc_(other.opencc_), fine_grained_(other.fine_grained_) { + InitStemmer(STEM_LANG_ENGLISH); +} + +RAGAnalyzer::~RAGAnalyzer() { + if (own_dict_) { + delete trie_; + delete pos_table_; + delete wordnet_lemma_; + delete opencc_; + } +} + +int32_t RAGAnalyzer::Load() { + fs::path root(dict_path_); + fs::path dict_path(root / DICT_PATH); + + if (!fs::exists(dict_path)) { + printf("Invalid analyzer file: %s", dict_path.string().c_str()); + // return Status::InvalidAnalyzerFile(dict_path); + return -1; + } + + fs::path pos_def_path(root / POS_DEF_PATH); + if (!fs::exists(pos_def_path)) { + printf("Invalid post file: %s", pos_def_path.string().c_str()); + // return Status::InvalidAnalyzerFile(pos_def_path); + return -1; + } + own_dict_ = true; + trie_ = new DartsTrie(); + pos_table_ = new POSTable(pos_def_path.string()); + if (pos_table_->Load() != 0) { + printf("Fail to load post table: %s", pos_def_path.string().c_str()); + return -1; + // return Status::InvalidAnalyzerFile("Failed to load RAGAnalyzer POS definition"); + } + + fs::path trie_path(root / TRIE_PATH); + if (fs::exists(trie_path)) { + trie_->Load(trie_path.string()); + } else { + // Build trie + try { + std::ifstream from(dict_path.string()); + std::string line; + re2::RE2 re_pattern(R"([\r\n]+)"); + std::string split_pattern("([ \t])"); + + while (getline(from, line)) { + line = line.substr(0, line.find('\r')); + if (line.empty()) + continue; + line = Replace(re_pattern, "", line); + std::vector results; + Split(line, split_pattern, results); + if (results.size() != 3) + throw std::runtime_error("Invalid dictionary format"); + int32_t freq = std::stoi(results[1]); + freq = int32_t(std::log(float(freq) / DENOMINATOR) + 0.5); + int32_t pos_idx = pos_table_->GetPOSIndex(results[2]); + int value = Encode(freq, pos_idx); + trie_->Add(results[0], value); + std::string rkey = RKey(results[0]); + trie_->Add(rkey, Encode(1, 0)); + } + trie_->Build(); + } catch (const std::exception &e) { + return -1; + // return Status::InvalidAnalyzerFile("Failed to load RAGAnalyzer analyzer"); + } + trie_->Save(trie_path.string()); + } + + fs::path lemma_path(root / WORDNET_PATH); + if (!fs::exists(lemma_path)) { + printf("Fail to load wordnet: %s", lemma_path.string().c_str()); + return -1; + // return Status::InvalidAnalyzerFile(lemma_path); + } + + wordnet_lemma_ = new WordNetLemmatizer(lemma_path.string()); + + fs::path opencc_path(root / OPENCC_PATH); + + if (!fs::exists(opencc_path)) { + printf("Fail to load opencc_path: %s", opencc_path.string().c_str()); + return -1; + // return Status::InvalidAnalyzerFile(opencc_path); + } + try { + opencc_ = new ::OpenCC(opencc_path.string()); + } catch (const std::exception &e) { + return -1; + // return Status::InvalidAnalyzerFile("Failed to load OpenCC"); + } + + // return Status::OK(); + return 0; +} + +void RAGAnalyzer::BuildPositionMapping(const std::string &original, const std::string &converted, std::vector &pos_mapping) { + pos_mapping.clear(); + pos_mapping.resize(converted.size() + 1); + + size_t orig_pos = 0; + size_t conv_pos = 0; + + // Map each character position from converted string to original string + while (orig_pos < original.size() && conv_pos < converted.size()) { + // Get character lengths + size_t orig_char_len = UTF8_BYTE_LENGTH_TABLE[static_cast(original[orig_pos])]; + size_t conv_char_len = UTF8_BYTE_LENGTH_TABLE[static_cast(converted[conv_pos])]; + + // Map all bytes of current converted character to current original position + for (size_t i = 0; i < conv_char_len && conv_pos + i < pos_mapping.size(); ++i) { + pos_mapping[conv_pos + i] = static_cast(orig_pos); + } + + // Move to next character in both strings + orig_pos += orig_char_len; + conv_pos += conv_char_len; + } + + // Fill any remaining positions + for (size_t i = conv_pos; i < pos_mapping.size(); ++i) { + pos_mapping[i] = static_cast(original.size()); + } +} + +std::string RAGAnalyzer::StrQ2B(const std::string &input) { + std::string output; + size_t i = 0; + + while (i < input.size()) { + unsigned char c = input[i]; + + uint32_t codepoint = 0; + if (c < 0x80) { + codepoint = c; + i += 1; + } else if ((c & 0xE0) == 0xC0) { + codepoint = (c & 0x1F) << 6; + codepoint |= (input[i + 1] & 0x3F); + i += 2; + } else if ((c & 0xF0) == 0xE0) { + codepoint = (c & 0x0F) << 12; + codepoint |= (input[i + 1] & 0x3F) << 6; + codepoint |= (input[i + 2] & 0x3F); + i += 3; + } else { + output += c; + i += 1; + continue; + } + + if (codepoint >= 0xFF01 && codepoint <= 0xFF5E) { + output += static_cast(codepoint - 0xFEE0); + } else if (codepoint == 0x3000) { + output += ' '; + } else { + if (codepoint < 0x80) { + output += static_cast(codepoint); + } else if (codepoint < 0x800) { + output += static_cast(0xC0 | (codepoint >> 6)); + output += static_cast(0x80 | (codepoint & 0x3F)); + } else if (codepoint < 0x10000) { + output += static_cast(0xE0 | (codepoint >> 12)); + output += static_cast(0x80 | ((codepoint >> 6) & 0x3F)); + output += static_cast(0x80 | (codepoint & 0x3F)); + } + } + } + + return output; +} + +int32_t RAGAnalyzer::Freq(const std::string_view key) const { + int32_t v = trie_->Get(key); + v = DecodeFreq(v); + return static_cast(std::exp(v) * DENOMINATOR + 0.5); +} + +std::string RAGAnalyzer::Tag(std::string_view key) const { + std::string lower_key = Key(std::string(key)); + int32_t encoded_value = trie_->Get(lower_key); + if (encoded_value == -1) { + return ""; + } + int32_t pos_idx = DecodePOSIndex(encoded_value); + if (pos_table_ == nullptr) { + return ""; + } + const char* pos_tag = pos_table_->GetPOS(pos_idx); + return pos_tag ? std::string(pos_tag) : ""; +} + +std::string RAGAnalyzer::Key(const std::string_view line) { return ToLowerString(line); } + +std::string RAGAnalyzer::RKey(const std::string_view line) { + std::string reversed; + reversed.reserve(line.size() + 2); + reversed += "DD"; + for (size_t i = line.size(); i > 0;) { + size_t start = i - 1; + while (start > 0 && (line[start] & 0xC0) == 0x80) { + --start; + } + reversed += line.substr(start, i - start); + i = start; + } + ToLower(reversed.data() + 2, reversed.size() - 2); + return reversed; +} + +std::pair, double> RAGAnalyzer::Score(const std::vector> &token_freqs) { + constexpr int64_t B = 30; + int64_t F = 0, L = 0; + std::vector tokens; + tokens.reserve(token_freqs.size()); + for (const auto &[token, freq_tag] : token_freqs) { + F += DecodeFreq(freq_tag); + L += (UTF8Length(token) < 2) ? 0 : 1; + tokens.push_back(token); + } + const auto score = B / static_cast(tokens.size()) + L / static_cast(tokens.size()) + F; + return {std::move(tokens), score}; +} + +void RAGAnalyzer::SortTokens(const std::vector>> &token_list, + std::vector, double>> &res) { + for (const auto &tfts : token_list) { + res.push_back(Score(tfts)); + } + std::sort(res.begin(), res.end(), [](const auto &a, const auto &b) { return a.second > b.second; }); +} + +std::pair, double> RAGAnalyzer::MaxForward(const std::string &line) const { + std::vector> res; + std::size_t s = 0; + std::size_t len = UTF8Length(line); + + while (s < len) { + std::size_t e = s + 1; + std::string t = UTF8Substr(line, s, e - s); + + while (e < len && trie_->HasKeysWithPrefix(Key(t))) { + e += 1; + t = UTF8Substr(line, s, e - s); + } + + while (e - 1 > s && trie_->Get(Key(t)) == -1) { + e -= 1; + t = UTF8Substr(line, s, e - s); + } + + int v = trie_->Get(Key(t)); + if (v != -1) { + res.emplace_back(std::move(t), v); + } else { + res.emplace_back(std::move(t), 0); + } + + s = e; + } + + return Score(res); +} + +std::pair, double> RAGAnalyzer::MaxBackward(const std::string &line) const { + std::vector> res; + int s = UTF8Length(line) - 1; + + while (s >= 0) { + const int e = s + 1; + std::string t = UTF8Substr(line, s, e - s); + while (s > 0 && trie_->HasKeysWithPrefix(RKey(t))) { + s -= 1; + t = UTF8Substr(line, s, e - s); + } + while (s + 1 < e && trie_->Get(Key(t)) == -1) { + s += 1; + t = UTF8Substr(line, s, e - s); + } + + int v = trie_->Get(Key(t)); + if (v != -1) { + res.emplace_back(std::move(t), v); + } else { + res.emplace_back(std::move(t), 0); + } + + s -= 1; + } + + std::reverse(res.begin(), res.end()); + return Score(res); +} + +int RAGAnalyzer::DFS(const std::string &chars, + const int s, + std::vector> &pre_tokens, + std::vector>> &token_list, + std::vector &best_tokens, + double &max_score, + const bool memo_all) const { + int res = s; + const int len = UTF8Length(chars); + if (s >= len) { + if (memo_all) { + token_list.push_back(pre_tokens); + } else if (auto [vec_str, current_score] = Score(pre_tokens); current_score > max_score) { + best_tokens = std::move(vec_str); + max_score = current_score; + } + return res; + } + // pruning + int S = s + 1; + if (s + 2 <= len) { + std::string t1 = UTF8Substr(chars, s, 1); + std::string t2 = UTF8Substr(chars, s, 2); + if (trie_->HasKeysWithPrefix(Key(t1)) && !trie_->HasKeysWithPrefix(Key(t2))) { + S = s + 2; + } + } + + if (pre_tokens.size() > 2 && UTF8Length(pre_tokens[pre_tokens.size() - 1].first) == 1 && + UTF8Length(pre_tokens[pre_tokens.size() - 2].first) == 1 && UTF8Length(pre_tokens[pre_tokens.size() - 3].first) == 1) { + std::string t1 = pre_tokens[pre_tokens.size() - 1].first + UTF8Substr(chars, s, 1); + if (trie_->HasKeysWithPrefix(Key(t1))) { + S = s + 2; + } + } + + for (int e = S; e <= len; ++e) { + std::string t = UTF8Substr(chars, s, e - s); + std::string k = Key(t); + + if (e > s + 1 && !trie_->HasKeysWithPrefix(k)) { + break; + } + + if (const int v = trie_->Get(k); v != -1) { + auto pretks = pre_tokens; + pretks.emplace_back(std::move(t), v); + res = std::max(res, DFS(chars, e, pretks, token_list, best_tokens, max_score, memo_all)); + } + } + + if (res > s) { + return res; + } + + std::string t = UTF8Substr(chars, s, 1); + if (const int v = trie_->Get(Key(t)); v != -1) { + pre_tokens.emplace_back(std::move(t), v); + } else { + pre_tokens.emplace_back(std::move(t), Encode(-12, 0)); + } + + return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all); +} + +struct TokensList { + const TokensList *prev = nullptr; + std::string_view token = {}; +}; + +struct BestTokenCandidate { + static constexpr int64_t B = 30; + TokensList tl{}; + // N: token num + // L: num of tokens with length >= 2 + // F: sum of freq + uint32_t N{}; + uint32_t L{}; + int64_t F{}; + + auto k() const { +#ifdef DIVIDE_F_BY_N + return N; +#else + return std::make_pair(N, L); +#endif + } + + auto v() const { return F; } + + auto score() const { +#ifdef DIVIDE_F_BY_N + return static_cast(B + L + F) / N; +#else + return F + (static_cast(B + L) / N); +#endif + } + + BestTokenCandidate update(const std::string_view new_token_sv, const int32_t key_f, const uint32_t add_l) const { + return {{&tl, new_token_sv}, N + 1, L + add_l, F + key_f}; + } +}; + +struct GrowingBestTokenCandidatesTopN { + int32_t top_n{}; + std::vector candidates{}; + + explicit GrowingBestTokenCandidatesTopN(const int32_t top_n) : top_n(top_n) { + } + + void AddBestTokenCandidateTopN(const BestTokenCandidate &add_candidate) { + const auto [it_b, it_e] = + std::equal_range(candidates.begin(), candidates.end(), add_candidate, [](const auto &a, const auto &b) { return a.k() < b.k(); }); + auto target_it = it_b; + bool do_replace = false; + if (const auto match_cnt = std::distance(it_b, it_e); match_cnt >= top_n) { + assert(match_cnt == top_n); + const auto it = std::min_element(it_b, it_e, [](const auto &a, const auto &b) { return a.v() < b.v(); }); + if (it->v() >= add_candidate.v()) { + return; + } + target_it = it; + do_replace = true; + } + if (do_replace) { + *target_it = add_candidate; + } else { + candidates.insert(target_it, add_candidate); + } + } +}; + +std::vector, double>> RAGAnalyzer::GetBestTokensTopN(const std::string_view chars, const uint32_t n) const { + const auto utf8_len = UTF8Length(chars); + std::vector dp_vec(utf8_len + 1, GrowingBestTokenCandidatesTopN(n)); + dp_vec[0].candidates.resize(1); + const char *current_utf8_ptr = chars.data(); + uint32_t current_left_chars = chars.size(); + std::string growing_key; // in lower case + for (uint32_t i = 0; i < utf8_len; ++i) { + const std::string_view current_chars{current_utf8_ptr, current_left_chars}; + const uint32_t left_utf8_cnt = utf8_len - i; + growing_key.clear(); + const char *lookup_until = current_utf8_ptr; + uint32_t lookup_left_chars = current_left_chars; + std::size_t reuse_node_pos = 0; + std::size_t reuse_key_pos = 0; + for (uint32_t j = 1; j <= left_utf8_cnt; ++j) { + { + // handle growing_key + const auto next_one_utf8 = UTF8Substrview({lookup_until, lookup_left_chars}, 0, 1); + if (next_one_utf8.size() == 1 && next_one_utf8[0] >= 'A' && next_one_utf8[0] <= 'Z') { + growing_key.push_back(next_one_utf8[0] - 'A' + 'a'); + } else { + growing_key.append(next_one_utf8); + } + lookup_until += next_one_utf8.size(); + lookup_left_chars -= next_one_utf8.size(); + } + auto dp_f = [&dp_vec, i, j, original_sv = std::string_view{current_utf8_ptr, growing_key.size()}]( + const int32_t key_f, + const uint32_t add_l) { + auto &target_dp = dp_vec[i + j]; + for (const auto &c : dp_vec[i].candidates) { + target_dp.AddBestTokenCandidateTopN(c.update(original_sv, key_f, add_l)); + } + }; + if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size()); + traverse_result >= 0) { + // in dictionary + const int32_t key_f = DecodeFreq(traverse_result); + const auto add_l = static_cast(j >= 2); + dp_f(key_f, add_l); + } else { + // not in dictionary + if (j == 1) { + // also give a score: -12 + dp_f(-12, 0); + } + if (traverse_result == -2) { + // no more results + break; + } + } + } + // update current_utf8_ptr and current_left_chars + const auto forward_cnt = UTF8Substrview(current_chars, 0, 1).size(); + current_utf8_ptr += forward_cnt; + current_left_chars -= forward_cnt; + } + std::vector> mid_result; + mid_result.reserve(n); + for (const auto &c : dp_vec.back().candidates) { + const auto new_pair = std::make_pair(&(c.tl), c.score()); + if (mid_result.size() < n) { + mid_result.push_back(new_pair); + } else { + assert(mid_result.size() == n); + if (new_pair.second > mid_result.back().second) { + mid_result.pop_back(); + const auto insert_pos = std::lower_bound(mid_result.begin(), + mid_result.end(), + new_pair, + [](const auto &a, const auto &b) { + return a.second > b.second; + }); + mid_result.insert(insert_pos, new_pair); + } + } + } + class HelperFunc { + uint32_t cnt = 0; + std::vector result{}; + + void GetTokensInner(const TokensList *tl) { + if (!tl->prev) { + result.reserve(cnt); + return; + } + ++cnt; + GetTokensInner(tl->prev); + result.push_back(tl->token); + } + + public: + std::vector GetTokens(const TokensList *tl) { + GetTokensInner(tl); + return std::move(result); + } + }; + std::vector, double>> result; + result.reserve(mid_result.size()); + for (const auto [tl, score] : mid_result) { + result.emplace_back(HelperFunc{}.GetTokens(tl), score); + } + return result; +} + +// TODO: for test +// #ifndef INFINITY_DEBUG +// #define INFINITY_DEBUG 1 +// #endif + +#ifdef INFINITY_DEBUG +namespace dp_debug { +template +std::string TestPrintTokens(const std::vector &tokens) { + std::ostringstream oss; + for (std::size_t i = 0; i < tokens.size(); ++i) { + oss << (i ? " #" : "#") << tokens[i] << "#"; + } + return std::move(oss).str(); +} + +auto print_1 = [](const bool b) { return b ? "✅" : "❌"; }; +auto print_2 = [](const bool b) { return b ? "equal" : "not equal"; }; + +void compare_score_and_tokens(const std::vector &dfs_tokens, + const double dfs_score, + const std::vector &dp_tokens, + const double dp_score, + const std::string &prefix) { + std::ostringstream oss; + const auto b_score_eq = dp_score == dfs_score; + oss << fmt::format("\n{} {} DFS and DP score {}:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), prefix, print_2(b_score_eq), dfs_score, dp_score); + bool vec_equal = true; + if (dp_tokens.size() != dfs_tokens.size()) { + vec_equal = false; + } else { + for (std::size_t k = 0; k < dp_tokens.size(); ++k) { + if (dp_tokens[k] != dfs_tokens[k]) { + vec_equal = false; + break; + } + } + } + oss << fmt::format("{} {} DFS and DP result {}:\nDFS: {}\nDP : {}\n", + print_1(vec_equal), + prefix, + print_2(vec_equal), + TestPrintTokens(dfs_tokens), + TestPrintTokens(dp_tokens)); + std::cerr << std::move(oss).str() << std::endl; +} + +inline void CheckDP(const RAGAnalyzer *this_ptr, + const std::string_view input_str, + const std::vector &dfs_tokens, + const double dfs_score, + const auto t0, + const auto t1) { + const auto dp_result = this_ptr->GetBestTokensTopN(input_str, 1); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP1 DFS duration: " << dfs_duration << " \nDP duration: " << dp_duration; + const auto &[dp_vec, dp_score] = dp_result[0]; + compare_score_and_tokens(dfs_tokens, dfs_score, dp_vec, dp_score, "[1 in top1]"); +} + +inline void CheckDP2(const RAGAnalyzer *this_ptr, const std::string_view input_str, auto get_dfs_sorted_tokens, const auto t0, const auto t1) { + constexpr int topn = 2; + const auto dp_result = this_ptr->GetBestTokensTopN(input_str, topn); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP2 DFS duration: " << dfs_duration << " \nTOP2 DP duration: " << dp_duration; + const auto dfs_sorted_tokens = get_dfs_sorted_tokens(); + for (int i = 0; i < std::min(topn, (int)dfs_sorted_tokens.size()); ++i) { + compare_score_and_tokens(dfs_sorted_tokens[i].first, + dfs_sorted_tokens[i].second, + dp_result[i].first, + dp_result[i].second, + std::format("[{} in top{}]", i + 1, topn)); + } +} +} // namespace dp_debug +#endif + +std::string RAGAnalyzer::Merge(const std::string &tks_str) const { + std::string tks = tks_str; + + tks = Replace(replace_space_pattern_, " ", tks); + + std::vector tokens; + Split(tks, blank_pattern_, tokens); + std::vector res; + std::size_t s = 0; + while (true) { + if (s >= tokens.size()) + break; + + std::size_t E = s + 1; + for (std::size_t e = s + 2; e < std::min(tokens.size() + 1, s + 6); ++e) { + std::string tk = Join(tokens, s, e, ""); + if (re2::RE2::PartialMatch(tk, regex_split_pattern_)) { + if (Freq(tk) > 0) { + E = e; + } + } + } + res.push_back(Join(tokens, s, E, "")); + s = E; + } + + return Join(res, 0, res.size()); +} + +void RAGAnalyzer::MergeWithPosition(const std::vector &tokens, + const std::vector> &positions, + std::vector &merged_tokens, + std::vector> &merged_positions) const { + // Filter out empty tokens first (like spaces) to match Merge behavior + std::vector filtered_tokens; + std::vector> filtered_positions; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (!tokens[i].empty() && tokens[i] != " ") { + filtered_tokens.push_back(tokens[i]); + filtered_positions.push_back(positions[i]); + } + } + + std::vector res; + std::size_t s = 0; + std::vector> res_positions; + + while (true) { + if (s >= filtered_tokens.size()) + break; + + std::size_t E = s + 1; + for (std::size_t e = s + 2; e < std::min(filtered_tokens.size() + 1, s + 6); ++e) { + std::string tk = Join(filtered_tokens, s, e, ""); + if (re2::RE2::PartialMatch(tk, regex_split_pattern_)) { + if (Freq(tk) > 0) { + E = e; + } + } + } + + std::string merged_token = Join(filtered_tokens, s, E, ""); + res.push_back(merged_token); + + unsigned start_pos = filtered_positions[s].first; + unsigned end_pos = filtered_positions[E - 1].second; + res_positions.emplace_back(start_pos, end_pos); + + s = E; + } + + merged_tokens = std::move(res); + merged_positions = std::move(res_positions); +} + +void RAGAnalyzer::EnglishNormalize(const std::vector &tokens, std::vector &res) const { + for (auto &t : tokens) { + if (re2::RE2::PartialMatch(t, pattern1_)) { + //"[a-zA-Z_-]+$" + std::string lemma_term = wordnet_lemma_->Lemmatize(t); + std::vector lowercase_buffer(term_string_buffer_limit_); + char *lowercase_term = lowercase_buffer.data(); + ToLower(lemma_term.c_str(), lemma_term.size(), lowercase_term, term_string_buffer_limit_); + std::string stem_term; + stemmer_->Stem(lowercase_term, stem_term); + res.push_back(stem_term); + } else { + res.push_back(t); + } + } +} + +void RAGAnalyzer::SplitByLang(const std::string &line, std::vector> &txt_lang_pairs) const { + std::vector arr; + Split(line, regex_split_pattern_, arr, true); + + for (const auto &a : arr) { + if (a.empty()) { + continue; + } + + std::size_t s = 0; + std::size_t e = s + 1; + bool zh = IsChinese(UTF8Substr(a, s, 1)); + + while (e < UTF8Length(a)) { + bool _zh = IsChinese(UTF8Substr(a, e, 1)); + if (_zh == zh) { + e++; + continue; + } + + std::string segment = UTF8Substr(a, s, e - s); + txt_lang_pairs.emplace_back(segment, zh); + + s = e; + e = s + 1; + zh = _zh; + } + + if (s >= UTF8Length(a)) { + continue; + } + + std::string segment = UTF8Substr(a, s, e - s); + txt_lang_pairs.emplace_back(segment, zh); + } +} + +void RAGAnalyzer::TokenizeInner(std::vector &res, const std::string &L) const { + auto [tks, s] = MaxForward(L); + auto [tks1, s1] = MaxBackward(L); + +#if 0 + std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0; + while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) { + same++; + } + if (same > 0) { + res.push_back(Join(tks, j, j + same)); + } + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + while (i < tks1.size() && j < tks.size()) { + std::string tk1 = Join(tks1, _i, i, ""); + std::string tk = Join(tks, _j, j, ""); + if (tk1 != tk) { + if (tk1.length() > tk.length()) { + j++; + } else { + i++; + } + continue; + } + if (tks1[i] != tks[j]) { + i++; + j++; + continue; + } + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, j, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + res.push_back(Join(best_tokens, 0)); + + same = 1; + while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same]) + same++; + res.push_back(Join(tks, j, j + same)); + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + } + if (_i < tks1.size()) { + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, tks.size(), ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + res.push_back(Join(best_tokens, 0)); + } + +#else + std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0; + while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) { + same++; + } + if (same > 0) { + res.push_back(Join(tks, j, j + same)); + } + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + while (i < tks1.size() && j < tks.size()) { + std::string tk1 = Join(tks1, _i, i, ""); + std::string tk = Join(tks, _j, j, ""); + if (tk1 != tk) { + if (tk1.length() > tk.length()) { + j++; + } else { + i++; + } + continue; + } + if (tks1[i] != tks[j]) { + i++; + j++; + continue; + } + + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, j, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + res.push_back(Join(best_tokens, 0)); + + same = 1; + while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same]) + same++; + res.push_back(Join(tks, j, j + same)); + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + } + if (_i < tks1.size()) { + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, tks.size(), ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + res.push_back(Join(best_tokens, 0)); + } +#endif +} + +void RAGAnalyzer::SplitLongText(const std::string &L, uint32_t length, std::vector &sublines) const { + uint32_t slice_count = length / MAX_SENTENCE_LEN + 1; + sublines.reserve(slice_count); + std::size_t last_sentence_start = 0; + std::size_t next_sentence_start = 0; + for (unsigned i = 0; i < slice_count; ++i) { + next_sentence_start = MAX_SENTENCE_LEN * (i + 1) - 5; + if (next_sentence_start + 5 < length) { + std::size_t sentence_length = MAX_SENTENCE_LEN * (i + 1) + 5 > length ? length - next_sentence_start : 10; + std::string substr = UTF8Substr(L, next_sentence_start, sentence_length); + auto [tks, s] = MaxForward(substr); + auto [tks1, s1] = MaxBackward(substr); + std::vector diff(std::max(tks.size(), tks1.size()), 0); + for (std::size_t j = 0; j < std::min(tks.size(), tks1.size()); ++j) { + if (tks[j] != tks1[j]) { + diff[j] = 1; + } + } + + if (s1 > s) { + tks = tks1; + } + std::size_t start = 0; + std::size_t forward_same_len = 0; + while (start < tks.size() && diff[start] == 0) { + forward_same_len += UTF8Length(tks[start]); + start++; + } + if (forward_same_len == 0) { + std::size_t end = tks.size() - 1; + std::size_t backward_same_len = 0; + while (end >= 0 && diff[end] == 0) { + backward_same_len += UTF8Length(tks[end]); + end--; + } + next_sentence_start += sentence_length - backward_same_len; + } else + next_sentence_start += forward_same_len; + } else + next_sentence_start = length; + if (next_sentence_start == last_sentence_start) + continue; + std::string str = UTF8Substr(L, last_sentence_start, next_sentence_start - last_sentence_start); + sublines.push_back(str); + last_sentence_start = next_sentence_start; + } +} + +// PCRE2-based replacement function to match Python's re.sub behavior +// Returns processed string and position mapping from processed to original +std::pair>> +PCRE2GlobalReplaceWithPosition(const std::string &text, const std::string &pattern, const std::string &replacement) { + + std::vector> pos_mapping; + std::string result; + + pcre2_code *re; + PCRE2_SPTR pcre2_pattern = reinterpret_cast(pattern.c_str()); + PCRE2_SPTR pcre2_subject = reinterpret_cast(text.c_str()); + // Note: pcre2_replacement is used in the replacement logic below + int errorcode; + PCRE2_SIZE erroroffset; + + // Compile the pattern with UTF and UCP flags for Unicode support + re = pcre2_compile(pcre2_pattern, PCRE2_ZERO_TERMINATED, PCRE2_UCP | PCRE2_UTF, &errorcode, &erroroffset, nullptr); + + if (re == nullptr) { + PCRE2_UCHAR buffer[256]; + pcre2_get_error_message(errorcode, buffer, sizeof(buffer)); + std::cerr << "PCRE2 compilation failed at offset " << erroroffset << ": " << buffer << std::endl; + return {text, {}}; + } + + pcre2_match_data *match_data = pcre2_match_data_create_from_pattern(re, nullptr); + + PCRE2_SIZE current_pos = 0; + PCRE2_SIZE last_match_end = 0; + + // Process the string match by match + while (current_pos < text.length()) { + int rc = pcre2_match(re, pcre2_subject, text.length(), current_pos, 0, match_data, nullptr); + + if (rc < 0) { + // No more matches, copy remaining text + if (last_match_end < text.length()) { + std::string remaining = text.substr(last_match_end); + result += remaining; + + // Map each character in remaining text + for (size_t i = 0; i < remaining.length(); ++i) { + pos_mapping.emplace_back(last_match_end + i, last_match_end + i); + } + } + break; + } + + PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data); + PCRE2_SIZE match_start = ovector[0]; + PCRE2_SIZE match_end = ovector[1]; + + // Copy text before the match + if (last_match_end < match_start) { + std::string before_match = text.substr(last_match_end, match_start - last_match_end); + result += before_match; + + // Map each character in before_match + for (size_t i = 0; i < before_match.length(); ++i) { + pos_mapping.emplace_back(last_match_end + i, last_match_end + i); + } + } + + // Add the replacement string + result += replacement; + + // Map each character in replacement to the start of the match + for (size_t i = 0; i < replacement.length(); ++i) { + pos_mapping.emplace_back(match_start, match_start); + } + + last_match_end = match_end; + current_pos = match_end; + + // If the match was zero-length, move forward one character to avoid infinite loop + if (match_start == match_end) { + if (current_pos < text.length()) { + current_pos++; + } else { + break; + } + } + } + + pcre2_match_data_free(match_data); + pcre2_code_free(re); + + return {result, pos_mapping}; +} + +// Original PCRE2GlobalReplace for backward compatibility +std::string PCRE2GlobalReplace(const std::string &text, const std::string &pattern, const std::string &replacement) { + auto [result, _] = PCRE2GlobalReplaceWithPosition(text, pattern, replacement); + return result; +} + +std::string RAGAnalyzer::Tokenize(const std::string &line) const { + // Python-style simple tokenization: re.sub(r"\\W+", " ", line) + std::string processed_line = PCRE2GlobalReplace(line, R"#(\W+)#", " "); + std::string str1 = StrQ2B(processed_line); + std::string strline; + opencc_->convert(str1, strline); + + std::vector res; + + // Use SplitByLang to separate by language + std::vector> arr; + SplitByLang(strline, arr); + + for (const auto &[L, lang] : arr) { + if (!lang) { + // Non-Chinese text: use NLTK tokenizer, lemmatize and stem + std::vector term_list; + std::vector sentences; + SentenceSplitter(L, sentences); + for (auto &sentence : sentences) { + NLTKWordTokenizer::GetInstance().Tokenize(sentence, term_list); + } + for (unsigned i = 0; i < term_list.size(); ++i) { + std::string t = wordnet_lemma_->Lemmatize(term_list[i]); + std::vector lowercase_buffer(term_string_buffer_limit_); + char *lowercase_term = lowercase_buffer.data(); + ToLower(t.c_str(), t.size(), lowercase_term, term_string_buffer_limit_); + std::string stem_term; + stemmer_->Stem(lowercase_term, stem_term); + res.push_back(stem_term); + } + continue; + } + auto length = UTF8Length(L); + if (length < 2 || re2::RE2::PartialMatch(L, pattern2_) || re2::RE2::PartialMatch(L, pattern3_)) { + //[a-z\\.-]+$ [0-9\\.-]+$ + res.push_back(L); + continue; + } + + // Chinese processing: use TokenizeInner +#if 0 + if (length > MAX_SENTENCE_LEN) { + std::vector sublines; + SplitLongText(L, length, sublines); + for (auto &l : sublines) { + TokenizeInner(res, l); + } + } else +#endif + TokenizeInner(res, L); + } + + // std::vector normalize_res; + // EnglishNormalize(res, normalize_res); + std::string r = Join(res, 0); + std::string ret = Merge(r); + return ret; +} + +std::pair, std::vector>> RAGAnalyzer::TokenizeWithPosition(const std::string &line) const { + // Python-style simple tokenization: re.sub(r"\W+", " ", line) + // Get processed line and position mapping from PCRE2GlobalReplace + auto [processed_line, pcre2_pos_mapping] = PCRE2GlobalReplaceWithPosition(line, R"#(\W+)#", " "); + + std::string str1 = StrQ2B(processed_line); + std::string strline; + opencc_->convert(str1, strline); + std::vector tokens; + std::vector> positions; + + // Build character position mapping from StrQ2B conversion + std::vector strq2b_pos_mapping; + BuildPositionMapping(processed_line, str1, strq2b_pos_mapping); + + // Build character position mapping from OpenCC conversion + std::vector opencc_pos_mapping; + BuildPositionMapping(str1, strline, opencc_pos_mapping); + + // Combine all position mappings: strline -> str1 -> processed_line -> line + std::vector final_pos_mapping; + final_pos_mapping.resize(strline.size() + 1); + + for (size_t i = 0; i < strline.size(); ++i) { + if (i < opencc_pos_mapping.size()) { + unsigned str1_pos = opencc_pos_mapping[i]; + if (str1_pos < strq2b_pos_mapping.size()) { + unsigned processed_pos = strq2b_pos_mapping[str1_pos]; + if (processed_pos < pcre2_pos_mapping.size()) { + final_pos_mapping[i] = pcre2_pos_mapping[processed_pos].first; + } else { + final_pos_mapping[i] = static_cast(line.size()); + } + } else { + final_pos_mapping[i] = static_cast(line.size()); + } + } else { + final_pos_mapping[i] = static_cast(line.size()); + } + } + + // Fill the last position + if (strline.size() < final_pos_mapping.size()) { + final_pos_mapping[strline.size()] = static_cast(line.size()); + } + + // Use SplitByLang to separate by language + std::vector> arr; + SplitByLang(strline, arr); + unsigned current_pos = 0; + + for (const auto &[L, lang] : arr) { + if (L.empty()) { + continue; + } + + std::size_t processed_pos = strline.find(L, current_pos); + if (processed_pos == std::string::npos) { + continue; + } + + unsigned original_start = current_pos; + current_pos = original_start + static_cast(L.size()); + + if (!lang) { + // Non-Chinese text: use NLTK tokenizer, lemmatize and stem + std::vector term_list; + std::vector sentences; + SentenceSplitter(L, sentences); + + unsigned sentence_start_pos = original_start; + for (auto &sentence : sentences) { + std::vector sentence_terms; + NLTKWordTokenizer::GetInstance().Tokenize(sentence, sentence_terms); + + unsigned current_search_pos = 0; + for (auto &term : sentence_terms) { + size_t pos_in_sentence = sentence.find(term, current_search_pos); + if (pos_in_sentence != std::string::npos) { + unsigned start_pos = sentence_start_pos + static_cast(pos_in_sentence); + unsigned end_pos = start_pos + static_cast(term.size()); + std::string t = wordnet_lemma_->Lemmatize(term); + std::vector lowercase_buffer(term_string_buffer_limit_); + char *lowercase_term = lowercase_buffer.data(); + ToLower(t.c_str(), t.size(), lowercase_term, term_string_buffer_limit_); + std::string stem_term; + stemmer_->Stem(lowercase_term, stem_term); + + tokens.push_back(stem_term); + + // Map positions back to original string using final_pos_mapping + if (start_pos < final_pos_mapping.size()) { + positions.emplace_back(final_pos_mapping[start_pos], final_pos_mapping[end_pos]); + } else { + positions.emplace_back(static_cast(line.size()), static_cast(line.size())); + } + + current_search_pos = pos_in_sentence + term.size(); + } + } + sentence_start_pos += static_cast(sentence.size()); + } + continue; + } + + auto length = UTF8Length(L); + if (length < 2 || re2::RE2::PartialMatch(L, pattern2_) || re2::RE2::PartialMatch(L, pattern3_)) { + tokens.push_back(L); + + // Map positions back to original string using final_pos_mapping + unsigned start_pos = original_start; + unsigned end_pos = original_start + static_cast(L.size()); + if (start_pos < final_pos_mapping.size() && end_pos < final_pos_mapping.size()) { + positions.emplace_back(final_pos_mapping[start_pos], final_pos_mapping[end_pos]); + } else { + positions.emplace_back(static_cast(line.size()), static_cast(line.size())); + } + continue; + } + + // Chinese processing: use TokenizeInnerWithPosition +#if 0 + if (length > MAX_SENTENCE_LEN) { + std::vector sublines; + SplitLongText(L, length, sublines); + unsigned subline_start_pos = original_start; + for (auto &l : sublines) { + TokenizeInnerWithPosition(l, tokens, positions, subline_start_pos, &final_pos_mapping); + subline_start_pos += static_cast(l.size()); + } + } else +#endif + TokenizeInnerWithPosition(L, tokens, positions, original_start, &final_pos_mapping); + } + + // std::vector normalize_tokens; + // std::vector> normalize_positions; + // EnglishNormalizeWithPosition(tokens, positions, normalize_tokens, normalize_positions); + + // Apply MergeWithPosition to match Tokenize behavior + std::vector merged_tokens; + std::vector> merged_positions; + MergeWithPosition(tokens, positions, merged_tokens, merged_positions); + + tokens = std::move(merged_tokens); + positions = std::move(merged_positions); + + return {std::move(tokens), std::move(positions)}; +} + +unsigned RAGAnalyzer::MapToOriginalPosition(unsigned processed_pos, const std::vector> &mapping) const { + for (const auto &[orig, proc] : mapping) { + if (proc == processed_pos) { + return orig; + } + } + return processed_pos; +} + +static unsigned CalculateTokensLength(const std::vector &tokens, int start, int end) { + unsigned total_length = 0; + for (int i = start; i < end; ++i) { + total_length += static_cast(tokens[i].size()); + } + return total_length; +} + +void RAGAnalyzer::TokenizeInnerWithPosition(const std::string &L, + std::vector &tokens, + std::vector> &positions, + unsigned base_pos, + const std::vector *pos_mapping) const { + auto [tks, s] = MaxForward(L); + auto [tks1, s1] = MaxBackward(L); + + // Use the same algorithm as Python version + std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0; + while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) { + same++; + } + if (same > 0) { + std::string token_str = Join(tks, j, j + same); + unsigned token_len = static_cast(token_str.size()); + unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, j); + + if (token_str.find(' ') != std::string::npos) { + std::vector space_split_tokens; + Split(token_str, blank_pattern_, space_split_tokens, false); + unsigned space_start_pos = start_pos; + for (const auto &space_token : space_split_tokens) { + if (space_token.empty()) { + continue; + } + unsigned space_token_len = static_cast(space_token.size()); + tokens.push_back(space_token); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0; + unsigned mapped_end = + (space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(space_start_pos, space_start_pos + space_token_len); + } + space_start_pos += space_token_len; + } + } else { + tokens.push_back(token_str); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0; + unsigned mapped_end = (start_pos + token_len) < pos_mapping->size() ? (*pos_mapping)[start_pos + token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(start_pos, start_pos + token_len); + } + } + } + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + + while (i < tks1.size() && j < tks.size()) { + std::string tk1 = Join(tks1, _i, i, ""); + std::string tk = Join(tks, _j, j, ""); + if (tk1 != tk) { + if (tk1.length() > tk.length()) { + j++; + } else { + i++; + } + continue; + } + if (tks1[i] != tks[j]) { + i++; + j++; + continue; + } + + // Handle different part with DFS + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, j, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + + std::string best_token_str = Join(best_tokens, 0); + unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, _j); + std::string original_token_str = Join(tks, _j, j, ""); + unsigned end_pos = start_pos + static_cast(original_token_str.size()); + + if (best_token_str.find(' ') != std::string::npos) { + std::vector space_split_tokens; + Split(best_token_str, blank_pattern_, space_split_tokens, false); + unsigned space_start_pos = start_pos; + for (const auto &space_token : space_split_tokens) { + if (space_token.empty()) { + continue; + } + unsigned space_token_len = static_cast(space_token.size()); + tokens.push_back(space_token); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0; + unsigned mapped_end = + (space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(space_start_pos, space_start_pos + space_token_len); + } + space_start_pos += space_token_len; + } + } else { + tokens.push_back(best_token_str); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0; + unsigned mapped_end = end_pos < pos_mapping->size() ? (*pos_mapping)[end_pos] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(start_pos, end_pos); + } + } + + same = 1; + while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same]) + same++; + + // Handle same part after different tokens + std::string token_str = Join(tks, j, j + same); + unsigned token_len = static_cast(token_str.size()); + start_pos = base_pos + CalculateTokensLength(tks, 0, j); + + if (token_str.find(' ') != std::string::npos) { + std::vector space_split_tokens; + Split(token_str, blank_pattern_, space_split_tokens, false); + unsigned space_start_pos = start_pos; + for (const auto &space_token : space_split_tokens) { + if (space_token.empty()) { + continue; + } + unsigned space_token_len = static_cast(space_token.size()); + tokens.push_back(space_token); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0; + unsigned mapped_end = + (space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(space_start_pos, space_start_pos + space_token_len); + } + space_start_pos += space_token_len; + } + } else { + tokens.push_back(token_str); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0; + unsigned mapped_end = (start_pos + token_len) < pos_mapping->size() ? (*pos_mapping)[start_pos + token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(start_pos, start_pos + token_len); + } + } + + _i = i + same; + _j = j + same; + j = _j + 1; + i = _i + 1; + } + + // Handle remaining part + if (_i < tks1.size()) { + std::vector> pre_tokens; + std::vector>> token_list; + std::vector best_tokens; + double max_score = std::numeric_limits::lowest(); + const auto str_for_dfs = Join(tks, _j, tks.size(), ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif + + std::string best_token_str = Join(best_tokens, 0); + unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, _j); + std::string original_token_str = Join(tks, _j, tks.size(), ""); + unsigned end_pos = start_pos + static_cast(original_token_str.size()); + + if (best_token_str.find(' ') != std::string::npos) { + std::vector space_split_tokens; + Split(best_token_str, blank_pattern_, space_split_tokens, false); + unsigned space_start_pos = start_pos; + for (const auto &space_token : space_split_tokens) { + if (space_token.empty()) { + continue; + } + unsigned space_token_len = static_cast(space_token.size()); + tokens.push_back(space_token); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0; + unsigned mapped_end = + (space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(space_start_pos, space_start_pos + space_token_len); + } + space_start_pos += space_token_len; + } + } else { + tokens.push_back(best_token_str); + // Map position back to original string if mapping is provided + if (pos_mapping) { + unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0; + unsigned mapped_end = end_pos < pos_mapping->size() ? (*pos_mapping)[end_pos] : 0; + positions.emplace_back(mapped_start, mapped_end); + } else { + positions.emplace_back(start_pos, end_pos); + } + } + } +} + +void RAGAnalyzer::EnglishNormalizeWithPosition(const std::vector &tokens, + const std::vector> &positions, + std::vector &normalize_tokens, + std::vector> &normalize_positions) const { + for (size_t i = 0; i < tokens.size(); ++i) { + const auto &token = tokens[i]; + const auto &[start_pos, end_pos] = positions[i]; + + if (re2::RE2::PartialMatch(token, pattern1_)) { + //"[a-zA-Z_-]+$" + std::string lemma_term = wordnet_lemma_->Lemmatize(token); + std::vector lowercase_buffer(term_string_buffer_limit_); + char *lowercase_term = lowercase_buffer.data(); + ToLower(lemma_term.c_str(), lemma_term.size(), lowercase_term, term_string_buffer_limit_); + std::string stem_term; + stemmer_->Stem(lowercase_term, stem_term); + + normalize_tokens.push_back(stem_term); + normalize_positions.emplace_back(start_pos, end_pos); + } else { + normalize_tokens.push_back(token); + normalize_positions.emplace_back(start_pos, end_pos); + } + } +} + +void RAGAnalyzer::FineGrainedTokenizeWithPosition(const std::string &tokens_str, + const std::vector> &positions, + std::vector &fine_tokens, + std::vector> &fine_positions) const { + std::vector tks; + Split(tokens_str, blank_pattern_, tks); + + std::size_t zh_num = 0; + for (auto &token : tks) { + int len = UTF8Length(token); + for (int i = 0; i < len; ++i) { + std::string t = UTF8Substr(token, i, 1); + if (IsChinese(t)) { + zh_num++; + } + } + } + + if (zh_num < tks.size() * 0.2) { + // English text processing - apply normalization + std::vector temp_tokens; + for (size_t i = 0; i < tks.size(); ++i) { + const auto &token = tks[i]; + const auto &[start_pos, end_pos] = positions[i]; + + std::istringstream iss(token); + std::string sub_token; + unsigned sub_start = start_pos; + + while (std::getline(iss, sub_token, '/')) { + if (!sub_token.empty()) { + unsigned sub_end = sub_start + sub_token.size(); + fine_tokens.push_back(sub_token); + fine_positions.emplace_back(sub_start, sub_end); + sub_start = sub_end + 1; + } + } + } + + // Apply English normalization to get lowercase and stemmed tokens + // std::vector> temp_positions = fine_positions; + // EnglishNormalizeWithPosition(temp_tokens, temp_positions, fine_tokens, fine_positions); + } else { + // Chinese or mixed text processing - match FineGrainedTokenize behavior + for (size_t i = 0; i < tks.size(); ++i) { + const auto &token = tks[i]; + const auto &[start_pos, end_pos] = positions[i]; + const auto token_len = UTF8Length(token); + + if (token_len < 3 || re2::RE2::PartialMatch(token, pattern4_)) { + fine_tokens.push_back(token); + fine_positions.emplace_back(start_pos, end_pos); + continue; + } + + std::vector>> token_list; + if (token_len > 10) { + std::vector> tk; + tk.emplace_back(token, Encode(-1, 0)); + token_list.push_back(tk); + } else { + std::vector> pre_tokens; + std::vector best_tokens; + double max_score = 0.0F; + DFS(token, 0, pre_tokens, token_list, best_tokens, max_score, true); + } + + if (token_list.size() < 2) { + fine_tokens.push_back(token); + fine_positions.emplace_back(start_pos, end_pos); + continue; + } + + std::vector, double>> sorted_tokens; + SortTokens(token_list, sorted_tokens); + const auto &stk = sorted_tokens[1].first; + + if (stk.size() == token_len) { + fine_tokens.push_back(token); + fine_positions.emplace_back(start_pos, end_pos); + } else if (re2::RE2::PartialMatch(token, pattern5_)) { + bool need_append_stk = true; + for (auto &t : stk) { + if (UTF8Length(t) < 3) { + fine_tokens.push_back(token); + fine_positions.emplace_back(start_pos, end_pos); + need_append_stk = false; + break; + } + } + if (need_append_stk) { + unsigned sub_pos = start_pos; + for (auto &t : stk) { + unsigned sub_end = sub_pos + UTF8Length(t); + fine_tokens.push_back(t); + fine_positions.emplace_back(sub_pos, sub_end); + sub_pos = sub_end; + } + } + } else { + unsigned sub_pos = start_pos; + for (auto &t : stk) { + unsigned sub_end = sub_pos + static_cast(t.size()); + fine_tokens.push_back(t); + fine_positions.emplace_back(sub_pos, sub_end); + sub_pos = sub_end; + } + } + } + } + + // Apply English normalization only if needed, similar to FineGrainedTokenize + // For Chinese text, no additional normalization needed + // fine_tokens already contains the correct Chinese tokens +} + +void RAGAnalyzer::FineGrainedTokenize(const std::string &tokens, std::vector &result) const { + std::vector tks; + Split(tokens, blank_pattern_, tks); + std::vector res; + std::size_t zh_num = 0; + for (auto &token : tks) { + int len = UTF8Length(token); + for (int i = 0; i < len; ++i) { + std::string t = UTF8Substr(token, i, 1); + if (IsChinese(t)) { + zh_num++; + } + } + } + if (zh_num < tks.size() * 0.2) { + for (auto &token : tks) { + std::istringstream iss(token); + std::string sub_token; + while (std::getline(iss, sub_token, '/')) { + result.push_back(sub_token); + } + } + // std::string ret = Join(res, 0); + return; + } + + for (auto &token : tks) { + const auto token_len = UTF8Length(token); + if (token_len < 3 || re2::RE2::PartialMatch(token, pattern4_)) { + //[0-9,\\.-]+$ + res.push_back(token); + continue; + } + std::vector>> token_list; + if (token_len > 10) { + std::vector> tk; + tk.emplace_back(token, Encode(-1, 0)); + token_list.push_back(tk); + } else { + std::vector> pre_tokens; + std::vector best_tokens; + double max_score = 0.0F; +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(token, 0, pre_tokens, token_list, best_tokens, max_score, true); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + auto get_dfs_sorted_tokens = [&]() { + std::vector, double>> sorted_tokens; + SortTokens(token_list, sorted_tokens); + return sorted_tokens; + }; + dp_debug::CheckDP2(this, token, get_dfs_sorted_tokens, t0, t1); +#endif + } + if (token_list.size() < 2) { + res.push_back(token); + continue; + } + std::vector, double>> sorted_tokens; + SortTokens(token_list, sorted_tokens); + const auto &stk = sorted_tokens[1].first; + if (stk.size() == token_len) { + res.push_back(token); + } else if (re2::RE2::PartialMatch(token, pattern5_)) { + // [a-z\\.-]+ + bool need_append_stk = true; + for (auto &t : stk) { + if (UTF8Length(t) < 3) { + res.push_back(token); + need_append_stk = false; + break; + } + } + if (need_append_stk) { + for (auto &t : stk) { + res.push_back(t); + } + } + } else { + for (auto &t : stk) { + res.push_back(t); + } + } + } + EnglishNormalize(res, result); + // std::string ret = Join(normalize_res, 0); + // return ret; +} + +int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, bool fine_grained, bool enable_position, HookType func) const { + if (enable_position) { + auto [tokens, positions] = TokenizeWithPosition(input.text_); + + if (fine_grained) { + std::vector fine_tokens; + std::vector> fine_positions; + FineGrainedTokenizeWithPosition(Join(tokens, 0), positions, fine_tokens, fine_positions); + tokens = std::move(fine_tokens); + positions = std::move(fine_positions); + } + + for (size_t i = 0; i < tokens.size(); ++i) { + if (tokens[i].empty()) + continue; + const auto &[start_pos, end_pos] = positions[i]; + func(data, tokens[i].c_str(), tokens[i].size(), start_pos, end_pos, false, 0); + } + } else { + std::string result = Tokenize(input.text_); + std::vector tokens; + if (fine_grained) { + FineGrainedTokenize(result, tokens); + } else { + Split(result, blank_pattern_, tokens); + } + unsigned offset = 0; + for (auto &t : tokens) { + if (t.empty()) + continue; + func(data, t.c_str(), t.size(), offset++, 0, false, 0); + } + } + return 0; +} \ No newline at end of file diff --git a/internal/cpp/rag_analyzer.h b/internal/cpp/rag_analyzer.h new file mode 100644 index 00000000000..70331445de4 --- /dev/null +++ b/internal/cpp/rag_analyzer.h @@ -0,0 +1,177 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencc/openccxx.h" +#include "stemmer/stemmer.h" +#include "term.h" +#include "re2/re2.h" +#include "dart_trie.h" +#include "wordnet_lemmatizer.h" +#include "analyzer.h" +#include +#include +#include +#include +#include + +// C++ reimplementation of +// https://github.com/infiniflow/ragflow/blob/main/rag/nlp/rag_tokenizer.py + +typedef void (*HookType)(void* data, + const char* text, + const uint32_t len, + const uint32_t offset, + const uint32_t end_offset, + const bool is_special_char, + const uint16_t payload); + +class NLTKWordTokenizer; + +class RAGAnalyzer : public Analyzer +{ +public: + explicit + RAGAnalyzer(const std::string& path); + + RAGAnalyzer(const RAGAnalyzer& other); + + ~RAGAnalyzer(); + + void InitStemmer(Language language) { stemmer_->Init(language); } + + int32_t Load(); + + void SetFineGrained(bool fine_grained) { fine_grained_ = fine_grained; } + + void SetEnablePosition(bool enable_position) { enable_position_ = enable_position; } + + std::pair, std::vector>> TokenizeWithPosition( + const std::string& line) const; + std::string Tokenize(const std::string& line) const; + + void FineGrainedTokenize(const std::string& tokens, std::vector& result) const; + + void TokenizeInnerWithPosition(const std::string& L, + std::vector& tokens, + std::vector>& positions, + unsigned base_pos, + const std::vector* pos_mapping = nullptr) const; + void FineGrainedTokenizeWithPosition(const std::string& tokens_str, + const std::vector>& positions, + std::vector& fine_tokens, + std::vector>& fine_positions) const; + void EnglishNormalizeWithPosition(const std::vector& tokens, + const std::vector>& positions, + std::vector& normalize_tokens, + std::vector>& normalize_positions) const; + unsigned MapToOriginalPosition(unsigned processed_pos, + const std::vector>& mapping) const; + void MergeWithPosition(const std::vector& tokens, + const std::vector>& positions, + std::vector& merged_tokens, + std::vector>& merged_positions) const; + + void SplitByLang(const std::string& line, std::vector>& txt_lang_pairs) const; + + int32_t Freq(std::string_view key) const; + std::string Tag(std::string_view key) const; + +protected: + int AnalyzeImpl(const Term& input, void* data, bool fine_grained, bool enable_position, HookType func) const; + +private: + static constexpr float DENOMINATOR = 1000000; + + static std::string StrQ2B(const std::string& input); + + static void BuildPositionMapping(const std::string& original, const std::string& converted, + std::vector& pos_mapping); + + + static std::string Key(std::string_view line); + + static std::string RKey(std::string_view line); + + static std::pair, double> Score( + const std::vector>& token_freqs); + + static void SortTokens(const std::vector>>& token_list, + std::vector, double>>& res); + + std::pair, double> MaxForward(const std::string& line) const; + + std::pair, double> MaxBackward(const std::string& line) const; + + int DFS(const std::string& chars, + int s, + std::vector>& pre_tokens, + std::vector>>& token_list, + std::vector& best_tokens, + double& max_score, + bool memo_all) const; + + void TokenizeInner(std::vector& res, const std::string& L) const; + + void SplitLongText(const std::string& L, uint32_t length, std::vector& sublines) const; + + [[nodiscard]] std::string Merge(const std::string& tokens) const; + + void EnglishNormalize(const std::vector& tokens, std::vector& res) const; + +public: + [[nodiscard]] std::vector, double>> GetBestTokensTopN( + std::string_view chars, uint32_t n) const; + + static const size_t term_string_buffer_limit_ = 4096 * 3; + + std::string dict_path_; + + bool own_dict_{}; + + DartsTrie* trie_{nullptr}; + + POSTable* pos_table_{nullptr}; + + WordNetLemmatizer* wordnet_lemma_{nullptr}; + + std::unique_ptr stemmer_; + + OpenCC* opencc_{nullptr}; + + bool fine_grained_{false}; + + bool enable_position_{false}; + + static inline re2::RE2 pattern1_{"[a-zA-Z_-]+$"}; + + static inline re2::RE2 pattern2_{"[a-zA-Z\\.-]+$"}; + + static inline re2::RE2 pattern3_{"[0-9\\.-]+$"}; + + static inline re2::RE2 pattern4_{"[0-9,\\.-]+$"}; + + static inline re2::RE2 pattern5_{"[a-zA-Z\\.-]+"}; + + static inline re2::RE2 regex_split_pattern_{ + R"#(([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+))#" + }; + + static inline re2::RE2 blank_pattern_{"( )"}; + + static inline re2::RE2 replace_space_pattern_{R"#(([ ]+))#"}; +}; + +void SentenceSplitter(const std::string& text, std::vector& result); diff --git a/internal/cpp/rag_analyzer_c_api.cpp b/internal/cpp/rag_analyzer_c_api.cpp new file mode 100644 index 00000000000..3ed07dc49e2 --- /dev/null +++ b/internal/cpp/rag_analyzer_c_api.cpp @@ -0,0 +1,225 @@ +// C API implementation for RAGAnalyzer + +#include "rag_analyzer_c_api.h" +#include "rag_analyzer.h" +#include "term.h" +#include +#include +#include + +extern "C" { + +RAGAnalyzerHandle RAGAnalyzer_Create(const char* path) { + if (!path) return nullptr; + try { + RAGAnalyzer* analyzer = new RAGAnalyzer(std::string(path)); + return static_cast(analyzer); + } catch (...) { + return nullptr; + } +} + +void RAGAnalyzer_Destroy(RAGAnalyzerHandle handle) { + if (handle) { + RAGAnalyzer* analyzer = static_cast(handle); + delete analyzer; + } +} + +int RAGAnalyzer_Load(RAGAnalyzerHandle handle) { + if (!handle) return -1; + RAGAnalyzer* analyzer = static_cast(handle); + return analyzer->Load(); +} + +void RAGAnalyzer_SetFineGrained(RAGAnalyzerHandle handle, bool fine_grained) { + if (!handle) return; + RAGAnalyzer* analyzer = static_cast(handle); + analyzer->SetFineGrained(fine_grained); +} + +void RAGAnalyzer_SetEnablePosition(RAGAnalyzerHandle handle, bool enable_position) { + if (!handle) return; + RAGAnalyzer* analyzer = static_cast(handle); + analyzer->SetEnablePosition(enable_position); +} + +int RAGAnalyzer_Analyze(RAGAnalyzerHandle handle, const char* text, RAGTokenCallback callback) { + if (!handle || !text || !callback) return -1; + + RAGAnalyzer* analyzer = static_cast(handle); + + Term input; + input.text_ = std::string(text); + + TermList output; + // Use the analyzer's internal state for fine_grained and enable_position + int ret = analyzer->Analyze(input, output, analyzer->fine_grained_, analyzer->enable_position_); + + if (ret != 0) { + return ret; + } + + // Call callback for each token + for (const auto& term : output) { + callback(term.text_.c_str(), term.text_.length(), term.word_offset_, term.end_offset_); + } + + return 0; +} + +char* RAGAnalyzer_Tokenize(RAGAnalyzerHandle handle, const char* text) { + if (!handle || !text) return nullptr; + + RAGAnalyzer* analyzer = static_cast(handle); + + std::string result = analyzer->Tokenize(std::string(text)); + + // Allocate memory for C string + char* c_result = static_cast(malloc(result.size() + 1)); + if (c_result) { + std::memcpy(c_result, result.c_str(), result.size() + 1); + } + return c_result; +} + +RAGTokenList* RAGAnalyzer_TokenizeWithPosition(RAGAnalyzerHandle handle, const char* text) { + if (!handle || !text) return nullptr; + + RAGAnalyzer* analyzer = static_cast(handle); + + Term input; + input.text_ = std::string(text); + + TermList output; + // Pass fine_grained and enable_position=true to get position information + analyzer->Analyze(input, output, analyzer->fine_grained_, true); + + // Allocate memory for the token list structure + RAGTokenList* token_list = static_cast(malloc(sizeof(RAGTokenList))); + if (!token_list) { + return nullptr; + } + + // Allocate memory for the tokens array + token_list->tokens = static_cast( + malloc(sizeof(RAGTokenWithPosition) * output.size()) + ); + if (!token_list->tokens) { + free(token_list); + return nullptr; + } + + token_list->count = static_cast(output.size()); + + // Fill in the tokens + for (size_t i = 0; i < output.size(); ++i) { + // Allocate memory for the text and copy it + token_list->tokens[i].text = static_cast( + malloc(output[i].text_.size() + 1) + ); + if (token_list->tokens[i].text) { + std::memcpy(token_list->tokens[i].text, + output[i].text_.c_str(), + output[i].text_.size() + 1); + } + token_list->tokens[i].offset = output[i].word_offset_; + token_list->tokens[i].end_offset = output[i].end_offset_; + } + + return token_list; +} + +void RAGAnalyzer_FreeTokenList(RAGTokenList* token_list) { + if (!token_list) return; + + if (token_list->tokens) { + for (uint32_t i = 0; i < token_list->count; ++i) { + if (token_list->tokens[i].text) { + free(token_list->tokens[i].text); + } + } + free(token_list->tokens); + } + free(token_list); +} + +// Helper functions to access token fields +const char* RAGToken_GetText(void* token) { + if (!token) return nullptr; + RAGTokenWithPosition* t = static_cast(token); + return t->text; +} + +uint32_t RAGToken_GetOffset(void* token) { + if (!token) return 0; + RAGTokenWithPosition* t = static_cast(token); + return t->offset; +} + +uint32_t RAGToken_GetEndOffset(void* token) { + if (!token) return 0; + RAGTokenWithPosition* t = static_cast(token); + return t->end_offset; +} + +char* RAGAnalyzer_FineGrainedTokenize(RAGAnalyzerHandle handle, const char* tokens) { + if (!handle || !tokens) return nullptr; + + RAGAnalyzer* analyzer = static_cast(handle); + + std::vector result; + analyzer->FineGrainedTokenize(std::string(tokens), result); + + // Join results with space + std::string result_str; + for (size_t i = 0; i < result.size(); ++i) { + if (i > 0) result_str += " "; + result_str += result[i]; + } + + // Allocate memory for C string + char* c_result = static_cast(malloc(result_str.size() + 1)); + if (c_result) { + std::memcpy(c_result, result_str.c_str(), result_str.size() + 1); + } + return c_result; +} + +int32_t RAGAnalyzer_GetTermFreq(RAGAnalyzerHandle handle, const char* term) { + if (!handle || !term) return 0; + + RAGAnalyzer* analyzer = static_cast(handle); + return analyzer->Freq(term); +} + +char* RAGAnalyzer_GetTermTag(RAGAnalyzerHandle handle, const char* term) { + if (!handle || !term) return nullptr; + + RAGAnalyzer* analyzer = static_cast(handle); + std::string tag_result = analyzer->Tag(term); + + if (tag_result.empty()) { + return nullptr; + } + + // Allocate memory for C string + char* c_result = static_cast(malloc(tag_result.size() + 1)); + if (c_result) { + std::memcpy(c_result, tag_result.c_str(), tag_result.size() + 1); + } + return c_result; +} + +RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle) { + if (!handle) return nullptr; + try { + RAGAnalyzer* original = static_cast(handle); + RAGAnalyzer* copy = new RAGAnalyzer(*original); + return static_cast(copy); + } catch (...) { + return nullptr; + } +} + +} // extern "C" diff --git a/internal/cpp/rag_analyzer_c_api.h b/internal/cpp/rag_analyzer_c_api.h new file mode 100644 index 00000000000..2a874000134 --- /dev/null +++ b/internal/cpp/rag_analyzer_c_api.h @@ -0,0 +1,106 @@ +// C API wrapper for RAGAnalyzer +// This file provides C-compatible interface for CGO to call + +#ifndef RAG_ANALYZER_C_API_H +#define RAG_ANALYZER_C_API_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +// Opaque pointer to RAGAnalyzer +typedef void* RAGAnalyzerHandle; + +// Callback function type for receiving tokens +typedef void (*RAGTokenCallback)( + const char* text, + uint32_t len, + uint32_t offset, + uint32_t end_offset +); + +// Create a new RAGAnalyzer instance +// path: path to dictionary files +// Returns: handle to the analyzer, or NULL on failure +RAGAnalyzerHandle RAGAnalyzer_Create(const char* path); + +// Destroy a RAGAnalyzer instance +void RAGAnalyzer_Destroy(RAGAnalyzerHandle handle); + +// Load the analyzer (must be called before Analyze) +// Returns: 0 on success, negative value on failure +int RAGAnalyzer_Load(RAGAnalyzerHandle handle); + +// Set fine-grained mode +void RAGAnalyzer_SetFineGrained(RAGAnalyzerHandle handle, bool fine_grained); + +// Set enable position tracking +void RAGAnalyzer_SetEnablePosition(RAGAnalyzerHandle handle, bool enable_position); + +// Analyze text and call callback for each token +// Returns: 0 on success, negative value on failure +int RAGAnalyzer_Analyze( + RAGAnalyzerHandle handle, + const char* text, + RAGTokenCallback callback +); + +// Simple analyze that returns tokens as a single space-separated string +// Caller is responsible for freeing the returned string +// Returns: dynamically allocated string (must call free()), or NULL on failure +char* RAGAnalyzer_Tokenize(RAGAnalyzerHandle handle, const char* text); + +// Structure for a token with position information +typedef struct { + char* text; // Token text (must be freed with free()) + uint32_t offset; // Byte offset of the token in the original text + uint32_t end_offset; // Byte end offset of the token +} RAGTokenWithPosition; + +// Helper functions to access token fields (for CGO) +const char* RAGToken_GetText(void* token); +uint32_t RAGToken_GetOffset(void* token); +uint32_t RAGToken_GetEndOffset(void* token); + +// Structure for a list of tokens with positions +typedef struct { + RAGTokenWithPosition* tokens; // Array of tokens (must be freed with RAGAnalyzer_FreeTokenList) + uint32_t count; // Number of tokens in the list +} RAGTokenList; + +// Tokenize with position information +// Caller is responsible for freeing the returned token list with RAGAnalyzer_FreeTokenList +// Returns: dynamically allocated token list (must call RAGAnalyzer_FreeTokenList), or NULL on failure +RAGTokenList* RAGAnalyzer_TokenizeWithPosition(RAGAnalyzerHandle handle, const char* text); + +// Free a token list allocated by RAGAnalyzer_TokenizeWithPosition +void RAGAnalyzer_FreeTokenList(RAGTokenList* token_list); + +// Fine-grained tokenize: takes space-separated tokens and returns fine-grained tokens as space-separated string +// Caller is responsible for freeing the returned string +// Returns: dynamically allocated string (must call free()), or NULL on failure +char* RAGAnalyzer_FineGrainedTokenize(RAGAnalyzerHandle handle, const char* tokens); + +// Get the frequency of a term (matching Python rag_tokenizer.freq) +// Returns: frequency value, or 0 if term not found +int32_t RAGAnalyzer_GetTermFreq(RAGAnalyzerHandle handle, const char* term); + +// Get the POS tag of a term (matching Python rag_tokenizer.tag) +// Caller is responsible for freeing the returned string +// Returns: dynamically allocated string (must call free()), or NULL if term not found or no tag +char* RAGAnalyzer_GetTermTag(RAGAnalyzerHandle handle, const char* term); + +// Copy an existing RAGAnalyzer instance to create a new independent instance +// This is useful for creating per-request analyzer instances in multi-threaded environments +// The new instance shares the loaded dictionaries with the original but has independent internal state +// Returns: handle to the new analyzer instance, or NULL on failure +RAGAnalyzerHandle RAGAnalyzer_Copy(RAGAnalyzerHandle handle); + +#ifdef __cplusplus +} +#endif + +#endif // RAG_ANALYZER_C_API_H diff --git a/internal/cpp/rag_analyzer_c_api_debug.cpp b/internal/cpp/rag_analyzer_c_api_debug.cpp new file mode 100644 index 00000000000..d083382646d --- /dev/null +++ b/internal/cpp/rag_analyzer_c_api_debug.cpp @@ -0,0 +1,168 @@ +// Debug version of C API with memory tracking +// Compile with: -DMEMORY_DEBUG to enable tracking + +#include "rag_analyzer_c_api.h" +#include "rag_analyzer.h" +#include "term.h" +#include +#include +#include +#include + +#ifdef MEMORY_DEBUG +#include +#include + +static std::mutex g_memory_mutex; +static std::map g_allocations; +static size_t g_total_allocated = 0; +static size_t g_total_freed = 0; + +void* debug_malloc(size_t size, const char* file, int line) { + void* ptr = malloc(size); + std::lock_guard lock(g_memory_mutex); + g_allocations[ptr] = size; + g_total_allocated += size; + fprintf(stderr, "[MEM_DEBUG] ALLOC: %p (%zu bytes) at %s:%d\n", ptr, size, file, line); + return ptr; +} + +void debug_free(void* ptr, const char* file, int line) { + if (!ptr) return; + { + std::lock_guard lock(g_memory_mutex); + auto it = g_allocations.find(ptr); + if (it != g_allocations.end()) { + g_total_freed += it->second; + g_allocations.erase(it); + } + } + fprintf(stderr, "[MEM_DEBUG] FREE: %p at %s:%d\n", ptr, file, line); + free(ptr); +} + +void print_memory_stats() { + std::lock_guard lock(g_memory_mutex); + fprintf(stderr, "\n[MEM_DEBUG] ===== Memory Statistics =====\n"); + fprintf(stderr, "[MEM_DEBUG] Total allocated: %zu bytes\n", g_total_allocated); + fprintf(stderr, "[MEM_DEBUG] Total freed: %zu bytes\n", g_total_freed); + fprintf(stderr, "[MEM_DEBUG] Current usage: %zu bytes\n", g_total_allocated - g_total_freed); + fprintf(stderr, "[MEM_DEBUG] Active allocations: %zu\n", g_allocations.size()); + if (!g_allocations.empty()) { + fprintf(stderr, "[MEM_DEBUG] Active blocks:\n"); + for (const auto& [ptr, size] : g_allocations) { + fprintf(stderr, "[MEM_DEBUG] %p: %zu bytes\n", ptr, size); + } + } + fprintf(stderr, "[MEM_DEBUG] ============================\n\n"); +} + +#define DEBUG_MALLOC(size) debug_malloc(size, __FILE__, __LINE__) +#define DEBUG_FREE(ptr) debug_free(ptr, __FILE__, __LINE__) + +#else + +#define DEBUG_MALLOC(size) malloc(size) +#define DEBUG_FREE(ptr) free(ptr) +void print_memory_stats() {} + +#endif + +extern "C" { + +RAGAnalyzerHandle RAGAnalyzer_Create(const char* path) { + if (!path) return nullptr; + try { + RAGAnalyzer* analyzer = new RAGAnalyzer(std::string(path)); + fprintf(stderr, "[C_API] Created analyzer: %p\n", (void*)analyzer); + return static_cast(analyzer); + } catch (...) { + fprintf(stderr, "[C_API] Failed to create analyzer\n"); + return nullptr; + } +} + +void RAGAnalyzer_Destroy(RAGAnalyzerHandle handle) { + if (handle) { + fprintf(stderr, "[C_API] Destroying analyzer: %p\n", handle); + RAGAnalyzer* analyzer = static_cast(handle); + delete analyzer; + } +} + +int RAGAnalyzer_Load(RAGAnalyzerHandle handle) { + if (!handle) return -1; + RAGAnalyzer* analyzer = static_cast(handle); + int ret = analyzer->Load(); + fprintf(stderr, "[C_API] Load result: %d\n", ret); + return ret; +} + +void RAGAnalyzer_SetFineGrained(RAGAnalyzerHandle handle, bool fine_grained) { + if (!handle) return; + RAGAnalyzer* analyzer = static_cast(handle); + analyzer->SetFineGrained(fine_grained); + fprintf(stderr, "[C_API] SetFineGrained: %d\n", fine_grained); +} + +void RAGAnalyzer_SetEnablePosition(RAGAnalyzerHandle handle, bool enable_position) { + if (!handle) return; + RAGAnalyzer* analyzer = static_cast(handle); + analyzer->SetEnablePosition(enable_position); + fprintf(stderr, "[C_API] SetEnablePosition: %d\n", enable_position); +} + +int RAGAnalyzer_Analyze(RAGAnalyzerHandle handle, const char* text, RAGTokenCallback callback) { + if (!handle || !text || !callback) return -1; + + fprintf(stderr, "[C_API] Analyze called with text length: %zu\n", strlen(text)); + + RAGAnalyzer* analyzer = static_cast(handle); + + Term input; + input.text_ = std::string(text); + + TermList output; + int ret = analyzer->Analyze(input, output); + + fprintf(stderr, "[C_API] Analyze returned: %d, tokens: %zu\n", ret, output.size()); + + if (ret != 0) { + return ret; + } + + // Call callback for each token + for (const auto& term : output) { + callback(term.text_.c_str(), term.text_.length(), term.word_offset_, term.end_offset_); + } + + return 0; +} + +char* RAGAnalyzer_Tokenize(RAGAnalyzerHandle handle, const char* text) { + if (!handle || !text) { + fprintf(stderr, "[C_API] Tokenize called with null handle or text\n"); + return nullptr; + } + + fprintf(stderr, "[C_API] Tokenize called with text length: %zu\n", strlen(text)); + + RAGAnalyzer* analyzer = static_cast(handle); + + std::string result = analyzer->Tokenize(std::string(text)); + + // Allocate memory for C string + char* c_result = static_cast(DEBUG_MALLOC(result.size() + 1)); + if (c_result) { + std::memcpy(c_result, result.c_str(), result.size() + 1); + fprintf(stderr, "[C_API] Tokenize allocated result: %p\n", (void*)c_result); + } + return c_result; +} + +// Debug function to print memory stats +void RAGAnalyzer_PrintMemoryStats() { + print_memory_stats(); +} + +} // extern "C" diff --git a/internal/cpp/rag_analyzer_c_test.cpp b/internal/cpp/rag_analyzer_c_test.cpp new file mode 100644 index 00000000000..f62401a68e6 --- /dev/null +++ b/internal/cpp/rag_analyzer_c_test.cpp @@ -0,0 +1,120 @@ +#include +#include +#include +#include +#include +#include "rag_analyzer_c_api.h" + +// Test case 1: Single thread, loop 1000 times +void test_single_thread() { + std::cout << "Test 1: Single thread, 1000 iterations..." << std::endl; + + // Create analyzer instance + RAGAnalyzerHandle handle = RAGAnalyzer_Create("."); + assert(handle != nullptr && "Failed to create RAGAnalyzer"); + + // Load the analyzer + int result = RAGAnalyzer_Load(handle); + if (result != 0) { + printf("Failed to load RAGAnalyzer: %d\n", result); + } + assert(result == 0 && "Failed to load RAGAnalyzer"); + + const char* input = "rag"; + bool all_passed = true; + + for (int i = 0; i < 1000; ++i) { + char* tokens = RAGAnalyzer_Tokenize(handle, input); + + if (tokens == nullptr || strlen(tokens) == 0) { + std::cerr << "Iteration " << i << ": Failed - returned empty or null string" << std::endl; + all_passed = false; + } + + // Free the returned string + if (tokens != nullptr) { + free(tokens); + } + } + + // Destroy analyzer instance + RAGAnalyzer_Destroy(handle); + + if (all_passed) { + std::cout << "Test 1: PASSED" << std::endl; + } else { + std::cout << "Test 1: FAILED" << std::endl; + exit(1); + } +} + +// Test case 2: 16 threads, each loop 1000 times +void test_multi_thread() { + std::cout << "Test 2: 32 threads, each 100000 iterations..." << std::endl; + + // Create analyzer instance (shared across threads) + RAGAnalyzerHandle handle = RAGAnalyzer_Create("."); + assert(handle != nullptr && "Failed to create RAGAnalyzer"); + + // Load the analyzer + int result = RAGAnalyzer_Load(handle); + assert(result == 0 && "Failed to load RAGAnalyzer"); + + const char* input = "rag"; + const int num_threads = 32; + const int iterations_per_thread = 100000; + + std::vector threads; + std::vector thread_results(num_threads, true); + + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back([&, t]() { + for (int i = 0; i < iterations_per_thread; ++i) { + char* tokens = RAGAnalyzer_Tokenize(handle, input); + + if (tokens == nullptr || strlen(tokens) == 0) { + std::cerr << "Thread " << t << " Iteration " << i << ": Failed - returned empty or null string" << std::endl; + thread_results[t] = false; + } + + // Free the returned string + if (tokens != nullptr) { + free(tokens); + } + } + }); + } + + // Wait for all threads to complete + for (auto& t : threads) { + t.join(); + } + + // Destroy analyzer instance + RAGAnalyzer_Destroy(handle); + + bool all_passed = true; + for (int t = 0; t < num_threads; ++t) { + if (!thread_results[t]) { + all_passed = false; + break; + } + } + + if (all_passed) { + std::cout << "Test 2: PASSED" << std::endl; + } else { + std::cout << "Test 2: FAILED" << std::endl; + exit(1); + } +} + +int main() { + std::cout << "=== RAGAnalyzer C API Test ===" << std::endl; + + test_single_thread(); + // test_multi_thread(); + + std::cout << "=== All tests PASSED ===" << std::endl; + return 0; +} diff --git a/internal/cpp/re2/bitmap256.cc b/internal/cpp/re2/bitmap256.cc new file mode 100644 index 00000000000..9f402ee6f36 --- /dev/null +++ b/internal/cpp/re2/bitmap256.cc @@ -0,0 +1,44 @@ +// Copyright 2023 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/bitmap256.h" + +#include + +#include "util/logging.h" +#include "util/util.h" + +namespace re2 { + +int Bitmap256::FindNextSetBit(int c) const { + DCHECK_GE(c, 0); + DCHECK_LE(c, 255); + + // Check the word that contains the bit. Mask out any lower bits. + int i = c / 64; + uint64_t word = words_[i] & (~uint64_t{0} << (c % 64)); + if (word != 0) + return (i * 64) + FindLSBSet(word); + + // Check any following words. + i++; + switch (i) { + case 1: + if (words_[1] != 0) + return (1 * 64) + FindLSBSet(words_[1]); + FALLTHROUGH_INTENDED; + case 2: + if (words_[2] != 0) + return (2 * 64) + FindLSBSet(words_[2]); + FALLTHROUGH_INTENDED; + case 3: + if (words_[3] != 0) + return (3 * 64) + FindLSBSet(words_[3]); + FALLTHROUGH_INTENDED; + default: + return -1; + } +} + +} // namespace re2 diff --git a/internal/cpp/re2/bitmap256.h b/internal/cpp/re2/bitmap256.h new file mode 100644 index 00000000000..d6f535b264b --- /dev/null +++ b/internal/cpp/re2/bitmap256.h @@ -0,0 +1,82 @@ +// Copyright 2016 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_BITMAP256_H_ +#define RE2_BITMAP256_H_ + +#ifdef _MSC_VER +#include +#endif +#include +#include + +#include "util/logging.h" + +namespace re2 { + +class Bitmap256 { +public: + Bitmap256() { Clear(); } + + // Clears all of the bits. + void Clear() { memset(words_, 0, sizeof words_); } + + // Tests the bit with index c. + bool Test(int c) const { + DCHECK_GE(c, 0); + DCHECK_LE(c, 255); + + return (words_[c / 64] & (uint64_t{1} << (c % 64))) != 0; + } + + // Sets the bit with index c. + void Set(int c) { + DCHECK_GE(c, 0); + DCHECK_LE(c, 255); + + words_[c / 64] |= (uint64_t{1} << (c % 64)); + } + + // Finds the next non-zero bit with index >= c. + // Returns -1 if no such bit exists. + int FindNextSetBit(int c) const; + +private: + // Finds the least significant non-zero bit in n. + static int FindLSBSet(uint64_t n) { + DCHECK_NE(n, 0); +#if defined(__GNUC__) + return __builtin_ctzll(n); +#elif defined(_MSC_VER) && defined(_M_X64) + unsigned long c; + _BitScanForward64(&c, n); + return static_cast(c); +#elif defined(_MSC_VER) && defined(_M_IX86) + unsigned long c; + if (static_cast(n) != 0) { + _BitScanForward(&c, static_cast(n)); + return static_cast(c); + } else { + _BitScanForward(&c, static_cast(n >> 32)); + return static_cast(c) + 32; + } +#else + int c = 63; + for (int shift = 1 << 5; shift != 0; shift >>= 1) { + uint64_t word = n << shift; + if (word != 0) { + n = word; + c -= shift; + } + } + return c; +#endif + } + + uint64_t words_[4]; +}; + +} // namespace re2 + +#endif // RE2_BITMAP256_H_ diff --git a/internal/cpp/re2/bitstate.cc b/internal/cpp/re2/bitstate.cc new file mode 100644 index 00000000000..322c4edae49 --- /dev/null +++ b/internal/cpp/re2/bitstate.cc @@ -0,0 +1,362 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tested by search_test.cc, exhaustive_test.cc, tester.cc + +// Prog::SearchBitState is a regular expression search with submatch +// tracking for small regular expressions and texts. Similarly to +// testing/backtrack.cc, it allocates a bitmap with (count of +// lists) * (length of text) bits to make sure it never explores the +// same (instruction list, character position) multiple times. This +// limits the search to run in time linear in the length of the text. +// +// Unlike testing/backtrack.cc, SearchBitState is not recursive +// on the text. +// +// SearchBitState is a fast replacement for the NFA code on small +// regexps and texts when SearchOnePass cannot be used. + +#include +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/regexp.h" +#include "util/logging.h" + +namespace re2 { + +struct Job { + int id; + int rle; // run length encoding + const char *p; +}; + +class BitState { +public: + explicit BitState(Prog *prog); + + // The usual Search prototype. + // Can only call Search once per BitState. + bool Search(const StringPiece &text, const StringPiece &context, bool anchored, bool longest, StringPiece *submatch, int nsubmatch); + +private: + inline bool ShouldVisit(int id, const char *p); + void Push(int id, const char *p); + void GrowStack(); + bool TrySearch(int id, const char *p); + + // Search parameters + Prog *prog_; // program being run + StringPiece text_; // text being searched + StringPiece context_; // greater context of text being searched + bool anchored_; // whether search is anchored at text.begin() + bool longest_; // whether search wants leftmost-longest match + bool endmatch_; // whether match must end at text.end() + StringPiece *submatch_; // submatches to fill in + int nsubmatch_; // # of submatches to fill in + + // Search state + static constexpr int kVisitedBits = 64; + PODArray visited_; // bitmap: (list ID, char*) pairs visited + PODArray cap_; // capture registers + PODArray job_; // stack of text positions to explore + int njob_; // stack size + + BitState(const BitState &) = delete; + BitState &operator=(const BitState &) = delete; +}; + +BitState::BitState(Prog *prog) : prog_(prog), anchored_(false), longest_(false), endmatch_(false), submatch_(NULL), nsubmatch_(0), njob_(0) {} + +// Given id, which *must* be a list head, we can look up its list ID. +// Then the question is: Should the search visit the (list ID, p) pair? +// If so, remember that it was visited so that the next time, +// we don't repeat the visit. +bool BitState::ShouldVisit(int id, const char *p) { + int n = prog_->list_heads()[id] * static_cast(text_.size() + 1) + static_cast(p - text_.data()); + if (visited_[n / kVisitedBits] & (uint64_t{1} << (n & (kVisitedBits - 1)))) + return false; + visited_[n / kVisitedBits] |= uint64_t{1} << (n & (kVisitedBits - 1)); + return true; +} + +// Grow the stack. +void BitState::GrowStack() { + PODArray tmp(2 * job_.size()); + memmove(tmp.data(), job_.data(), njob_ * sizeof job_[0]); + job_ = std::move(tmp); +} + +// Push (id, p) onto the stack, growing it if necessary. +void BitState::Push(int id, const char *p) { + if (njob_ >= job_.size()) { + GrowStack(); + if (njob_ >= job_.size()) { + LOG(DFATAL) << "GrowStack() failed: " + << "njob_ = " << njob_ << ", " + << "job_.size() = " << job_.size(); + return; + } + } + + // If id < 0, it's undoing a Capture, + // so we mustn't interfere with that. + if (id >= 0 && njob_ > 0) { + Job *top = &job_[njob_ - 1]; + if (id == top->id && p == top->p + top->rle + 1 && top->rle < std::numeric_limits::max()) { + ++top->rle; + return; + } + } + + Job *top = &job_[njob_++]; + top->id = id; + top->rle = 0; + top->p = p; +} + +// Try a search from instruction id0 in state p0. +// Return whether it succeeded. +bool BitState::TrySearch(int id0, const char *p0) { + bool matched = false; + const char *end = text_.data() + text_.size(); + njob_ = 0; + // Push() no longer checks ShouldVisit(), + // so we must perform the check ourselves. + if (ShouldVisit(id0, p0)) + Push(id0, p0); + while (njob_ > 0) { + // Pop job off stack. + --njob_; + int id = job_[njob_].id; + int &rle = job_[njob_].rle; + const char *p = job_[njob_].p; + + if (id < 0) { + // Undo the Capture. + cap_[prog_->inst(-id)->cap()] = p; + continue; + } + + if (rle > 0) { + p += rle; + // Revivify job on stack. + --rle; + ++njob_; + } + + Loop: + // Visit id, p. + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "Unexpected opcode: " << ip->opcode(); + return false; + + case kInstFail: + break; + + case kInstAltMatch: + if (ip->greedy(prog_)) { + // out1 is the Match instruction. + id = ip->out1(); + p = end; + goto Loop; + } + if (longest_) { + // ip must be non-greedy... + // out is the Match instruction. + id = ip->out(); + p = end; + goto Loop; + } + goto Next; + + case kInstByteRange: { + int c = -1; + if (p < end) + c = *p & 0xFF; + if (!ip->Matches(c)) + goto Next; + + if (ip->hint() != 0) + Push(id + ip->hint(), p); // try the next when we're done + id = ip->out(); + p++; + goto CheckAndLoop; + } + + case kInstCapture: + if (!ip->last()) + Push(id + 1, p); // try the next when we're done + + if (0 <= ip->cap() && ip->cap() < cap_.size()) { + // Capture p to register, but save old value first. + Push(-id, cap_[ip->cap()]); // undo when we're done + cap_[ip->cap()] = p; + } + + id = ip->out(); + goto CheckAndLoop; + + case kInstEmptyWidth: + if (ip->empty() & ~Prog::EmptyFlags(context_, p)) + goto Next; + + if (!ip->last()) + Push(id + 1, p); // try the next when we're done + id = ip->out(); + goto CheckAndLoop; + + case kInstNop: + if (!ip->last()) + Push(id + 1, p); // try the next when we're done + id = ip->out(); + + CheckAndLoop: + // Sanity check: id is the head of its list, which must + // be the case if id-1 is the last of *its* list. :) + DCHECK(id == 0 || prog_->inst(id - 1)->last()); + if (ShouldVisit(id, p)) + goto Loop; + break; + + case kInstMatch: { + if (endmatch_ && p != end) + goto Next; + + // We found a match. If the caller doesn't care + // where the match is, no point going further. + if (nsubmatch_ == 0) + return true; + + // Record best match so far. + // Only need to check end point, because this entire + // call is only considering one start position. + matched = true; + cap_[1] = p; + if (submatch_[0].data() == NULL || (longest_ && p > submatch_[0].data() + submatch_[0].size())) { + for (int i = 0; i < nsubmatch_; i++) + submatch_[i] = StringPiece(cap_[2 * i], static_cast(cap_[2 * i + 1] - cap_[2 * i])); + } + + // If going for first match, we're done. + if (!longest_) + return true; + + // If we used the entire text, no longer match is possible. + if (p == end) + return true; + + // Otherwise, continue on in hope of a longer match. + // Note the absence of the ShouldVisit() check here + // due to execution remaining in the same list. + Next: + if (!ip->last()) { + id++; + goto Loop; + } + break; + } + } + } + return matched; +} + +// Search text (within context) for prog_. +bool BitState::Search(const StringPiece &text, const StringPiece &context, bool anchored, bool longest, StringPiece *submatch, int nsubmatch) { + // Search parameters. + text_ = text; + context_ = context; + if (context_.data() == NULL) + context_ = text; + if (prog_->anchor_start() && BeginPtr(context_) != BeginPtr(text)) + return false; + if (prog_->anchor_end() && EndPtr(context_) != EndPtr(text)) + return false; + anchored_ = anchored || prog_->anchor_start(); + longest_ = longest || prog_->anchor_end(); + endmatch_ = prog_->anchor_end(); + submatch_ = submatch; + nsubmatch_ = nsubmatch; + for (int i = 0; i < nsubmatch_; i++) + submatch_[i] = StringPiece(); + + // Allocate scratch space. + int nvisited = prog_->list_count() * static_cast(text.size() + 1); + nvisited = (nvisited + kVisitedBits - 1) / kVisitedBits; + visited_ = PODArray(nvisited); + memset(visited_.data(), 0, nvisited * sizeof visited_[0]); + + int ncap = 2 * nsubmatch; + if (ncap < 2) + ncap = 2; + cap_ = PODArray(ncap); + memset(cap_.data(), 0, ncap * sizeof cap_[0]); + + // When sizeof(Job) == 16, we start with a nice round 1KiB. :) + job_ = PODArray(64); + + // Anchored search must start at text.begin(). + if (anchored_) { + cap_[0] = text.data(); + return TrySearch(prog_->start(), text.data()); + } + + // Unanchored search, starting from each possible text position. + // Notice that we have to try the empty string at the end of + // the text, so the loop condition is p <= text.end(), not p < text.end(). + // This looks like it's quadratic in the size of the text, + // but we are not clearing visited_ between calls to TrySearch, + // so no work is duplicated and it ends up still being linear. + const char *etext = text.data() + text.size(); + for (const char *p = text.data(); p <= etext; p++) { + // Try to use prefix accel (e.g. memchr) to skip ahead. + if (p < etext && prog_->can_prefix_accel()) { + p = reinterpret_cast(prog_->PrefixAccel(p, etext - p)); + if (p == NULL) + p = etext; + } + + cap_[0] = p; + if (TrySearch(prog_->start(), p)) // Match must be leftmost; done. + return true; + // Avoid invoking undefined behavior (arithmetic on a null pointer) + // by simply not continuing the loop. + if (p == NULL) + break; + } + return false; +} + +// Bit-state search. +bool Prog::SearchBitState(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch) { + // If full match, we ask for an anchored longest match + // and then check that match[0] == text. + // So make sure match[0] exists. + StringPiece sp0; + if (kind == kFullMatch) { + anchor = kAnchored; + if (nmatch < 1) { + match = &sp0; + nmatch = 1; + } + } + + // Run the search. + BitState b(this); + bool anchored = anchor == kAnchored; + bool longest = kind != kFirstMatch; + if (!b.Search(text, context, anchored, longest, match, nmatch)) + return false; + if (kind == kFullMatch && EndPtr(match[0]) != EndPtr(text)) + return false; + return true; +} + +} // namespace re2 diff --git a/internal/cpp/re2/compile.cc b/internal/cpp/re2/compile.cc new file mode 100644 index 00000000000..925bf972e41 --- /dev/null +++ b/internal/cpp/re2/compile.cc @@ -0,0 +1,1221 @@ +// Copyright 2007 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Compile regular expression to Prog. +// +// Prog and Inst are defined in prog.h. +// This file's external interface is just Regexp::CompileToProg. +// The Compiler class defined in this file is private. + +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/re2.h" +#include "re2/regexp.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/utf.h" + +namespace re2 { + +// List of pointers to Inst* that need to be filled in (patched). +// Because the Inst* haven't been filled in yet, +// we can use the Inst* word to hold the list's "next" pointer. +// It's kind of sleazy, but it works well in practice. +// See http://swtch.com/~rsc/regexp/regexp1.html for inspiration. +// +// Because the out and out1 fields in Inst are no longer pointers, +// we can't use pointers directly here either. Instead, head refers +// to inst_[head>>1].out (head&1 == 0) or inst_[head>>1].out1 (head&1 == 1). +// head == 0 represents the NULL list. This is okay because instruction #0 +// is always the fail instruction, which never appears on a list. +struct PatchList { + // Returns patch list containing just p. + static PatchList Mk(uint32_t p) { return {p, p}; } + + // Patches all the entries on l to have value p. + // Caller must not ever use patch list again. + static void Patch(Prog::Inst *inst0, PatchList l, uint32_t p) { + while (l.head != 0) { + Prog::Inst *ip = &inst0[l.head >> 1]; + if (l.head & 1) { + l.head = ip->out1(); + ip->out1_ = p; + } else { + l.head = ip->out(); + ip->set_out(p); + } + } + } + + // Appends two patch lists and returns result. + static PatchList Append(Prog::Inst *inst0, PatchList l1, PatchList l2) { + if (l1.head == 0) + return l2; + if (l2.head == 0) + return l1; + Prog::Inst *ip = &inst0[l1.tail >> 1]; + if (l1.tail & 1) + ip->out1_ = l2.head; + else + ip->set_out(l2.head); + return {l1.head, l2.tail}; + } + + uint32_t head; + uint32_t tail; // for constant-time append +}; + +static const PatchList kNullPatchList = {0, 0}; + +// Compiled program fragment. +struct Frag { + uint32_t begin; + PatchList end; + bool nullable; + + Frag() : begin(0), end(kNullPatchList), nullable(false) {} + Frag(uint32_t begin, PatchList end, bool nullable) : begin(begin), end(end), nullable(nullable) {} +}; + +// Input encodings. +enum Encoding { + kEncodingUTF8 = 1, // UTF-8 (0-10FFFF) + kEncodingLatin1, // Latin-1 (0-FF) +}; + +class Compiler : public Regexp::Walker { +public: + explicit Compiler(); + ~Compiler(); + + // Compiles Regexp to a new Prog. + // Caller is responsible for deleting Prog when finished with it. + // If reversed is true, compiles for walking over the input + // string backward (reverses all concatenations). + static Prog *Compile(Regexp *re, bool reversed, int64_t max_mem); + + // Compiles alternation of all the re to a new Prog. + // Each re has a match with an id equal to its index in the vector. + static Prog *CompileSet(Regexp *re, RE2::Anchor anchor, int64_t max_mem); + + // Interface for Regexp::Walker, which helps traverse the Regexp. + // The walk is purely post-recursive: given the machines for the + // children, PostVisit combines them to create the machine for + // the current node. The child_args are Frags. + // The Compiler traverses the Regexp parse tree, visiting + // each node in depth-first order. It invokes PreVisit before + // visiting the node's children and PostVisit after visiting + // the children. + Frag PreVisit(Regexp *re, Frag parent_arg, bool *stop); + Frag PostVisit(Regexp *re, Frag parent_arg, Frag pre_arg, Frag *child_args, int nchild_args); + Frag ShortVisit(Regexp *re, Frag parent_arg); + Frag Copy(Frag arg); + + // Given fragment a, returns a+ or a+?; a* or a*?; a? or a?? + Frag Plus(Frag a, bool nongreedy); + Frag Star(Frag a, bool nongreedy); + Frag Quest(Frag a, bool nongreedy); + + // Given fragment a, returns (a) capturing as \n. + Frag Capture(Frag a, int n); + + // Given fragments a and b, returns ab; a|b + Frag Cat(Frag a, Frag b); + Frag Alt(Frag a, Frag b); + + // Returns a fragment that can't match anything. + Frag NoMatch(); + + // Returns a fragment that matches the empty string. + Frag Match(int32_t id); + + // Returns a no-op fragment. + Frag Nop(); + + // Returns a fragment matching the byte range lo-hi. + Frag ByteRange(int lo, int hi, bool foldcase); + + // Returns a fragment matching an empty-width special op. + Frag EmptyWidth(EmptyOp op); + + // Adds n instructions to the program. + // Returns the index of the first one. + // Returns -1 if no more instructions are available. + int AllocInst(int n); + + // Rune range compiler. + + // Begins a new alternation. + void BeginRange(); + + // Adds a fragment matching the rune range lo-hi. + void AddRuneRange(Rune lo, Rune hi, bool foldcase); + void AddRuneRangeLatin1(Rune lo, Rune hi, bool foldcase); + void AddRuneRangeUTF8(Rune lo, Rune hi, bool foldcase); + void Add_80_10ffff(); + + // New suffix that matches the byte range lo-hi, then goes to next. + int UncachedRuneByteSuffix(uint8_t lo, uint8_t hi, bool foldcase, int next); + int CachedRuneByteSuffix(uint8_t lo, uint8_t hi, bool foldcase, int next); + + // Returns true iff the suffix is cached. + bool IsCachedRuneByteSuffix(int id); + + // Adds a suffix to alternation. + void AddSuffix(int id); + + // Adds a suffix to the trie starting from the given root node. + // Returns zero iff allocating an instruction fails. Otherwise, returns + // the current root node, which might be different from what was given. + int AddSuffixRecursive(int root, int id); + + // Finds the trie node for the given suffix. Returns a Frag in order to + // distinguish between pointing at the root node directly (end.head == 0) + // and pointing at an Alt's out1 or out (end.head&1 == 1 or 0, respectively). + Frag FindByteRange(int root, int id); + + // Compares two ByteRanges and returns true iff they are equal. + bool ByteRangeEqual(int id1, int id2); + + // Returns the alternation of all the added suffixes. + Frag EndRange(); + + // Single rune. + Frag Literal(Rune r, bool foldcase); + + void Setup(Regexp::ParseFlags flags, int64_t max_mem, RE2::Anchor anchor); + Prog *Finish(Regexp *re); + + // Returns .* where dot = any byte + Frag DotStar(); + +private: + Prog *prog_; // Program being built. + bool failed_; // Did we give up compiling? + Encoding encoding_; // Input encoding + bool reversed_; // Should program run backward over text? + + PODArray inst_; + int ninst_; // Number of instructions used. + int max_ninst_; // Maximum number of instructions. + + int64_t max_mem_; // Total memory budget. + + std::unordered_map rune_cache_; + Frag rune_range_; + + RE2::Anchor anchor_; // anchor mode for RE2::Set + + Compiler(const Compiler &) = delete; + Compiler &operator=(const Compiler &) = delete; +}; + +Compiler::Compiler() { + prog_ = new Prog(); + failed_ = false; + encoding_ = kEncodingUTF8; + reversed_ = false; + ninst_ = 0; + max_ninst_ = 1; // make AllocInst for fail instruction okay + max_mem_ = 0; + int fail = AllocInst(1); + inst_[fail].InitFail(); + max_ninst_ = 0; // Caller must change +} + +Compiler::~Compiler() { delete prog_; } + +int Compiler::AllocInst(int n) { + if (failed_ || ninst_ + n > max_ninst_) { + failed_ = true; + return -1; + } + + if (ninst_ + n > inst_.size()) { + int cap = inst_.size(); + if (cap == 0) + cap = 8; + while (ninst_ + n > cap) + cap *= 2; + PODArray inst(cap); + if (inst_.data() != NULL) + memmove(inst.data(), inst_.data(), ninst_ * sizeof inst_[0]); + memset(inst.data() + ninst_, 0, (cap - ninst_) * sizeof inst_[0]); + inst_ = std::move(inst); + } + int id = ninst_; + ninst_ += n; + return id; +} + +// These routines are somewhat hard to visualize in text -- +// see http://swtch.com/~rsc/regexp/regexp1.html for +// pictures explaining what is going on here. + +// Returns an unmatchable fragment. +Frag Compiler::NoMatch() { return Frag(); } + +// Is a an unmatchable fragment? +static bool IsNoMatch(Frag a) { return a.begin == 0; } + +// Given fragments a and b, returns fragment for ab. +Frag Compiler::Cat(Frag a, Frag b) { + if (IsNoMatch(a) || IsNoMatch(b)) + return NoMatch(); + + // Elide no-op. + Prog::Inst *begin = &inst_[a.begin]; + if (begin->opcode() == kInstNop && a.end.head == (a.begin << 1) && begin->out() == 0) { + // in case refs to a somewhere + PatchList::Patch(inst_.data(), a.end, b.begin); + return b; + } + + // To run backward over string, reverse all concatenations. + if (reversed_) { + PatchList::Patch(inst_.data(), b.end, a.begin); + return Frag(b.begin, a.end, b.nullable && a.nullable); + } + + PatchList::Patch(inst_.data(), a.end, b.begin); + return Frag(a.begin, b.end, a.nullable && b.nullable); +} + +// Given fragments for a and b, returns fragment for a|b. +Frag Compiler::Alt(Frag a, Frag b) { + // Special case for convenience in loops. + if (IsNoMatch(a)) + return b; + if (IsNoMatch(b)) + return a; + + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + + inst_[id].InitAlt(a.begin, b.begin); + return Frag(id, PatchList::Append(inst_.data(), a.end, b.end), a.nullable || b.nullable); +} + +// When capturing submatches in like-Perl mode, a kOpAlt Inst +// treats out_ as the first choice, out1_ as the second. +// +// For *, +, and ?, if out_ causes another repetition, +// then the operator is greedy. If out1_ is the repetition +// (and out_ moves forward), then the operator is non-greedy. + +// Given a fragment for a, returns a fragment for a+ or a+? (if nongreedy) +Frag Compiler::Plus(Frag a, bool nongreedy) { + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + PatchList pl; + if (nongreedy) { + inst_[id].InitAlt(0, a.begin); + pl = PatchList::Mk(id << 1); + } else { + inst_[id].InitAlt(a.begin, 0); + pl = PatchList::Mk((id << 1) | 1); + } + PatchList::Patch(inst_.data(), a.end, id); + return Frag(a.begin, pl, a.nullable); +} + +// Given a fragment for a, returns a fragment for a* or a*? (if nongreedy) +Frag Compiler::Star(Frag a, bool nongreedy) { + // When the subexpression is nullable, one Alt isn't enough to guarantee + // correct priority ordering within the transitive closure. The simplest + // solution is to handle it as (a+)? instead, which adds the second Alt. + if (a.nullable) + return Quest(Plus(a, nongreedy), nongreedy); + + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + PatchList pl; + if (nongreedy) { + inst_[id].InitAlt(0, a.begin); + pl = PatchList::Mk(id << 1); + } else { + inst_[id].InitAlt(a.begin, 0); + pl = PatchList::Mk((id << 1) | 1); + } + PatchList::Patch(inst_.data(), a.end, id); + return Frag(id, pl, true); +} + +// Given a fragment for a, returns a fragment for a? or a?? (if nongreedy) +Frag Compiler::Quest(Frag a, bool nongreedy) { + if (IsNoMatch(a)) + return Nop(); + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + PatchList pl; + if (nongreedy) { + inst_[id].InitAlt(0, a.begin); + pl = PatchList::Mk(id << 1); + } else { + inst_[id].InitAlt(a.begin, 0); + pl = PatchList::Mk((id << 1) | 1); + } + return Frag(id, PatchList::Append(inst_.data(), pl, a.end), true); +} + +// Returns a fragment for the byte range lo-hi. +Frag Compiler::ByteRange(int lo, int hi, bool foldcase) { + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + inst_[id].InitByteRange(lo, hi, foldcase, 0); + return Frag(id, PatchList::Mk(id << 1), false); +} + +// Returns a no-op fragment. Sometimes unavoidable. +Frag Compiler::Nop() { + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + inst_[id].InitNop(0); + return Frag(id, PatchList::Mk(id << 1), true); +} + +// Returns a fragment that signals a match. +Frag Compiler::Match(int32_t match_id) { + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + inst_[id].InitMatch(match_id); + return Frag(id, kNullPatchList, false); +} + +// Returns a fragment matching a particular empty-width op (like ^ or $) +Frag Compiler::EmptyWidth(EmptyOp empty) { + int id = AllocInst(1); + if (id < 0) + return NoMatch(); + inst_[id].InitEmptyWidth(empty, 0); + return Frag(id, PatchList::Mk(id << 1), true); +} + +// Given a fragment a, returns a fragment with capturing parens around a. +Frag Compiler::Capture(Frag a, int n) { + if (IsNoMatch(a)) + return NoMatch(); + int id = AllocInst(2); + if (id < 0) + return NoMatch(); + inst_[id].InitCapture(2 * n, a.begin); + inst_[id + 1].InitCapture(2 * n + 1, 0); + PatchList::Patch(inst_.data(), a.end, id + 1); + + return Frag(id, PatchList::Mk((id + 1) << 1), a.nullable); +} + +// A Rune is a name for a Unicode code point. +// Returns maximum rune encoded by UTF-8 sequence of length len. +static int MaxRune(int len) { + int b; // number of Rune bits in len-byte UTF-8 sequence (len < UTFmax) + if (len == 1) + b = 7; + else + b = 8 - (len + 1) + 6 * (len - 1); + return (1 << b) - 1; // maximum Rune for b bits. +} + +// The rune range compiler caches common suffix fragments, +// which are very common in UTF-8 (e.g., [80-bf]). +// The fragment suffixes are identified by their start +// instructions. NULL denotes the eventual end match. +// The Frag accumulates in rune_range_. Caching common +// suffixes reduces the UTF-8 "." from 32 to 24 instructions, +// and it reduces the corresponding one-pass NFA from 16 nodes to 8. + +void Compiler::BeginRange() { + rune_cache_.clear(); + rune_range_.begin = 0; + rune_range_.end = kNullPatchList; +} + +int Compiler::UncachedRuneByteSuffix(uint8_t lo, uint8_t hi, bool foldcase, int next) { + Frag f = ByteRange(lo, hi, foldcase); + if (next != 0) { + PatchList::Patch(inst_.data(), f.end, next); + } else { + rune_range_.end = PatchList::Append(inst_.data(), rune_range_.end, f.end); + } + return f.begin; +} + +static uint64_t MakeRuneCacheKey(uint8_t lo, uint8_t hi, bool foldcase, int next) { + return (uint64_t)next << 17 | (uint64_t)lo << 9 | (uint64_t)hi << 1 | (uint64_t)foldcase; +} + +int Compiler::CachedRuneByteSuffix(uint8_t lo, uint8_t hi, bool foldcase, int next) { + uint64_t key = MakeRuneCacheKey(lo, hi, foldcase, next); + std::unordered_map::const_iterator it = rune_cache_.find(key); + if (it != rune_cache_.end()) + return it->second; + int id = UncachedRuneByteSuffix(lo, hi, foldcase, next); + rune_cache_[key] = id; + return id; +} + +bool Compiler::IsCachedRuneByteSuffix(int id) { + uint8_t lo = inst_[id].byte_range.lo_; + uint8_t hi = inst_[id].byte_range.hi_; + bool foldcase = inst_[id].foldcase() != 0; + int next = inst_[id].out(); + + uint64_t key = MakeRuneCacheKey(lo, hi, foldcase, next); + return rune_cache_.find(key) != rune_cache_.end(); +} + +void Compiler::AddSuffix(int id) { + if (failed_) + return; + + if (rune_range_.begin == 0) { + rune_range_.begin = id; + return; + } + + if (encoding_ == kEncodingUTF8) { + // Build a trie in order to reduce fanout. + rune_range_.begin = AddSuffixRecursive(rune_range_.begin, id); + return; + } + + int alt = AllocInst(1); + if (alt < 0) { + rune_range_.begin = 0; + return; + } + inst_[alt].InitAlt(rune_range_.begin, id); + rune_range_.begin = alt; +} + +int Compiler::AddSuffixRecursive(int root, int id) { + DCHECK(inst_[root].opcode() == kInstAlt || inst_[root].opcode() == kInstByteRange); + + Frag f = FindByteRange(root, id); + if (IsNoMatch(f)) { + int alt = AllocInst(1); + if (alt < 0) + return 0; + inst_[alt].InitAlt(root, id); + return alt; + } + + int br; + if (f.end.head == 0) + br = root; + else if (f.end.head & 1) + br = inst_[f.begin].out1(); + else + br = inst_[f.begin].out(); + + if (IsCachedRuneByteSuffix(br)) { + // We can't fiddle with cached suffixes, so make a clone of the head. + int byterange = AllocInst(1); + if (byterange < 0) + return 0; + inst_[byterange].InitByteRange(inst_[br].lo(), inst_[br].hi(), inst_[br].foldcase(), inst_[br].out()); + + // Ensure that the parent points to the clone, not to the original. + // Note that this could leave the head unreachable except via the cache. + br = byterange; + if (f.end.head == 0) + root = br; + else if (f.end.head & 1) + inst_[f.begin].out1_ = br; + else + inst_[f.begin].set_out(br); + } + + int out = inst_[id].out(); + if (!IsCachedRuneByteSuffix(id)) { + // The head should be the instruction most recently allocated, so free it + // instead of leaving it unreachable. + DCHECK_EQ(id, ninst_ - 1); + inst_[id].out_opcode_ = 0; + inst_[id].out1_ = 0; + ninst_--; + } + + out = AddSuffixRecursive(inst_[br].out(), out); + if (out == 0) + return 0; + + inst_[br].set_out(out); + return root; +} + +bool Compiler::ByteRangeEqual(int id1, int id2) { + return inst_[id1].lo() == inst_[id2].lo() && inst_[id1].hi() == inst_[id2].hi() && inst_[id1].foldcase() == inst_[id2].foldcase(); +} + +Frag Compiler::FindByteRange(int root, int id) { + if (inst_[root].opcode() == kInstByteRange) { + if (ByteRangeEqual(root, id)) + return Frag(root, kNullPatchList, false); + else + return NoMatch(); + } + + while (inst_[root].opcode() == kInstAlt) { + int out1 = inst_[root].out1(); + if (ByteRangeEqual(out1, id)) + return Frag(root, PatchList::Mk((root << 1) | 1), false); + + // CharClass is a sorted list of ranges, so if out1 of the root Alt wasn't + // what we're looking for, then we can stop immediately. Unfortunately, we + // can't short-circuit the search in reverse mode. + if (!reversed_) + return NoMatch(); + + int out = inst_[root].out(); + if (inst_[out].opcode() == kInstAlt) + root = out; + else if (ByteRangeEqual(out, id)) + return Frag(root, PatchList::Mk(root << 1), false); + else + return NoMatch(); + } + + LOG(DFATAL) << "should never happen"; + return NoMatch(); +} + +Frag Compiler::EndRange() { return rune_range_; } + +// Converts rune range lo-hi into a fragment that recognizes +// the bytes that would make up those runes in the current +// encoding (Latin 1 or UTF-8). +// This lets the machine work byte-by-byte even when +// using multibyte encodings. + +void Compiler::AddRuneRange(Rune lo, Rune hi, bool foldcase) { + switch (encoding_) { + default: + case kEncodingUTF8: + AddRuneRangeUTF8(lo, hi, foldcase); + break; + case kEncodingLatin1: + AddRuneRangeLatin1(lo, hi, foldcase); + break; + } +} + +void Compiler::AddRuneRangeLatin1(Rune lo, Rune hi, bool foldcase) { + // Latin-1 is easy: runes *are* bytes. + if (lo > hi || lo > 0xFF) + return; + if (hi > 0xFF) + hi = 0xFF; + AddSuffix(UncachedRuneByteSuffix(static_cast(lo), static_cast(hi), foldcase, 0)); +} + +void Compiler::Add_80_10ffff() { + // The 80-10FFFF (Runeself-Runemax) rune range occurs frequently enough + // (for example, for /./ and /[^a-z]/) that it is worth simplifying: by + // permitting overlong encodings in E0 and F0 sequences and code points + // over 10FFFF in F4 sequences, the size of the bytecode and the number + // of equivalence classes are reduced significantly. + int id; + if (reversed_) { + // Prefix factoring matters, but we don't have to handle it here + // because the rune range trie logic takes care of that already. + id = UncachedRuneByteSuffix(0xC2, 0xDF, false, 0); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + AddSuffix(id); + + id = UncachedRuneByteSuffix(0xE0, 0xEF, false, 0); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + AddSuffix(id); + + id = UncachedRuneByteSuffix(0xF0, 0xF4, false, 0); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + id = UncachedRuneByteSuffix(0x80, 0xBF, false, id); + AddSuffix(id); + } else { + // Suffix factoring matters - and we do have to handle it here. + int cont1 = UncachedRuneByteSuffix(0x80, 0xBF, false, 0); + id = UncachedRuneByteSuffix(0xC2, 0xDF, false, cont1); + AddSuffix(id); + + int cont2 = UncachedRuneByteSuffix(0x80, 0xBF, false, cont1); + id = UncachedRuneByteSuffix(0xE0, 0xEF, false, cont2); + AddSuffix(id); + + int cont3 = UncachedRuneByteSuffix(0x80, 0xBF, false, cont2); + id = UncachedRuneByteSuffix(0xF0, 0xF4, false, cont3); + AddSuffix(id); + } +} + +void Compiler::AddRuneRangeUTF8(Rune lo, Rune hi, bool foldcase) { + if (lo > hi) + return; + + // Pick off 80-10FFFF as a common special case. + if (lo == 0x80 && hi == 0x10ffff) { + Add_80_10ffff(); + return; + } + + // Split range into same-length sized ranges. + for (int i = 1; i < UTFmax; i++) { + Rune max = MaxRune(i); + if (lo <= max && max < hi) { + AddRuneRangeUTF8(lo, max, foldcase); + AddRuneRangeUTF8(max + 1, hi, foldcase); + return; + } + } + + // ASCII range is always a special case. + if (hi < Runeself) { + AddSuffix(UncachedRuneByteSuffix(static_cast(lo), static_cast(hi), foldcase, 0)); + return; + } + + // Split range into sections that agree on leading bytes. + for (int i = 1; i < UTFmax; i++) { + uint32_t m = (1 << (6 * i)) - 1; // last i bytes of a UTF-8 sequence + if ((lo & ~m) != (hi & ~m)) { + if ((lo & m) != 0) { + AddRuneRangeUTF8(lo, lo | m, foldcase); + AddRuneRangeUTF8((lo | m) + 1, hi, foldcase); + return; + } + if ((hi & m) != m) { + AddRuneRangeUTF8(lo, (hi & ~m) - 1, foldcase); + AddRuneRangeUTF8(hi & ~m, hi, foldcase); + return; + } + } + } + + // Finally. Generate byte matching equivalent for lo-hi. + uint8_t ulo[UTFmax], uhi[UTFmax]; + int n = runetochar(reinterpret_cast(ulo), &lo); + int m = runetochar(reinterpret_cast(uhi), &hi); + (void)m; // USED(m) + DCHECK_EQ(n, m); + + // The logic below encodes this thinking: + // + // 1. When we have built the whole suffix, we know that it cannot + // possibly be a suffix of anything longer: in forward mode, nothing + // else can occur before the leading byte; in reverse mode, nothing + // else can occur after the last continuation byte or else the leading + // byte would have to change. Thus, there is no benefit to caching + // the first byte of the suffix whereas there is a cost involved in + // cloning it if it begins a common prefix, which is fairly likely. + // + // 2. Conversely, the last byte of the suffix cannot possibly be a + // prefix of anything because next == 0, so we will never want to + // clone it, but it is fairly likely to be a common suffix. Perhaps + // more so in reverse mode than in forward mode because the former is + // "converging" towards lower entropy, but caching is still worthwhile + // for the latter in cases such as 80-BF. + // + // 3. Handling the bytes between the first and the last is less + // straightforward and, again, the approach depends on whether we are + // "converging" towards lower entropy: in forward mode, a single byte + // is unlikely to be part of a common suffix whereas a byte range + // is more likely so; in reverse mode, a byte range is unlikely to + // be part of a common suffix whereas a single byte is more likely + // so. The same benefit versus cost argument applies here. + int id = 0; + if (reversed_) { + for (int i = 0; i < n; i++) { + // In reverse UTF-8 mode: cache the leading byte; don't cache the last + // continuation byte; cache anything else iff it's a single byte (XX-XX). + if (i == 0 || (ulo[i] == uhi[i] && i != n - 1)) + id = CachedRuneByteSuffix(ulo[i], uhi[i], false, id); + else + id = UncachedRuneByteSuffix(ulo[i], uhi[i], false, id); + } + } else { + for (int i = n - 1; i >= 0; i--) { + // In forward UTF-8 mode: don't cache the leading byte; cache the last + // continuation byte; cache anything else iff it's a byte range (XX-YY). + if (i == n - 1 || (ulo[i] < uhi[i] && i != 0)) + id = CachedRuneByteSuffix(ulo[i], uhi[i], false, id); + else + id = UncachedRuneByteSuffix(ulo[i], uhi[i], false, id); + } + } + AddSuffix(id); +} + +// Should not be called. +Frag Compiler::Copy(Frag arg) { + // We're using WalkExponential; there should be no copying. + failed_ = true; + LOG(DFATAL) << "Compiler::Copy called!"; + return NoMatch(); +} + +// Visits a node quickly; called once WalkExponential has +// decided to cut this walk short. +Frag Compiler::ShortVisit(Regexp *re, Frag) { + failed_ = true; + return NoMatch(); +} + +// Called before traversing a node's children during the walk. +Frag Compiler::PreVisit(Regexp *re, Frag, bool *stop) { + // Cut off walk if we've already failed. + if (failed_) + *stop = true; + + return Frag(); // not used by caller +} + +Frag Compiler::Literal(Rune r, bool foldcase) { + switch (encoding_) { + default: + return Frag(); + + case kEncodingLatin1: + return ByteRange(r, r, foldcase); + + case kEncodingUTF8: { + if (r < Runeself) // Make common case fast. + return ByteRange(r, r, foldcase); + uint8_t buf[UTFmax]; + int n = runetochar(reinterpret_cast(buf), &r); + Frag f = ByteRange((uint8_t)buf[0], buf[0], false); + for (int i = 1; i < n; i++) + f = Cat(f, ByteRange((uint8_t)buf[i], buf[i], false)); + return f; + } + } +} + +// Called after traversing the node's children during the walk. +// Given their frags, build and return the frag for this re. +Frag Compiler::PostVisit(Regexp *re, Frag, Frag, Frag *child_frags, int nchild_frags) { + // If a child failed, don't bother going forward, especially + // since the child_frags might contain Frags with NULLs in them. + if (failed_) + return NoMatch(); + + // Given the child fragments, return the fragment for this node. + switch (re->op()) { + case kRegexpRepeat: + // Should not see; code at bottom of function will print error + break; + + case kRegexpNoMatch: + return NoMatch(); + + case kRegexpEmptyMatch: + return Nop(); + + case kRegexpHaveMatch: { + Frag f = Match(re->match_id()); + if (anchor_ == RE2::ANCHOR_BOTH) { + // Append \z or else the subexpression will effectively be unanchored. + // Complemented by the UNANCHORED case in CompileSet(). + f = Cat(EmptyWidth(kEmptyEndText), f); + } + return f; + } + + case kRegexpConcat: { + Frag f = child_frags[0]; + for (int i = 1; i < nchild_frags; i++) + f = Cat(f, child_frags[i]); + return f; + } + + case kRegexpAlternate: { + Frag f = child_frags[0]; + for (int i = 1; i < nchild_frags; i++) + f = Alt(f, child_frags[i]); + return f; + } + + case kRegexpStar: + return Star(child_frags[0], (re->parse_flags() & Regexp::NonGreedy) != 0); + + case kRegexpPlus: + return Plus(child_frags[0], (re->parse_flags() & Regexp::NonGreedy) != 0); + + case kRegexpQuest: + return Quest(child_frags[0], (re->parse_flags() & Regexp::NonGreedy) != 0); + + case kRegexpLiteral: + return Literal(re->rune(), (re->parse_flags() & Regexp::FoldCase) != 0); + + case kRegexpLiteralString: { + // Concatenation of literals. + if (re->nrunes() == 0) + return Nop(); + Frag f; + for (int i = 0; i < re->nrunes(); i++) { + Frag f1 = Literal(re->runes()[i], (re->parse_flags() & Regexp::FoldCase) != 0); + if (i == 0) + f = f1; + else + f = Cat(f, f1); + } + return f; + } + + case kRegexpAnyChar: + BeginRange(); + AddRuneRange(0, Runemax, false); + return EndRange(); + + case kRegexpAnyByte: + return ByteRange(0x00, 0xFF, false); + + case kRegexpCharClass: { + CharClass *cc = re->cc(); + if (cc->empty()) { + // This can't happen. + failed_ = true; + LOG(DFATAL) << "No ranges in char class"; + return NoMatch(); + } + + // ASCII case-folding optimization: if the char class + // behaves the same on A-Z as it does on a-z, + // discard any ranges wholly contained in A-Z + // and mark the other ranges as foldascii. + // This reduces the size of a program for + // (?i)abc from 3 insts per letter to 1 per letter. + bool foldascii = cc->FoldsASCII(); + + // Character class is just a big OR of the different + // character ranges in the class. + BeginRange(); + for (CharClass::iterator i = cc->begin(); i != cc->end(); ++i) { + // ASCII case-folding optimization (see above). + if (foldascii && 'A' <= i->lo && i->hi <= 'Z') + continue; + + // If this range contains all of A-Za-z or none of it, + // the fold flag is unnecessary; don't bother. + bool fold = foldascii; + if ((i->lo <= 'A' && 'z' <= i->hi) || i->hi < 'A' || 'z' < i->lo || ('Z' < i->lo && i->hi < 'a')) + fold = false; + + AddRuneRange(i->lo, i->hi, fold); + } + return EndRange(); + } + + case kRegexpCapture: + // If this is a non-capturing parenthesis -- (?:foo) -- + // just use the inner expression. + if (re->cap() < 0) + return child_frags[0]; + return Capture(child_frags[0], re->cap()); + + case kRegexpBeginLine: + return EmptyWidth(reversed_ ? kEmptyEndLine : kEmptyBeginLine); + + case kRegexpEndLine: + return EmptyWidth(reversed_ ? kEmptyBeginLine : kEmptyEndLine); + + case kRegexpBeginText: + return EmptyWidth(reversed_ ? kEmptyEndText : kEmptyBeginText); + + case kRegexpEndText: + return EmptyWidth(reversed_ ? kEmptyBeginText : kEmptyEndText); + + case kRegexpWordBoundary: + return EmptyWidth(kEmptyWordBoundary); + + case kRegexpNoWordBoundary: + return EmptyWidth(kEmptyNonWordBoundary); + } + failed_ = true; + LOG(DFATAL) << "Missing case in Compiler: " << re->op(); + return NoMatch(); +} + +// Is this regexp required to start at the beginning of the text? +// Only approximate; can return false for complicated regexps like (\Aa|\Ab), +// but handles (\A(a|b)). Could use the Walker to write a more exact one. +static bool IsAnchorStart(Regexp **pre, int depth) { + Regexp *re = *pre; + Regexp *sub; + // The depth limit makes sure that we don't overflow + // the stack on a deeply nested regexp. As the comment + // above says, IsAnchorStart is conservative, so returning + // a false negative is okay. The exact limit is somewhat arbitrary. + if (re == NULL || depth >= 4) + return false; + switch (re->op()) { + default: + break; + case kRegexpConcat: + if (re->nsub() > 0) { + sub = re->sub()[0]->Incref(); + if (IsAnchorStart(&sub, depth + 1)) { + PODArray subcopy(re->nsub()); + subcopy[0] = sub; // already have reference + for (int i = 1; i < re->nsub(); i++) + subcopy[i] = re->sub()[i]->Incref(); + *pre = Regexp::Concat(subcopy.data(), re->nsub(), re->parse_flags()); + re->Decref(); + return true; + } + sub->Decref(); + } + break; + case kRegexpCapture: + sub = re->sub()[0]->Incref(); + if (IsAnchorStart(&sub, depth + 1)) { + *pre = Regexp::Capture(sub, re->parse_flags(), re->cap()); + re->Decref(); + return true; + } + sub->Decref(); + break; + case kRegexpBeginText: + *pre = Regexp::LiteralString(NULL, 0, re->parse_flags()); + re->Decref(); + return true; + } + return false; +} + +// Is this regexp required to start at the end of the text? +// Only approximate; can return false for complicated regexps like (a\z|b\z), +// but handles ((a|b)\z). Could use the Walker to write a more exact one. +static bool IsAnchorEnd(Regexp **pre, int depth) { + Regexp *re = *pre; + Regexp *sub; + // The depth limit makes sure that we don't overflow + // the stack on a deeply nested regexp. As the comment + // above says, IsAnchorEnd is conservative, so returning + // a false negative is okay. The exact limit is somewhat arbitrary. + if (re == NULL || depth >= 4) + return false; + switch (re->op()) { + default: + break; + case kRegexpConcat: + if (re->nsub() > 0) { + sub = re->sub()[re->nsub() - 1]->Incref(); + if (IsAnchorEnd(&sub, depth + 1)) { + PODArray subcopy(re->nsub()); + subcopy[re->nsub() - 1] = sub; // already have reference + for (int i = 0; i < re->nsub() - 1; i++) + subcopy[i] = re->sub()[i]->Incref(); + *pre = Regexp::Concat(subcopy.data(), re->nsub(), re->parse_flags()); + re->Decref(); + return true; + } + sub->Decref(); + } + break; + case kRegexpCapture: + sub = re->sub()[0]->Incref(); + if (IsAnchorEnd(&sub, depth + 1)) { + *pre = Regexp::Capture(sub, re->parse_flags(), re->cap()); + re->Decref(); + return true; + } + sub->Decref(); + break; + case kRegexpEndText: + *pre = Regexp::LiteralString(NULL, 0, re->parse_flags()); + re->Decref(); + return true; + } + return false; +} + +void Compiler::Setup(Regexp::ParseFlags flags, int64_t max_mem, RE2::Anchor anchor) { + if (flags & Regexp::Latin1) + encoding_ = kEncodingLatin1; + max_mem_ = max_mem; + if (max_mem <= 0) { + max_ninst_ = 100000; // more than enough + } else if (static_cast(max_mem) <= sizeof(Prog)) { + // No room for anything. + max_ninst_ = 0; + } else { + int64_t m = (max_mem - sizeof(Prog)) / sizeof(Prog::Inst); + // Limit instruction count so that inst->id() fits nicely in an int. + // SparseArray also assumes that the indices (inst->id()) are ints. + // The call to WalkExponential uses 2*max_ninst_ below, + // and other places in the code use 2 or 3 * prog->size(). + // Limiting to 2^24 should avoid overflow in those places. + // (The point of allowing more than 32 bits of memory is to + // have plenty of room for the DFA states, not to use it up + // on the program.) + if (m >= 1 << 24) + m = 1 << 24; + // Inst imposes its own limit (currently bigger than 2^24 but be safe). + if (m > Prog::Inst::kMaxInst) + m = Prog::Inst::kMaxInst; + max_ninst_ = static_cast(m); + } + anchor_ = anchor; +} + +// Compiles re, returning program. +// Caller is responsible for deleting prog_. +// If reversed is true, compiles a program that expects +// to run over the input string backward (reverses all concatenations). +// The reversed flag is also recorded in the returned program. +Prog *Compiler::Compile(Regexp *re, bool reversed, int64_t max_mem) { + Compiler c; + c.Setup(re->parse_flags(), max_mem, RE2::UNANCHORED /* unused */); + c.reversed_ = reversed; + + // Simplify to remove things like counted repetitions + // and character classes like \d. + Regexp *sre = re->Simplify(); + if (sre == NULL) + return NULL; + + // Record whether prog is anchored, removing the anchors. + // (They get in the way of other optimizations.) + bool is_anchor_start = IsAnchorStart(&sre, 0); + bool is_anchor_end = IsAnchorEnd(&sre, 0); + + // Generate fragment for entire regexp. + Frag all = c.WalkExponential(sre, Frag(), 2 * c.max_ninst_); + sre->Decref(); + if (c.failed_) + return NULL; + + // Success! Finish by putting Match node at end, and record start. + // Turn off c.reversed_ (if it is set) to force the remaining concatenations + // to behave normally. + c.reversed_ = false; + all = c.Cat(all, c.Match(0)); + + c.prog_->set_reversed(reversed); + if (c.prog_->reversed()) { + c.prog_->set_anchor_start(is_anchor_end); + c.prog_->set_anchor_end(is_anchor_start); + } else { + c.prog_->set_anchor_start(is_anchor_start); + c.prog_->set_anchor_end(is_anchor_end); + } + + c.prog_->set_start(all.begin); + if (!c.prog_->anchor_start()) { + // Also create unanchored version, which starts with a .*? loop. + all = c.Cat(c.DotStar(), all); + } + c.prog_->set_start_unanchored(all.begin); + + // Hand ownership of prog_ to caller. + return c.Finish(re); +} + +Prog *Compiler::Finish(Regexp *re) { + if (failed_) + return NULL; + + if (prog_->start() == 0 && prog_->start_unanchored() == 0) { + // No possible matches; keep Fail instruction only. + ninst_ = 1; + } + + // Hand off the array to Prog. + prog_->inst_ = std::move(inst_); + prog_->size_ = ninst_; + + prog_->Optimize(); + prog_->Flatten(); + prog_->ComputeByteMap(); + + if (!prog_->reversed()) { + std::string prefix; + bool prefix_foldcase; + if (re->RequiredPrefixForAccel(&prefix, &prefix_foldcase)) + prog_->ConfigurePrefixAccel(prefix, prefix_foldcase); + } + + // Record remaining memory for DFA. + if (max_mem_ <= 0) { + prog_->set_dfa_mem(1 << 20); + } else { + int64_t m = max_mem_ - sizeof(Prog); + m -= prog_->size_ * sizeof(Prog::Inst); // account for inst_ + if (prog_->CanBitState()) + m -= prog_->size_ * sizeof(uint16_t); // account for list_heads_ + if (m < 0) + m = 0; + prog_->set_dfa_mem(m); + } + + Prog *p = prog_; + prog_ = NULL; + return p; +} + +// Converts Regexp to Prog. +Prog *Regexp::CompileToProg(int64_t max_mem) { return Compiler::Compile(this, false, max_mem); } + +Prog *Regexp::CompileToReverseProg(int64_t max_mem) { return Compiler::Compile(this, true, max_mem); } + +Frag Compiler::DotStar() { return Star(ByteRange(0x00, 0xff, false), true); } + +// Compiles RE set to Prog. +Prog *Compiler::CompileSet(Regexp *re, RE2::Anchor anchor, int64_t max_mem) { + Compiler c; + c.Setup(re->parse_flags(), max_mem, anchor); + + Regexp *sre = re->Simplify(); + if (sre == NULL) + return NULL; + + Frag all = c.WalkExponential(sre, Frag(), 2 * c.max_ninst_); + sre->Decref(); + if (c.failed_) + return NULL; + + c.prog_->set_anchor_start(true); + c.prog_->set_anchor_end(true); + + if (anchor == RE2::UNANCHORED) { + // Prepend .* or else the expression will effectively be anchored. + // Complemented by the ANCHOR_BOTH case in PostVisit(). + all = c.Cat(c.DotStar(), all); + } + c.prog_->set_start(all.begin); + c.prog_->set_start_unanchored(all.begin); + + Prog *prog = c.Finish(re); + if (prog == NULL) + return NULL; + + // Make sure DFA has enough memory to operate, + // since we're not going to fall back to the NFA. + bool dfa_failed = false; + StringPiece sp = "hello, world"; + prog->SearchDFA(sp, sp, Prog::kAnchored, Prog::kManyMatch, NULL, &dfa_failed, NULL); + if (dfa_failed) { + delete prog; + return NULL; + } + + return prog; +} + +Prog *Prog::CompileSet(Regexp *re, RE2::Anchor anchor, int64_t max_mem) { return Compiler::CompileSet(re, anchor, max_mem); } + +} // namespace re2 diff --git a/internal/cpp/re2/dfa.cc b/internal/cpp/re2/dfa.cc new file mode 100644 index 00000000000..8ca508097bc --- /dev/null +++ b/internal/cpp/re2/dfa.cc @@ -0,0 +1,1985 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// A DFA (deterministic finite automaton)-based regular expression search. +// +// The DFA search has two main parts: the construction of the automaton, +// which is represented by a graph of State structures, and the execution +// of the automaton over a given input string. +// +// The basic idea is that the State graph is constructed so that the +// execution can simply start with a state s, and then for each byte c in +// the input string, execute "s = s->next[c]", checking at each point whether +// the current s represents a matching state. +// +// The simple explanation just given does convey the essence of this code, +// but it omits the details of how the State graph gets constructed as well +// as some performance-driven optimizations to the execution of the automaton. +// All these details are explained in the comments for the code following +// the definition of class DFA. +// +// See http://swtch.com/~rsc/regexp/ for a very bare-bones equivalent. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/re2.h" +#include "re2/sparse_set.h" +#include "re2/stringpiece.h" +#include "util/logging.h" +#include "util/mix.h" +#include "util/mutex.h" +#include "util/strutil.h" + +// Silence "zero-sized array in struct/union" warning for DFA::State::next_. +#ifdef _MSC_VER +#pragma warning(disable : 4200) +#endif + +namespace re2 { + +// Controls whether the DFA should bail out early if the NFA would be faster. +static bool dfa_should_bail_when_slow = true; + +void Prog::TESTING_ONLY_set_dfa_should_bail_when_slow(bool b) { dfa_should_bail_when_slow = b; } + +// A DFA implementation of a regular expression program. +// Since this is entirely a forward declaration mandated by C++, +// some of the comments here are better understood after reading +// the comments in the sections that follow the DFA definition. +class DFA { +public: + DFA(Prog *prog, Prog::MatchKind kind, int64_t max_mem); + ~DFA(); + bool ok() const { return !init_failed_; } + Prog::MatchKind kind() { return kind_; } + + // Searches for the regular expression in text, which is considered + // as a subsection of context for the purposes of interpreting flags + // like ^ and $ and \A and \z. + // Returns whether a match was found. + // If a match is found, sets *ep to the end point of the best match in text. + // If "anchored", the match must begin at the start of text. + // If "want_earliest_match", the match that ends first is used, not + // necessarily the best one. + // If "run_forward" is true, the DFA runs from text.begin() to text.end(). + // If it is false, the DFA runs from text.end() to text.begin(), + // returning the leftmost end of the match instead of the rightmost one. + // If the DFA cannot complete the search (for example, if it is out of + // memory), it sets *failed and returns false. + bool Search(const StringPiece &text, + const StringPiece &context, + bool anchored, + bool want_earliest_match, + bool run_forward, + bool *failed, + const char **ep, + SparseSet *matches); + + // Builds out all states for the entire DFA. + // If cb is not empty, it receives one callback per state built. + // Returns the number of states built. + // FOR TESTING OR EXPERIMENTAL PURPOSES ONLY. + int BuildAllStates(const Prog::DFAStateCallback &cb); + + // Computes min and max for matching strings. Won't return strings + // bigger than maxlen. + bool PossibleMatchRange(std::string *min, std::string *max, int maxlen); + + // These data structures are logically private, but C++ makes it too + // difficult to mark them as such. + class RWLocker; + class StateSaver; + class Workq; + + // A single DFA state. The DFA is represented as a graph of these + // States, linked by the next_ pointers. If in state s and reading + // byte c, the next state should be s->next_[c]. + struct State { + inline bool IsMatch() const { return (flag_ & kFlagMatch) != 0; } + + int *inst_; // Instruction pointers in the state. + int ninst_; // # of inst_ pointers. + uint32_t flag_; // Empty string bitfield flags in effect on the way + // into this state, along with kFlagMatch if this + // is a matching state. + + // fixes from https://github.com/girishji/re2/commit/80b212f289c4ef75408b1510b9fc85e6cb9a447c + std::atomic *next_; // Outgoing arrows from State, + + // one per input byte class + }; + + enum { + kByteEndText = 256, // imaginary byte at end of text + + kFlagEmptyMask = 0xFF, // State.flag_: bits holding kEmptyXXX flags + kFlagMatch = 0x0100, // State.flag_: this is a matching state + kFlagLastWord = 0x0200, // State.flag_: last byte was a word char + kFlagNeedShift = 16, // needed kEmpty bits are or'ed in shifted left + }; + + struct StateHash { + size_t operator()(const State *a) const { + DCHECK(a != NULL); + HashMix mix(a->flag_); + for (int i = 0; i < a->ninst_; i++) + mix.Mix(a->inst_[i]); + mix.Mix(0); + return mix.get(); + } + }; + + struct StateEqual { + bool operator()(const State *a, const State *b) const { + DCHECK(a != NULL); + DCHECK(b != NULL); + if (a == b) + return true; + if (a->flag_ != b->flag_) + return false; + if (a->ninst_ != b->ninst_) + return false; + for (int i = 0; i < a->ninst_; i++) + if (a->inst_[i] != b->inst_[i]) + return false; + return true; + } + }; + + typedef std::unordered_set StateSet; + +private: + // Make it easier to swap in a scalable reader-writer mutex. + using CacheMutex = Mutex; + + enum { + // Indices into start_ for unanchored searches. + // Add kStartAnchored for anchored searches. + kStartBeginText = 0, // text at beginning of context + kStartBeginLine = 2, // text at beginning of line + kStartAfterWordChar = 4, // text follows a word character + kStartAfterNonWordChar = 6, // text follows non-word character + kMaxStart = 8, + + kStartAnchored = 1, + }; + + // Resets the DFA State cache, flushing all saved State* information. + // Releases and reacquires cache_mutex_ via cache_lock, so any + // State* existing before the call are not valid after the call. + // Use a StateSaver to preserve important states across the call. + // cache_mutex_.r <= L < mutex_ + // After: cache_mutex_.w <= L < mutex_ + void ResetCache(RWLocker *cache_lock); + + // Looks up and returns the State corresponding to a Workq. + // L >= mutex_ + State *WorkqToCachedState(Workq *q, Workq *mq, uint32_t flag); + + // Looks up and returns a State matching the inst, ninst, and flag. + // L >= mutex_ + State *CachedState(int *inst, int ninst, uint32_t flag); + + // Clear the cache entirely. + // Must hold cache_mutex_.w or be in destructor. + void ClearCache(); + + // Converts a State into a Workq: the opposite of WorkqToCachedState. + // L >= mutex_ + void StateToWorkq(State *s, Workq *q); + + // Runs a State on a given byte, returning the next state. + State *RunStateOnByteUnlocked(State *, int); // cache_mutex_.r <= L < mutex_ + State *RunStateOnByte(State *, int); // L >= mutex_ + + // Runs a Workq on a given byte followed by a set of empty-string flags, + // producing a new Workq in nq. If a match instruction is encountered, + // sets *ismatch to true. + // L >= mutex_ + void RunWorkqOnByte(Workq *q, Workq *nq, int c, uint32_t flag, bool *ismatch); + + // Runs a Workq on a set of empty-string flags, producing a new Workq in nq. + // L >= mutex_ + void RunWorkqOnEmptyString(Workq *q, Workq *nq, uint32_t flag); + + // Adds the instruction id to the Workq, following empty arrows + // according to flag. + // L >= mutex_ + void AddToQueue(Workq *q, int id, uint32_t flag); + + // For debugging, returns a text representation of State. + static std::string DumpState(State *state); + + // For debugging, returns a text representation of a Workq. + static std::string DumpWorkq(Workq *q); + + // Search parameters + struct SearchParams { + SearchParams(const StringPiece &text, const StringPiece &context, RWLocker *cache_lock) + : text(text), context(context), anchored(false), can_prefix_accel(false), want_earliest_match(false), run_forward(false), start(NULL), + cache_lock(cache_lock), failed(false), ep(NULL), matches(NULL) {} + + StringPiece text; + StringPiece context; + bool anchored; + bool can_prefix_accel; + bool want_earliest_match; + bool run_forward; + State *start; + RWLocker *cache_lock; + bool failed; // "out" parameter: whether search gave up + const char *ep; // "out" parameter: end pointer for match + SparseSet *matches; + + private: + SearchParams(const SearchParams &) = delete; + SearchParams &operator=(const SearchParams &) = delete; + }; + + // Before each search, the parameters to Search are analyzed by + // AnalyzeSearch to determine the state in which to start. + struct StartInfo { + StartInfo() : start(NULL) {} + std::atomic start; + }; + + // Fills in params->start and params->can_prefix_accel using + // the other search parameters. Returns true on success, + // false on failure. + // cache_mutex_.r <= L < mutex_ + bool AnalyzeSearch(SearchParams *params); + bool AnalyzeSearchHelper(SearchParams *params, StartInfo *info, uint32_t flags); + + // The generic search loop, inlined to create specialized versions. + // cache_mutex_.r <= L < mutex_ + // Might unlock and relock cache_mutex_ via params->cache_lock. + template + inline bool InlinedSearchLoop(SearchParams *params); + + // The specialized versions of InlinedSearchLoop. The three letters + // at the ends of the name denote the true/false values used as the + // last three parameters of InlinedSearchLoop. + // cache_mutex_.r <= L < mutex_ + // Might unlock and relock cache_mutex_ via params->cache_lock. + bool SearchFFF(SearchParams *params); + bool SearchFFT(SearchParams *params); + bool SearchFTF(SearchParams *params); + bool SearchFTT(SearchParams *params); + bool SearchTFF(SearchParams *params); + bool SearchTFT(SearchParams *params); + bool SearchTTF(SearchParams *params); + bool SearchTTT(SearchParams *params); + + // The main search loop: calls an appropriate specialized version of + // InlinedSearchLoop. + // cache_mutex_.r <= L < mutex_ + // Might unlock and relock cache_mutex_ via params->cache_lock. + bool FastSearchLoop(SearchParams *params); + + // Looks up bytes in bytemap_ but handles case c == kByteEndText too. + int ByteMap(int c) { + if (c == kByteEndText) + return prog_->bytemap_range(); + return prog_->bytemap()[c]; + } + + // Constant after initialization. + Prog *prog_; // The regular expression program to run. + Prog::MatchKind kind_; // The kind of DFA. + bool init_failed_; // initialization failed (out of memory) + + Mutex mutex_; // mutex_ >= cache_mutex_.r + + // Scratch areas, protected by mutex_. + Workq *q0_; // Two pre-allocated work queues. + Workq *q1_; + PODArray stack_; // Pre-allocated stack for AddToQueue + + // State* cache. Many threads use and add to the cache simultaneously, + // holding cache_mutex_ for reading and mutex_ (above) when adding. + // If the cache fills and needs to be discarded, the discarding is done + // while holding cache_mutex_ for writing, to avoid interrupting other + // readers. Any State* pointers are only valid while cache_mutex_ + // is held. + CacheMutex cache_mutex_; + int64_t mem_budget_; // Total memory budget for all States. + int64_t state_budget_; // Amount of memory remaining for new States. + StateSet state_cache_; // All States computed so far. + StartInfo start_[kMaxStart]; + + DFA(const DFA &) = delete; + DFA &operator=(const DFA &) = delete; +}; + +// Shorthand for casting to uint8_t*. +static inline const uint8_t *BytePtr(const void *v) { return reinterpret_cast(v); } + +// Work queues + +// Marks separate thread groups of different priority +// in the work queue when in leftmost-longest matching mode. +// #define Mark (-1) +constexpr auto Mark = -1; + +// Separates the match IDs from the instructions in inst_. +// Used only for "many match" DFA states. +// #define MatchSep (-2) +constexpr auto MatchSep = -2; + +// Internally, the DFA uses a sparse array of +// program instruction pointers as a work queue. +// In leftmost longest mode, marks separate sections +// of workq that started executing at different +// locations in the string (earlier locations first). +class DFA::Workq : public SparseSet { +public: + // Constructor: n is number of normal slots, maxmark number of mark slots. + Workq(int n, int maxmark) : SparseSet(n + maxmark), n_(n), maxmark_(maxmark), nextmark_(n), last_was_mark_(true) {} + + bool is_mark(int i) { return i >= n_; } + + int maxmark() { return maxmark_; } + + void clear() { + SparseSet::clear(); + nextmark_ = n_; + } + + void mark() { + if (last_was_mark_) + return; + last_was_mark_ = false; + SparseSet::insert_new(nextmark_++); + } + + int size() { return n_ + maxmark_; } + + void insert(int id) { + if (contains(id)) + return; + insert_new(id); + } + + void insert_new(int id) { + last_was_mark_ = false; + SparseSet::insert_new(id); + } + +private: + int n_; // size excluding marks + int maxmark_; // maximum number of marks + int nextmark_; // id of next mark + bool last_was_mark_; // last inserted was mark + + Workq(const Workq &) = delete; + Workq &operator=(const Workq &) = delete; +}; + +DFA::DFA(Prog *prog, Prog::MatchKind kind, int64_t max_mem) + : prog_(prog), kind_(kind), init_failed_(false), q0_(NULL), q1_(NULL), mem_budget_(max_mem) { + int nmark = 0; + if (kind_ == Prog::kLongestMatch) + nmark = prog_->size(); + // See DFA::AddToQueue() for why this is so. + int nstack = prog_->inst_count(kInstCapture) + prog_->inst_count(kInstEmptyWidth) + prog_->inst_count(kInstNop) + nmark + 1; // + 1 for start inst + + // Account for space needed for DFA, q0, q1, stack. + mem_budget_ -= sizeof(DFA); + mem_budget_ -= (prog_->size() + nmark) * (sizeof(int) + sizeof(int)) * 2; // q0, q1 + mem_budget_ -= nstack * sizeof(int); // stack + if (mem_budget_ < 0) { + init_failed_ = true; + return; + } + + state_budget_ = mem_budget_; + + // Make sure there is a reasonable amount of working room left. + // At minimum, the search requires room for two states in order + // to limp along, restarting frequently. We'll get better performance + // if there is room for a larger number of states, say 20. + // Note that a state stores list heads only, so we use the program + // list count for the upper bound, not the program size. + int nnext = prog_->bytemap_range() + 1; // + 1 for kByteEndText slot + int64_t one_state = sizeof(State) + nnext * sizeof(std::atomic) + (prog_->list_count() + nmark) * sizeof(int); + if (state_budget_ < 20 * one_state) { + init_failed_ = true; + return; + } + + q0_ = new Workq(prog_->size(), nmark); + q1_ = new Workq(prog_->size(), nmark); + stack_ = PODArray(nstack); +} + +DFA::~DFA() { + delete q0_; + delete q1_; + ClearCache(); +} + +// In the DFA state graph, s->next[c] == NULL means that the +// state has not yet been computed and needs to be. We need +// a different special value to signal that s->next[c] is a +// state that can never lead to a match (and thus the search +// can be called off). Hence DeadState. +#define DeadState reinterpret_cast(1) + +// Signals that the rest of the string matches no matter what it is. +#define FullMatchState reinterpret_cast(2) + +#define SpecialStateMax FullMatchState + +// Debugging printouts + +// For debugging, returns a string representation of the work queue. +std::string DFA::DumpWorkq(Workq *q) { + std::string s; + const char *sep = ""; + for (Workq::iterator it = q->begin(); it != q->end(); ++it) { + if (q->is_mark(*it)) { + s += "|"; + sep = ""; + } else { + s += StringPrintf("%s%d", sep, *it); + sep = ","; + } + } + return s; +} + +// For debugging, returns a string representation of the state. +std::string DFA::DumpState(State *state) { + if (state == NULL) + return "_"; + if (state == DeadState) + return "X"; + if (state == FullMatchState) + return "*"; + std::string s; + const char *sep = ""; + s += StringPrintf("(%p)", state); + for (int i = 0; i < state->ninst_; i++) { + if (state->inst_[i] == Mark) { + s += "|"; + sep = ""; + } else if (state->inst_[i] == MatchSep) { + s += "||"; + sep = ""; + } else { + s += StringPrintf("%s%d", sep, state->inst_[i]); + sep = ","; + } + } + s += StringPrintf(" flag=%#x", state->flag_); + return s; +} + +////////////////////////////////////////////////////////////////////// +// +// DFA state graph construction. +// +// The DFA state graph is a heavily-linked collection of State* structures. +// The state_cache_ is a set of all the State structures ever allocated, +// so that if the same state is reached by two different paths, +// the same State structure can be used. This reduces allocation +// requirements and also avoids duplication of effort across the two +// identical states. +// +// A State is defined by an ordered list of instruction ids and a flag word. +// +// The choice of an ordered list of instructions differs from a typical +// textbook DFA implementation, which would use an unordered set. +// Textbook descriptions, however, only care about whether +// the DFA matches, not where it matches in the text. To decide where the +// DFA matches, we need to mimic the behavior of the dominant backtracking +// implementations like PCRE, which try one possible regular expression +// execution, then another, then another, stopping when one of them succeeds. +// The DFA execution tries these many executions in parallel, representing +// each by an instruction id. These pointers are ordered in the State.inst_ +// list in the same order that the executions would happen in a backtracking +// search: if a match is found during execution of inst_[2], inst_[i] for i>=3 +// can be discarded. +// +// Textbooks also typically do not consider context-aware empty string operators +// like ^ or $. These are handled by the flag word, which specifies the set +// of empty-string operators that should be matched when executing at the +// current text position. These flag bits are defined in prog.h. +// The flag word also contains two DFA-specific bits: kFlagMatch if the state +// is a matching state (one that reached a kInstMatch in the program) +// and kFlagLastWord if the last processed byte was a word character, for the +// implementation of \B and \b. +// +// The flag word also contains, shifted up 16 bits, the bits looked for by +// any kInstEmptyWidth instructions in the state. These provide a useful +// summary indicating when new flags might be useful. +// +// The permanent representation of a State's instruction ids is just an array, +// but while a state is being analyzed, these instruction ids are represented +// as a Workq, which is an array that allows iteration in insertion order. + +// NOTE(rsc): The choice of State construction determines whether the DFA +// mimics backtracking implementations (so-called leftmost first matching) or +// traditional DFA implementations (so-called leftmost longest matching as +// prescribed by POSIX). This implementation chooses to mimic the +// backtracking implementations, because we want to replace PCRE. To get +// POSIX behavior, the states would need to be considered not as a simple +// ordered list of instruction ids, but as a list of unordered sets of instruction +// ids. A match by a state in one set would inhibit the running of sets +// farther down the list but not other instruction ids in the same set. Each +// set would correspond to matches beginning at a given point in the string. +// This is implemented by separating different sets with Mark pointers. + +// Looks in the State cache for a State matching q, flag. +// If one is found, returns it. If one is not found, allocates one, +// inserts it in the cache, and returns it. +// If mq is not null, MatchSep and the match IDs in mq will be appended +// to the State. +DFA::State *DFA::WorkqToCachedState(Workq *q, Workq *mq, uint32_t flag) { + // mutex_.AssertHeld(); + + // Construct array of instruction ids for the new state. + // Only ByteRange, EmptyWidth, and Match instructions are useful to keep: + // those are the only operators with any effect in + // RunWorkqOnEmptyString or RunWorkqOnByte. + PODArray inst(q->size()); + int n = 0; + uint32_t needflags = 0; // flags needed by kInstEmptyWidth instructions + bool sawmatch = false; // whether queue contains guaranteed kInstMatch + bool sawmark = false; // whether queue contains a Mark + + for (Workq::iterator it = q->begin(); it != q->end(); ++it) { + int id = *it; + if (sawmatch && (kind_ == Prog::kFirstMatch || q->is_mark(id))) + break; + if (q->is_mark(id)) { + if (n > 0 && inst[n - 1] != Mark) { + sawmark = true; + inst[n++] = Mark; + } + continue; + } + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + case kInstAltMatch: + // This state will continue to a match no matter what + // the rest of the input is. If it is the highest priority match + // being considered, return the special FullMatchState + // to indicate that it's all matches from here out. + if (kind_ != Prog::kManyMatch && (kind_ != Prog::kFirstMatch || (it == q->begin() && ip->greedy(prog_))) && + (kind_ != Prog::kLongestMatch || !sawmark) && (flag & kFlagMatch)) { + return FullMatchState; + } + FALLTHROUGH_INTENDED; + default: + // Record iff id is the head of its list, which must + // be the case if id-1 is the last of *its* list. :) + if (prog_->inst(id - 1)->last()) + inst[n++] = *it; + if (ip->opcode() == kInstEmptyWidth) + needflags |= ip->empty(); + if (ip->opcode() == kInstMatch && !prog_->anchor_end()) + sawmatch = true; + break; + } + } + DCHECK_LE(n, q->size()); + if (n > 0 && inst[n - 1] == Mark) + n--; + + // If there are no empty-width instructions waiting to execute, + // then the extra flag bits will not be used, so there is no + // point in saving them. (Discarding them reduces the number + // of distinct states.) + if (needflags == 0) + flag &= kFlagMatch; + + // NOTE(rsc): The code above cannot do flag &= needflags, + // because if the right flags were present to pass the current + // kInstEmptyWidth instructions, new kInstEmptyWidth instructions + // might be reached that in turn need different flags. + // The only sure thing is that if there are no kInstEmptyWidth + // instructions at all, no flags will be needed. + // We could do the extra work to figure out the full set of + // possibly needed flags by exploring past the kInstEmptyWidth + // instructions, but the check above -- are any flags needed + // at all? -- handles the most common case. More fine-grained + // analysis can only be justified by measurements showing that + // too many redundant states are being allocated. + + // If there are no Insts in the list, it's a dead state, + // which is useful to signal with a special pointer so that + // the execution loop can stop early. This is only okay + // if the state is *not* a matching state. + if (n == 0 && flag == 0) { + return DeadState; + } + + // If we're in longest match mode, the state is a sequence of + // unordered state sets separated by Marks. Sort each set + // to canonicalize, to reduce the number of distinct sets stored. + if (kind_ == Prog::kLongestMatch) { + int *ip = inst.data(); + int *ep = ip + n; + while (ip < ep) { + int *markp = ip; + while (markp < ep && *markp != Mark) + markp++; + std::sort(ip, markp); + if (markp < ep) + markp++; + ip = markp; + } + } + + // If we're in many match mode, canonicalize for similar reasons: + // we have an unordered set of states (i.e. we don't have Marks) + // and sorting will reduce the number of distinct sets stored. + if (kind_ == Prog::kManyMatch) { + int *ip = inst.data(); + int *ep = ip + n; + std::sort(ip, ep); + } + + // Append MatchSep and the match IDs in mq if necessary. + if (mq != NULL) { + inst[n++] = MatchSep; + for (Workq::iterator i = mq->begin(); i != mq->end(); ++i) { + int id = *i; + Prog::Inst *ip = prog_->inst(id); + if (ip->opcode() == kInstMatch) + inst[n++] = ip->match_id(); + } + } + + // Save the needed empty-width flags in the top bits for use later. + flag |= needflags << kFlagNeedShift; + + State *state = CachedState(inst.data(), n, flag); + return state; +} + +// Looks in the State cache for a State matching inst, ninst, flag. +// If one is found, returns it. If one is not found, allocates one, +// inserts it in the cache, and returns it. +DFA::State *DFA::CachedState(int *inst, int ninst, uint32_t flag) { + // mutex_.AssertHeld(); + + // Look in the cache for a pre-existing state. + // We have to initialise the struct like this because otherwise + // MSVC will complain about the flexible array member. :( + State state; + state.inst_ = inst; + state.ninst_ = ninst; + state.flag_ = flag; + StateSet::iterator it = state_cache_.find(&state); + if (it != state_cache_.end()) { + return *it; + } + + // Must have enough memory for new state. + // In addition to what we're going to allocate, + // the state cache hash table seems to incur about 40 bytes per + // State*, empirically. + const int kStateCacheOverhead = 40; + int nnext = prog_->bytemap_range() + 1; // + 1 for kByteEndText slot + int mem = sizeof(State) + nnext * sizeof(std::atomic) + ninst * sizeof(int); + if (mem_budget_ < mem + kStateCacheOverhead) { + mem_budget_ = -1; + return NULL; + } + mem_budget_ -= mem + kStateCacheOverhead; + + // Allocate new state along with room for next_ and inst_. + char *space = std::allocator().allocate(mem); + State *s = new (space) State; + s->next_ = new (space + sizeof(State)) std::atomic[nnext]; + // Work around a unfortunate bug in older versions of libstdc++. + // (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=64658) + for (int i = 0; i < nnext; i++) + (void)new (s->next_ + i) std::atomic(NULL); + s->inst_ = new (s->next_ + nnext) int[ninst]; + memmove(s->inst_, inst, ninst * sizeof s->inst_[0]); + s->ninst_ = ninst; + s->flag_ = flag; + // Put state in cache and return it. + state_cache_.insert(s); + return s; +} + +// Clear the cache. Must hold cache_mutex_.w or be in destructor. +void DFA::ClearCache() { + StateSet::iterator begin = state_cache_.begin(); + StateSet::iterator end = state_cache_.end(); + while (begin != end) { + StateSet::iterator tmp = begin; + ++begin; + // Deallocate the blob of memory that we allocated in DFA::CachedState(). + // We recompute mem in order to benefit from sized delete where possible. + int ninst = (*tmp)->ninst_; + int nnext = prog_->bytemap_range() + 1; // + 1 for kByteEndText slot + int mem = sizeof(State) + nnext * sizeof(std::atomic) + ninst * sizeof(int); + std::allocator().deallocate(reinterpret_cast(*tmp), mem); + } + state_cache_.clear(); +} + +// Copies insts in state s to the work queue q. +void DFA::StateToWorkq(State *s, Workq *q) { + q->clear(); + for (int i = 0; i < s->ninst_; i++) { + if (s->inst_[i] == Mark) { + q->mark(); + } else if (s->inst_[i] == MatchSep) { + // Nothing after this is an instruction! + break; + } else { + // Explore from the head of the list. + AddToQueue(q, s->inst_[i], s->flag_ & kFlagEmptyMask); + } + } +} + +// Adds ip to the work queue, following empty arrows according to flag. +void DFA::AddToQueue(Workq *q, int id, uint32_t flag) { + + // Use stack_ to hold our stack of instructions yet to process. + // It was preallocated as follows: + // one entry per Capture; + // one entry per EmptyWidth; and + // one entry per Nop. + // This reflects the maximum number of stack pushes that each can + // perform. (Each instruction can be processed at most once.) + // When using marks, we also added nmark == prog_->size(). + // (Otherwise, nmark == 0.) + int *stk = stack_.data(); + int nstk = 0; + + stk[nstk++] = id; + while (nstk > 0) { + DCHECK_LE(nstk, stack_.size()); + id = stk[--nstk]; + + Loop: + if (id == Mark) { + q->mark(); + continue; + } + + if (id == 0) + continue; + + // If ip is already on the queue, nothing to do. + // Otherwise add it. We don't actually keep all the + // ones that get added, but adding all of them here + // increases the likelihood of q->contains(id), + // reducing the amount of duplicated work. + if (q->contains(id)) + continue; + q->insert_new(id); + + // Process instruction. + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstByteRange: // just save these on the queue + case kInstMatch: + if (ip->last()) + break; + id = id + 1; + goto Loop; + + case kInstCapture: // DFA treats captures as no-ops. + case kInstNop: + if (!ip->last()) + stk[nstk++] = id + 1; + + // If this instruction is the [00-FF]* loop at the beginning of + // a leftmost-longest unanchored search, separate with a Mark so + // that future threads (which will start farther to the right in + // the input string) are lower priority than current threads. + if (ip->opcode() == kInstNop && q->maxmark() > 0 && id == prog_->start_unanchored() && id != prog_->start()) + stk[nstk++] = Mark; + id = ip->out(); + goto Loop; + + case kInstAltMatch: + DCHECK(!ip->last()); + id = id + 1; + goto Loop; + + case kInstEmptyWidth: + if (!ip->last()) + stk[nstk++] = id + 1; + + // Continue on if we have all the right flag bits. + if (ip->empty() & ~flag) + break; + id = ip->out(); + goto Loop; + } + } +} + +// Running of work queues. In the work queue, order matters: +// the queue is sorted in priority order. If instruction i comes before j, +// then the instructions that i produces during the run must come before +// the ones that j produces. In order to keep this invariant, all the +// work queue runners have to take an old queue to process and then +// also a new queue to fill in. It's not acceptable to add to the end of +// an existing queue, because new instructions will not end up in the +// correct position. + +// Runs the work queue, processing the empty strings indicated by flag. +// For example, flag == kEmptyBeginLine|kEmptyEndLine means to match +// both ^ and $. It is important that callers pass all flags at once: +// processing both ^ and $ is not the same as first processing only ^ +// and then processing only $. Doing the two-step sequence won't match +// ^$^$^$ but processing ^ and $ simultaneously will (and is the behavior +// exhibited by existing implementations). +void DFA::RunWorkqOnEmptyString(Workq *oldq, Workq *newq, uint32_t flag) { + newq->clear(); + for (Workq::iterator i = oldq->begin(); i != oldq->end(); ++i) { + if (oldq->is_mark(*i)) + AddToQueue(newq, Mark, flag); + else + AddToQueue(newq, *i, flag); + } +} + +// Runs the work queue, processing the single byte c followed by any empty +// strings indicated by flag. For example, c == 'a' and flag == kEmptyEndLine, +// means to match c$. Sets the bool *ismatch to true if the end of the +// regular expression program has been reached (the regexp has matched). +void DFA::RunWorkqOnByte(Workq *oldq, Workq *newq, int c, uint32_t flag, bool *ismatch) { + // mutex_.AssertHeld(); + + newq->clear(); + for (Workq::iterator i = oldq->begin(); i != oldq->end(); ++i) { + if (oldq->is_mark(*i)) { + if (*ismatch) + return; + newq->mark(); + continue; + } + int id = *i; + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstFail: // never succeeds + case kInstCapture: // already followed + case kInstNop: // already followed + case kInstAltMatch: // already followed + case kInstEmptyWidth: // already followed + break; + + case kInstByteRange: // can follow if c is in range + if (!ip->Matches(c)) + break; + AddToQueue(newq, ip->out(), flag); + if (ip->hint() != 0) { + // We have a hint, but we must cancel out the + // increment that will occur after the break. + i += ip->hint() - 1; + } else { + // We have no hint, so we must find the end + // of the current list and then skip to it. + Prog::Inst *ip0 = ip; + while (!ip->last()) + ++ip; + i += ip - ip0; + } + break; + + case kInstMatch: + if (prog_->anchor_end() && c != kByteEndText && kind_ != Prog::kManyMatch) + break; + *ismatch = true; + if (kind_ == Prog::kFirstMatch) { + // Can stop processing work queue since we found a match. + return; + } + break; + } + } +} + +// Processes input byte c in state, returning new state. +// Caller does not hold mutex. +DFA::State *DFA::RunStateOnByteUnlocked(State *state, int c) { + // Keep only one RunStateOnByte going + // even if the DFA is being run by multiple threads. + MutexLock l(&mutex_); + return RunStateOnByte(state, c); +} + +// Processes input byte c in state, returning new state. +DFA::State *DFA::RunStateOnByte(State *state, int c) { + // mutex_.AssertHeld(); + + if (state <= SpecialStateMax) { + if (state == FullMatchState) { + // It is convenient for routines like PossibleMatchRange + // if we implement RunStateOnByte for FullMatchState: + // once you get into this state you never get out, + // so it's pretty easy. + return FullMatchState; + } + if (state == DeadState) { + LOG(DFATAL) << "DeadState in RunStateOnByte"; + return NULL; + } + if (state == NULL) { + LOG(DFATAL) << "NULL state in RunStateOnByte"; + return NULL; + } + LOG(DFATAL) << "Unexpected special state in RunStateOnByte"; + return NULL; + } + + // If someone else already computed this, return it. + State *ns = state->next_[ByteMap(c)].load(std::memory_order_relaxed); + if (ns != NULL) + return ns; + + // Convert state into Workq. + StateToWorkq(state, q0_); + + // Flags marking the kinds of empty-width things (^ $ etc) + // around this byte. Before the byte we have the flags recorded + // in the State structure itself. After the byte we have + // nothing yet (but that will change: read on). + uint32_t needflag = state->flag_ >> kFlagNeedShift; + uint32_t beforeflag = state->flag_ & kFlagEmptyMask; + uint32_t oldbeforeflag = beforeflag; + uint32_t afterflag = 0; + + if (c == '\n') { + // Insert implicit $ and ^ around \n + beforeflag |= kEmptyEndLine; + afterflag |= kEmptyBeginLine; + } + + if (c == kByteEndText) { + // Insert implicit $ and \z before the fake "end text" byte. + beforeflag |= kEmptyEndLine | kEmptyEndText; + } + + // The state flag kFlagLastWord says whether the last + // byte processed was a word character. Use that info to + // insert empty-width (non-)word boundaries. + bool islastword = (state->flag_ & kFlagLastWord) != 0; + bool isword = c != kByteEndText && Prog::IsWordChar(static_cast(c)); + if (isword == islastword) + beforeflag |= kEmptyNonWordBoundary; + else + beforeflag |= kEmptyWordBoundary; + + // Okay, finally ready to run. + // Only useful to rerun on empty string if there are new, useful flags. + if (beforeflag & ~oldbeforeflag & needflag) { + RunWorkqOnEmptyString(q0_, q1_, beforeflag); + using std::swap; + swap(q0_, q1_); + } + bool ismatch = false; + RunWorkqOnByte(q0_, q1_, c, afterflag, &ismatch); + using std::swap; + swap(q0_, q1_); + + // Save afterflag along with ismatch and isword in new state. + uint32_t flag = afterflag; + if (ismatch) + flag |= kFlagMatch; + if (isword) + flag |= kFlagLastWord; + + if (ismatch && kind_ == Prog::kManyMatch) + ns = WorkqToCachedState(q0_, q1_, flag); + else + ns = WorkqToCachedState(q0_, NULL, flag); + + // Flush ns before linking to it. + // Write barrier before updating state->next_ so that the + // main search loop can proceed without any locking, for speed. + // (Otherwise it would need one mutex operation per input byte.) + state->next_[ByteMap(c)].store(ns, std::memory_order_release); + return ns; +} + +////////////////////////////////////////////////////////////////////// +// DFA cache reset. + +// Reader-writer lock helper. +// +// The DFA uses a reader-writer mutex to protect the state graph itself. +// Traversing the state graph requires holding the mutex for reading, +// and discarding the state graph and starting over requires holding the +// lock for writing. If a search needs to expand the graph but is out +// of memory, it will need to drop its read lock and then acquire the +// write lock. Since it cannot then atomically downgrade from write lock +// to read lock, it runs the rest of the search holding the write lock. +// (This probably helps avoid repeated contention, but really the decision +// is forced by the Mutex interface.) It's a bit complicated to keep +// track of whether the lock is held for reading or writing and thread +// that through the search, so instead we encapsulate it in the RWLocker +// and pass that around. + +class DFA::RWLocker { +public: + explicit RWLocker(CacheMutex *mu); + ~RWLocker(); + + // If the lock is only held for reading right now, + // drop the read lock and re-acquire for writing. + // Subsequent calls to LockForWriting are no-ops. + // Notice that the lock is *released* temporarily. + void LockForWriting(); + +private: + CacheMutex *mu_; + bool writing_; + + RWLocker(const RWLocker &) = delete; + RWLocker &operator=(const RWLocker &) = delete; +}; + +DFA::RWLocker::RWLocker(CacheMutex *mu) : mu_(mu), writing_(false) { mu_->ReaderLock(); } + +// This function is marked as NO_THREAD_SAFETY_ANALYSIS because +// the annotations don't support lock upgrade. +void DFA::RWLocker::LockForWriting() NO_THREAD_SAFETY_ANALYSIS { + if (!writing_) { + mu_->ReaderUnlock(); + mu_->WriterLock(); + writing_ = true; + } +} + +DFA::RWLocker::~RWLocker() { + if (!writing_) + mu_->ReaderUnlock(); + else + mu_->WriterUnlock(); +} + +// When the DFA's State cache fills, we discard all the states in the +// cache and start over. Many threads can be using and adding to the +// cache at the same time, so we synchronize using the cache_mutex_ +// to keep from stepping on other threads. Specifically, all the +// threads using the current cache hold cache_mutex_ for reading. +// When a thread decides to flush the cache, it drops cache_mutex_ +// and then re-acquires it for writing. That ensures there are no +// other threads accessing the cache anymore. The rest of the search +// runs holding cache_mutex_ for writing, avoiding any contention +// with or cache pollution caused by other threads. + +void DFA::ResetCache(RWLocker *cache_lock) { + // Re-acquire the cache_mutex_ for writing (exclusive use). + cache_lock->LockForWriting(); + + hooks::GetDFAStateCacheResetHook()({ + state_budget_, + state_cache_.size(), + }); + + // Clear the cache, reset the memory budget. + for (int i = 0; i < kMaxStart; i++) + start_[i].start.store(NULL, std::memory_order_relaxed); + ClearCache(); + mem_budget_ = state_budget_; +} + +// Typically, a couple States do need to be preserved across a cache +// reset, like the State at the current point in the search. +// The StateSaver class helps keep States across cache resets. +// It makes a copy of the state's guts outside the cache (before the reset) +// and then can be asked, after the reset, to recreate the State +// in the new cache. For example, in a DFA method ("this" is a DFA): +// +// StateSaver saver(this, s); +// ResetCache(cache_lock); +// s = saver.Restore(); +// +// The saver should always have room in the cache to re-create the state, +// because resetting the cache locks out all other threads, and the cache +// is known to have room for at least a couple states (otherwise the DFA +// constructor fails). + +class DFA::StateSaver { +public: + explicit StateSaver(DFA *dfa, State *state); + ~StateSaver(); + + // Recreates and returns a state equivalent to the + // original state passed to the constructor. + // Returns NULL if the cache has filled, but + // since the DFA guarantees to have room in the cache + // for a couple states, should never return NULL + // if used right after ResetCache. + State *Restore(); + +private: + DFA *dfa_; // the DFA to use + int *inst_; // saved info from State + int ninst_; + uint32_t flag_; + bool is_special_; // whether original state was special + State *special_; // if is_special_, the original state + + StateSaver(const StateSaver &) = delete; + StateSaver &operator=(const StateSaver &) = delete; +}; + +DFA::StateSaver::StateSaver(DFA *dfa, State *state) { + dfa_ = dfa; + if (state <= SpecialStateMax) { + inst_ = NULL; + ninst_ = 0; + flag_ = 0; + is_special_ = true; + special_ = state; + return; + } + is_special_ = false; + special_ = NULL; + flag_ = state->flag_; + ninst_ = state->ninst_; + inst_ = new int[ninst_]; + memmove(inst_, state->inst_, ninst_ * sizeof inst_[0]); +} + +DFA::StateSaver::~StateSaver() { + if (!is_special_) + delete[] inst_; +} + +DFA::State *DFA::StateSaver::Restore() { + if (is_special_) + return special_; + MutexLock l(&dfa_->mutex_); + State *s = dfa_->CachedState(inst_, ninst_, flag_); + if (s == NULL) + LOG(DFATAL) << "StateSaver failed to restore state."; + return s; +} + +////////////////////////////////////////////////////////////////////// +// +// DFA execution. +// +// The basic search loop is easy: start in a state s and then for each +// byte c in the input, s = s->next[c]. +// +// This simple description omits a few efficiency-driven complications. +// +// First, the State graph is constructed incrementally: it is possible +// that s->next[c] is null, indicating that that state has not been +// fully explored. In this case, RunStateOnByte must be invoked to +// determine the next state, which is cached in s->next[c] to save +// future effort. An alternative reason for s->next[c] to be null is +// that the DFA has reached a so-called "dead state", in which any match +// is no longer possible. In this case RunStateOnByte will return NULL +// and the processing of the string can stop early. +// +// Second, a 256-element pointer array for s->next_ makes each State +// quite large (2kB on 64-bit machines). Instead, dfa->bytemap_[] +// maps from bytes to "byte classes" and then next_ only needs to have +// as many pointers as there are byte classes. A byte class is simply a +// range of bytes that the regexp never distinguishes between. +// A regexp looking for a[abc] would have four byte ranges -- 0 to 'a'-1, +// 'a', 'b' to 'c', and 'c' to 0xFF. The bytemap slows us a little bit +// but in exchange we typically cut the size of a State (and thus our +// memory footprint) by about 5-10x. The comments still refer to +// s->next[c] for simplicity, but code should refer to s->next_[bytemap_[c]]. +// +// Third, it is common for a DFA for an unanchored match to begin in a +// state in which only one particular byte value can take the DFA to a +// different state. That is, s->next[c] != s for only one c. In this +// situation, the DFA can do better than executing the simple loop. +// Instead, it can call memchr to search very quickly for the byte c. +// Whether the start state has this property is determined during a +// pre-compilation pass and the "can_prefix_accel" argument is set. +// +// Fourth, the desired behavior is to search for the leftmost-best match +// (approximately, the same one that Perl would find), which is not +// necessarily the match ending earliest in the string. Each time a +// match is found, it must be noted, but the DFA must continue on in +// hope of finding a higher-priority match. In some cases, the caller only +// cares whether there is any match at all, not which one is found. +// The "want_earliest_match" flag causes the search to stop at the first +// match found. +// +// Fifth, one algorithm that uses the DFA needs it to run over the +// input string backward, beginning at the end and ending at the beginning. +// Passing false for the "run_forward" flag causes the DFA to run backward. +// +// The checks for these last three cases, which in a naive implementation +// would be performed once per input byte, slow the general loop enough +// to merit specialized versions of the search loop for each of the +// eight possible settings of the three booleans. Rather than write +// eight different functions, we write one general implementation and then +// inline it to create the specialized ones. +// +// Note that matches are delayed by one byte, to make it easier to +// accomodate match conditions depending on the next input byte (like $ and \b). +// When s->next[c]->IsMatch(), it means that there is a match ending just +// *before* byte c. + +// The generic search loop. Searches text for a match, returning +// the pointer to the end of the chosen match, or NULL if no match. +// The bools are equal to the same-named variables in params, but +// making them function arguments lets the inliner specialize +// this function to each combination (see two paragraphs above). +template +inline bool DFA::InlinedSearchLoop(SearchParams *params) { + State *start = params->start; + const uint8_t *bp = BytePtr(params->text.data()); // start of text + const uint8_t *p = bp; // text scanning point + const uint8_t *ep = BytePtr(params->text.data() + params->text.size()); // end of text + const uint8_t *resetp = NULL; // p at last cache reset + if (!run_forward) { + using std::swap; + swap(p, ep); + } + + const uint8_t *bytemap = prog_->bytemap(); + const uint8_t *lastmatch = NULL; // most recent matching position in text + bool matched = false; + + State *s = start; + + if (s->IsMatch()) { + matched = true; + lastmatch = p; + if (params->matches != NULL && kind_ == Prog::kManyMatch) { + for (int i = s->ninst_ - 1; i >= 0; i--) { + int id = s->inst_[i]; + if (id == MatchSep) + break; + params->matches->insert(id); + } + } + if (want_earliest_match) { + params->ep = reinterpret_cast(lastmatch); + return true; + } + } + + while (p != ep) { + + if (can_prefix_accel && s == start) { + // In start state, only way out is to find the prefix, + // so we use prefix accel (e.g. memchr) to skip ahead. + // If not found, we can skip to the end of the string. + p = BytePtr(prog_->PrefixAccel(p, ep - p)); + if (p == NULL) { + p = ep; + break; + } + } + + int c; + if (run_forward) + c = *p++; + else + c = *--p; + + // Note that multiple threads might be consulting + // s->next_[bytemap[c]] simultaneously. + // RunStateOnByte takes care of the appropriate locking, + // including a memory barrier so that the unlocked access + // (sometimes known as "double-checked locking") is safe. + // The alternative would be either one DFA per thread + // or one mutex operation per input byte. + // + // ns == DeadState means the state is known to be dead + // (no more matches are possible). + // ns == NULL means the state has not yet been computed + // (need to call RunStateOnByteUnlocked). + // RunStateOnByte returns ns == NULL if it is out of memory. + // ns == FullMatchState means the rest of the string matches. + // + // Okay to use bytemap[] not ByteMap() here, because + // c is known to be an actual byte and not kByteEndText. + + State *ns = s->next_[bytemap[c]].load(std::memory_order_acquire); + if (ns == NULL) { + ns = RunStateOnByteUnlocked(s, c); + if (ns == NULL) { + // After we reset the cache, we hold cache_mutex exclusively, + // so if resetp != NULL, it means we filled the DFA state + // cache with this search alone (without any other threads). + // Benchmarks show that doing a state computation on every + // byte runs at about 0.2 MB/s, while the NFA (nfa.cc) can do the + // same at about 2 MB/s. Unless we're processing an average + // of 10 bytes per state computation, fail so that RE2 can + // fall back to the NFA. However, RE2::Set cannot fall back, + // so we just have to keep on keeping on in that case. + if (dfa_should_bail_when_slow && resetp != NULL && static_cast(p - resetp) < 10 * state_cache_.size() && + kind_ != Prog::kManyMatch) { + params->failed = true; + return false; + } + resetp = p; + + // Prepare to save start and s across the reset. + StateSaver save_start(this, start); + StateSaver save_s(this, s); + + // Discard all the States in the cache. + ResetCache(params->cache_lock); + + // Restore start and s so we can continue. + if ((start = save_start.Restore()) == NULL || (s = save_s.Restore()) == NULL) { + // Restore already did LOG(DFATAL). + params->failed = true; + return false; + } + ns = RunStateOnByteUnlocked(s, c); + if (ns == NULL) { + LOG(DFATAL) << "RunStateOnByteUnlocked failed after ResetCache"; + params->failed = true; + return false; + } + } + } + if (ns <= SpecialStateMax) { + if (ns == DeadState) { + params->ep = reinterpret_cast(lastmatch); + return matched; + } + // FullMatchState + params->ep = reinterpret_cast(ep); + return true; + } + + s = ns; + if (s->IsMatch()) { + matched = true; + // The DFA notices the match one byte late, + // so adjust p before using it in the match. + if (run_forward) + lastmatch = p - 1; + else + lastmatch = p + 1; + if (params->matches != NULL && kind_ == Prog::kManyMatch) { + for (int i = s->ninst_ - 1; i >= 0; i--) { + int id = s->inst_[i]; + if (id == MatchSep) + break; + params->matches->insert(id); + } + } + if (want_earliest_match) { + params->ep = reinterpret_cast(lastmatch); + return true; + } + } + } + + // Process one more byte to see if it triggers a match. + // (Remember, matches are delayed one byte.) + + int lastbyte; + if (run_forward) { + if (EndPtr(params->text) == EndPtr(params->context)) + lastbyte = kByteEndText; + else + lastbyte = EndPtr(params->text)[0] & 0xFF; + } else { + if (BeginPtr(params->text) == BeginPtr(params->context)) + lastbyte = kByteEndText; + else + lastbyte = BeginPtr(params->text)[-1] & 0xFF; + } + + State *ns = s->next_[ByteMap(lastbyte)].load(std::memory_order_acquire); + if (ns == NULL) { + ns = RunStateOnByteUnlocked(s, lastbyte); + if (ns == NULL) { + StateSaver save_s(this, s); + ResetCache(params->cache_lock); + if ((s = save_s.Restore()) == NULL) { + params->failed = true; + return false; + } + ns = RunStateOnByteUnlocked(s, lastbyte); + if (ns == NULL) { + LOG(DFATAL) << "RunStateOnByteUnlocked failed after Reset"; + params->failed = true; + return false; + } + } + } + if (ns <= SpecialStateMax) { + if (ns == DeadState) { + params->ep = reinterpret_cast(lastmatch); + return matched; + } + // FullMatchState + params->ep = reinterpret_cast(ep); + return true; + } + + s = ns; + if (s->IsMatch()) { + matched = true; + lastmatch = p; + if (params->matches != NULL && kind_ == Prog::kManyMatch) { + for (int i = s->ninst_ - 1; i >= 0; i--) { + int id = s->inst_[i]; + if (id == MatchSep) + break; + params->matches->insert(id); + } + } + } + + params->ep = reinterpret_cast(lastmatch); + return matched; +} + +// Inline specializations of the general loop. +bool DFA::SearchFFF(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchFFT(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchFTF(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchFTT(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchTFF(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchTFT(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchTTF(SearchParams *params) { return InlinedSearchLoop(params); } +bool DFA::SearchTTT(SearchParams *params) { return InlinedSearchLoop(params); } + +// For performance, calls the appropriate specialized version +// of InlinedSearchLoop. +bool DFA::FastSearchLoop(SearchParams *params) { + // Because the methods are private, the Searches array + // cannot be declared at top level. + static bool (DFA::*Searches[])(SearchParams *) = { + &DFA::SearchFFF, + &DFA::SearchFFT, + &DFA::SearchFTF, + &DFA::SearchFTT, + &DFA::SearchTFF, + &DFA::SearchTFT, + &DFA::SearchTTF, + &DFA::SearchTTT, + }; + + int index = 4 * params->can_prefix_accel + 2 * params->want_earliest_match + 1 * params->run_forward; + return (this->*Searches[index])(params); +} + +// The discussion of DFA execution above ignored the question of how +// to determine the initial state for the search loop. There are two +// factors that influence the choice of start state. +// +// The first factor is whether the search is anchored or not. +// The regexp program (Prog*) itself has +// two different entry points: one for anchored searches and one for +// unanchored searches. (The unanchored version starts with a leading ".*?" +// and then jumps to the anchored one.) +// +// The second factor is where text appears in the larger context, which +// determines which empty-string operators can be matched at the beginning +// of execution. If text is at the very beginning of context, \A and ^ match. +// Otherwise if text is at the beginning of a line, then ^ matches. +// Otherwise it matters whether the character before text is a word character +// or a non-word character. +// +// The two cases (unanchored vs not) and four cases (empty-string flags) +// combine to make the eight cases recorded in the DFA's begin_text_[2], +// begin_line_[2], after_wordchar_[2], and after_nonwordchar_[2] cached +// StartInfos. The start state for each is filled in the first time it +// is used for an actual search. + +// Examines text, context, and anchored to determine the right start +// state for the DFA search loop. Fills in params and returns true on success. +// Returns false on failure. +bool DFA::AnalyzeSearch(SearchParams *params) { + const StringPiece &text = params->text; + const StringPiece &context = params->context; + + // Sanity check: make sure that text lies within context. + if (BeginPtr(text) < BeginPtr(context) || EndPtr(text) > EndPtr(context)) { + LOG(DFATAL) << "context does not contain text"; + params->start = DeadState; + return true; + } + + // Determine correct search type. + int start; + uint32_t flags; + if (params->run_forward) { + if (BeginPtr(text) == BeginPtr(context)) { + start = kStartBeginText; + flags = kEmptyBeginText | kEmptyBeginLine; + } else if (BeginPtr(text)[-1] == '\n') { + start = kStartBeginLine; + flags = kEmptyBeginLine; + } else if (Prog::IsWordChar(BeginPtr(text)[-1] & 0xFF)) { + start = kStartAfterWordChar; + flags = kFlagLastWord; + } else { + start = kStartAfterNonWordChar; + flags = 0; + } + } else { + if (EndPtr(text) == EndPtr(context)) { + start = kStartBeginText; + flags = kEmptyBeginText | kEmptyBeginLine; + } else if (EndPtr(text)[0] == '\n') { + start = kStartBeginLine; + flags = kEmptyBeginLine; + } else if (Prog::IsWordChar(EndPtr(text)[0] & 0xFF)) { + start = kStartAfterWordChar; + flags = kFlagLastWord; + } else { + start = kStartAfterNonWordChar; + flags = 0; + } + } + if (params->anchored) + start |= kStartAnchored; + StartInfo *info = &start_[start]; + + // Try once without cache_lock for writing. + // Try again after resetting the cache + // (ResetCache will relock cache_lock for writing). + if (!AnalyzeSearchHelper(params, info, flags)) { + ResetCache(params->cache_lock); + if (!AnalyzeSearchHelper(params, info, flags)) { + params->failed = true; + LOG(DFATAL) << "Failed to analyze start state."; + return false; + } + } + + params->start = info->start.load(std::memory_order_acquire); + + // Even if we could prefix accel, we cannot do so when anchored and, + // less obviously, we cannot do so when we are going to need flags. + // This trick works only when there is a single byte that leads to a + // different state! + if (prog_->can_prefix_accel() && !params->anchored && params->start > SpecialStateMax && params->start->flag_ >> kFlagNeedShift == 0) + params->can_prefix_accel = true; + + return true; +} + +// Fills in info if needed. Returns true on success, false on failure. +bool DFA::AnalyzeSearchHelper(SearchParams *params, StartInfo *info, uint32_t flags) { + // Quick check. + State *start = info->start.load(std::memory_order_acquire); + if (start != NULL) + return true; + + MutexLock l(&mutex_); + start = info->start.load(std::memory_order_relaxed); + if (start != NULL) + return true; + + q0_->clear(); + AddToQueue(q0_, params->anchored ? prog_->start() : prog_->start_unanchored(), flags); + start = WorkqToCachedState(q0_, NULL, flags); + if (start == NULL) + return false; + + // Synchronize with "quick check" above. + info->start.store(start, std::memory_order_release); + return true; +} + +// The actual DFA search: calls AnalyzeSearch and then FastSearchLoop. +bool DFA::Search(const StringPiece &text, + const StringPiece &context, + bool anchored, + bool want_earliest_match, + bool run_forward, + bool *failed, + const char **epp, + SparseSet *matches) { + *epp = NULL; + if (!ok()) { + *failed = true; + return false; + } + *failed = false; + + RWLocker l(&cache_mutex_); + SearchParams params(text, context, &l); + params.anchored = anchored; + params.want_earliest_match = want_earliest_match; + params.run_forward = run_forward; + params.matches = matches; + + if (!AnalyzeSearch(¶ms)) { + *failed = true; + return false; + } + if (params.start == DeadState) + return false; + if (params.start == FullMatchState) { + if (run_forward == want_earliest_match) + *epp = text.data(); + else + *epp = text.data() + text.size(); + return true; + } + bool ret = FastSearchLoop(¶ms); + if (params.failed) { + *failed = true; + return false; + } + *epp = params.ep; + return ret; +} + +DFA *Prog::GetDFA(MatchKind kind) { + // For a forward DFA, half the memory goes to each DFA. + // However, if it is a "many match" DFA, then there is + // no counterpart with which the memory must be shared. + // + // For a reverse DFA, all the memory goes to the + // "longest match" DFA, because RE2 never does reverse + // "first match" searches. + if (kind == kFirstMatch) { + std::call_once(dfa_first_once_, [](Prog *prog) { prog->dfa_first_ = new DFA(prog, kFirstMatch, prog->dfa_mem_ / 2); }, this); + return dfa_first_; + } else if (kind == kManyMatch) { + std::call_once(dfa_first_once_, [](Prog *prog) { prog->dfa_first_ = new DFA(prog, kManyMatch, prog->dfa_mem_); }, this); + return dfa_first_; + } else { + std::call_once( + dfa_longest_once_, + [](Prog *prog) { + if (!prog->reversed_) + prog->dfa_longest_ = new DFA(prog, kLongestMatch, prog->dfa_mem_ / 2); + else + prog->dfa_longest_ = new DFA(prog, kLongestMatch, prog->dfa_mem_); + }, + this); + return dfa_longest_; + } +} + +void Prog::DeleteDFA(DFA *dfa) { delete dfa; } + +// Executes the regexp program to search in text, +// which itself is inside the larger context. (As a convenience, +// passing a NULL context is equivalent to passing text.) +// Returns true if a match is found, false if not. +// If a match is found, fills in match0->end() to point at the end of the match +// and sets match0->begin() to text.begin(), since the DFA can't track +// where the match actually began. +// +// This is the only external interface (class DFA only exists in this file). +// +bool Prog::SearchDFA(const StringPiece &text, + const StringPiece &const_context, + Anchor anchor, + MatchKind kind, + StringPiece *match0, + bool *failed, + SparseSet *matches) { + *failed = false; + + StringPiece context = const_context; + if (context.data() == NULL) + context = text; + bool caret = anchor_start(); + bool dollar = anchor_end(); + if (reversed_) { + using std::swap; + swap(caret, dollar); + } + if (caret && BeginPtr(context) != BeginPtr(text)) + return false; + if (dollar && EndPtr(context) != EndPtr(text)) + return false; + + // Handle full match by running an anchored longest match + // and then checking if it covers all of text. + bool anchored = anchor == kAnchored || anchor_start() || kind == kFullMatch; + bool endmatch = false; + if (kind == kManyMatch) { + // This is split out in order to avoid clobbering kind. + } else if (kind == kFullMatch || anchor_end()) { + endmatch = true; + kind = kLongestMatch; + } + + // If the caller doesn't care where the match is (just whether one exists), + // then we can stop at the very first match we find, the so-called + // "earliest match". + bool want_earliest_match = false; + if (kind == kManyMatch) { + // This is split out in order to avoid clobbering kind. + if (matches == NULL) { + want_earliest_match = true; + } + } else if (match0 == NULL && !endmatch) { + want_earliest_match = true; + kind = kLongestMatch; + } + + DFA *dfa = GetDFA(kind); + const char *ep; + bool matched = dfa->Search(text, context, anchored, want_earliest_match, !reversed_, failed, &ep, matches); + if (*failed) { + hooks::GetDFASearchFailureHook()({ + // Nothing yet... + }); + return false; + } + if (!matched) + return false; + if (endmatch && ep != (reversed_ ? text.data() : text.data() + text.size())) + return false; + + // If caller cares, record the boundary of the match. + // We only know where it ends, so use the boundary of text + // as the beginning. + if (match0) { + if (reversed_) + *match0 = StringPiece(ep, static_cast(text.data() + text.size() - ep)); + else + *match0 = StringPiece(text.data(), static_cast(ep - text.data())); + } + return true; +} + +// Build out all states in DFA. Returns number of states. +int DFA::BuildAllStates(const Prog::DFAStateCallback &cb) { + if (!ok()) + return 0; + + // Pick out start state for unanchored search + // at beginning of text. + RWLocker l(&cache_mutex_); + SearchParams params(StringPiece(), StringPiece(), &l); + params.anchored = false; + if (!AnalyzeSearch(¶ms) || params.start == NULL || params.start == DeadState) + return 0; + + // Add start state to work queue. + // Note that any State* that we handle here must point into the cache, + // so we can simply depend on pointer-as-a-number hashing and equality. + std::unordered_map m; + std::deque q; + m.emplace(params.start, static_cast(m.size())); + q.push_back(params.start); + + // Compute the input bytes needed to cover all of the next pointers. + int nnext = prog_->bytemap_range() + 1; // + 1 for kByteEndText slot + std::vector input(nnext); + for (int c = 0; c < 256; c++) { + int b = prog_->bytemap()[c]; + while (c < 256 - 1 && prog_->bytemap()[c + 1] == b) + c++; + input[b] = c; + } + input[prog_->bytemap_range()] = kByteEndText; + + // Scratch space for the output. + std::vector output(nnext); + + // Flood to expand every state. + bool oom = false; + while (!q.empty()) { + State *s = q.front(); + q.pop_front(); + for (int c : input) { + State *ns = RunStateOnByteUnlocked(s, c); + if (ns == NULL) { + oom = true; + break; + } + if (ns == DeadState) { + output[ByteMap(c)] = -1; + continue; + } + if (m.find(ns) == m.end()) { + m.emplace(ns, static_cast(m.size())); + q.push_back(ns); + } + output[ByteMap(c)] = m[ns]; + } + if (cb) + cb(oom ? NULL : output.data(), s == FullMatchState || s->IsMatch()); + if (oom) + break; + } + + return static_cast(m.size()); +} + +// Build out all states in DFA for kind. Returns number of states. +int Prog::BuildEntireDFA(MatchKind kind, const DFAStateCallback &cb) { return GetDFA(kind)->BuildAllStates(cb); } + +// Computes min and max for matching string. +// Won't return strings bigger than maxlen. +bool DFA::PossibleMatchRange(std::string *min, std::string *max, int maxlen) { + if (!ok()) + return false; + + // NOTE: if future users of PossibleMatchRange want more precision when + // presented with infinitely repeated elements, consider making this a + // parameter to PossibleMatchRange. + static int kMaxEltRepetitions = 0; + + // Keep track of the number of times we've visited states previously. We only + // revisit a given state if it's part of a repeated group, so if the value + // portion of the map tuple exceeds kMaxEltRepetitions we bail out and set + // |*max| to |PrefixSuccessor(*max)|. + // + // Also note that previously_visited_states[UnseenStatePtr] will, in the STL + // tradition, implicitly insert a '0' value at first use. We take advantage + // of that property below. + std::unordered_map previously_visited_states; + + // Pick out start state for anchored search at beginning of text. + RWLocker l(&cache_mutex_); + SearchParams params(StringPiece(), StringPiece(), &l); + params.anchored = true; + if (!AnalyzeSearch(¶ms)) + return false; + if (params.start == DeadState) { // No matching strings + *min = ""; + *max = ""; + return true; + } + if (params.start == FullMatchState) // Every string matches: no max + return false; + + // The DFA is essentially a big graph rooted at params.start, + // and paths in the graph correspond to accepted strings. + // Each node in the graph has potentially 256+1 arrows + // coming out, one for each byte plus the magic end of + // text character kByteEndText. + + // To find the smallest possible prefix of an accepted + // string, we just walk the graph preferring to follow + // arrows with the lowest bytes possible. To find the + // largest possible prefix, we follow the largest bytes + // possible. + + // The test for whether there is an arrow from s on byte j is + // ns = RunStateOnByteUnlocked(s, j); + // if (ns == NULL) + // return false; + // if (ns != DeadState && ns->ninst > 0) + // The RunStateOnByteUnlocked call asks the DFA to build out the graph. + // It returns NULL only if the DFA has run out of memory, + // in which case we can't be sure of anything. + // The second check sees whether there was graph built + // and whether it is interesting graph. Nodes might have + // ns->ninst == 0 if they exist only to represent the fact + // that a match was found on the previous byte. + + // Build minimum prefix. + State *s = params.start; + min->clear(); + MutexLock lock(&mutex_); + for (int i = 0; i < maxlen; i++) { + if (previously_visited_states[s] > kMaxEltRepetitions) + break; + previously_visited_states[s]++; + + // Stop if min is a match. + State *ns = RunStateOnByte(s, kByteEndText); + if (ns == NULL) // DFA out of memory + return false; + if (ns != DeadState && (ns == FullMatchState || ns->IsMatch())) + break; + + // Try to extend the string with low bytes. + bool extended = false; + for (int j = 0; j < 256; j++) { + ns = RunStateOnByte(s, j); + if (ns == NULL) // DFA out of memory + return false; + if (ns == FullMatchState || (ns > SpecialStateMax && ns->ninst_ > 0)) { + extended = true; + min->append(1, static_cast(j)); + s = ns; + break; + } + } + if (!extended) + break; + } + + // Build maximum prefix. + previously_visited_states.clear(); + s = params.start; + max->clear(); + for (int i = 0; i < maxlen; i++) { + if (previously_visited_states[s] > kMaxEltRepetitions) + break; + previously_visited_states[s] += 1; + + // Try to extend the string with high bytes. + bool extended = false; + for (int j = 255; j >= 0; j--) { + State *ns = RunStateOnByte(s, j); + if (ns == NULL) + return false; + if (ns == FullMatchState || (ns > SpecialStateMax && ns->ninst_ > 0)) { + extended = true; + max->append(1, static_cast(j)); + s = ns; + break; + } + } + if (!extended) { + // Done, no need for PrefixSuccessor. + return true; + } + } + + // Stopped while still adding to *max - round aaaaaaaaaa... to aaaa...b + PrefixSuccessor(max); + + // If there are no bytes left, we have no way to say "there is no maximum + // string". We could make the interface more complicated and be able to + // return "there is no maximum but here is a minimum", but that seems like + // overkill -- the most common no-max case is all possible strings, so not + // telling the caller that the empty string is the minimum match isn't a + // great loss. + if (max->empty()) + return false; + + return true; +} + +// PossibleMatchRange for a Prog. +bool Prog::PossibleMatchRange(std::string *min, std::string *max, int maxlen) { + // Have to use dfa_longest_ to get all strings for full matches. + // For example, (a|aa) never matches aa in first-match mode. + return GetDFA(kLongestMatch)->PossibleMatchRange(min, max, maxlen); +} + +} // namespace re2 diff --git a/internal/cpp/re2/filtered_re2.cc b/internal/cpp/re2/filtered_re2.cc new file mode 100644 index 00000000000..beada0f6246 --- /dev/null +++ b/internal/cpp/re2/filtered_re2.cc @@ -0,0 +1,118 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/filtered_re2.h" + +#include +#include +#include + +#include "re2/prefilter.h" +#include "re2/prefilter_tree.h" +#include "util/logging.h" +#include "util/util.h" + +namespace re2 { + +FilteredRE2::FilteredRE2() : compiled_(false), prefilter_tree_(new PrefilterTree()) {} + +FilteredRE2::FilteredRE2(int min_atom_len) : compiled_(false), prefilter_tree_(new PrefilterTree(min_atom_len)) {} + +FilteredRE2::~FilteredRE2() { + for (size_t i = 0; i < re2_vec_.size(); i++) + delete re2_vec_[i]; +} + +FilteredRE2::FilteredRE2(FilteredRE2 &&other) + : re2_vec_(std::move(other.re2_vec_)), compiled_(other.compiled_), prefilter_tree_(std::move(other.prefilter_tree_)) { + other.re2_vec_.clear(); + other.re2_vec_.shrink_to_fit(); + other.compiled_ = false; + other.prefilter_tree_.reset(new PrefilterTree()); +} + +FilteredRE2 &FilteredRE2::operator=(FilteredRE2 &&other) { + this->~FilteredRE2(); + (void)new (this) FilteredRE2(std::move(other)); + return *this; +} + +RE2::ErrorCode FilteredRE2::Add(const StringPiece &pattern, const RE2::Options &options, int *id) { + RE2 *re = new RE2(pattern, options); + RE2::ErrorCode code = re->error_code(); + + if (!re->ok()) { + if (options.log_errors()) { + LOG(ERROR) << "Couldn't compile regular expression, skipping: " << pattern << " due to error " << re->error(); + } + delete re; + } else { + *id = static_cast(re2_vec_.size()); + re2_vec_.push_back(re); + } + + return code; +} + +void FilteredRE2::Compile(std::vector *atoms) { + if (compiled_) { + LOG(ERROR) << "Compile called already."; + return; + } + + if (re2_vec_.empty()) { + LOG(ERROR) << "Compile called before Add."; + return; + } + + for (size_t i = 0; i < re2_vec_.size(); i++) { + Prefilter *prefilter = Prefilter::FromRE2(re2_vec_[i]); + prefilter_tree_->Add(prefilter); + } + atoms->clear(); + prefilter_tree_->Compile(atoms); + compiled_ = true; +} + +int FilteredRE2::SlowFirstMatch(const StringPiece &text) const { + for (size_t i = 0; i < re2_vec_.size(); i++) + if (RE2::PartialMatch(text, *re2_vec_[i])) + return static_cast(i); + return -1; +} + +int FilteredRE2::FirstMatch(const StringPiece &text, const std::vector &atoms) const { + if (!compiled_) { + LOG(DFATAL) << "FirstMatch called before Compile."; + return -1; + } + std::vector regexps; + prefilter_tree_->RegexpsGivenStrings(atoms, ®exps); + for (size_t i = 0; i < regexps.size(); i++) + if (RE2::PartialMatch(text, *re2_vec_[regexps[i]])) + return regexps[i]; + return -1; +} + +bool FilteredRE2::AllMatches(const StringPiece &text, const std::vector &atoms, std::vector *matching_regexps) const { + matching_regexps->clear(); + std::vector regexps; + prefilter_tree_->RegexpsGivenStrings(atoms, ®exps); + for (size_t i = 0; i < regexps.size(); i++) + if (RE2::PartialMatch(text, *re2_vec_[regexps[i]])) + matching_regexps->push_back(regexps[i]); + return !matching_regexps->empty(); +} + +void FilteredRE2::AllPotentials(const std::vector &atoms, std::vector *potential_regexps) const { + prefilter_tree_->RegexpsGivenStrings(atoms, potential_regexps); +} + +void FilteredRE2::RegexpsGivenStrings(const std::vector &matched_atoms, std::vector *passed_regexps) { + prefilter_tree_->RegexpsGivenStrings(matched_atoms, passed_regexps); +} + +void FilteredRE2::PrintPrefilter(int regexpid) { prefilter_tree_->PrintPrefilter(regexpid); } + +} // namespace re2 diff --git a/internal/cpp/re2/filtered_re2.h b/internal/cpp/re2/filtered_re2.h new file mode 100644 index 00000000000..5174a8c305f --- /dev/null +++ b/internal/cpp/re2/filtered_re2.h @@ -0,0 +1,107 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_FILTERED_RE2_H_ +#define RE2_FILTERED_RE2_H_ + +// The class FilteredRE2 is used as a wrapper to multiple RE2 regexps. +// It provides a prefilter mechanism that helps in cutting down the +// number of regexps that need to be actually searched. +// +// By design, it does not include a string matching engine. This is to +// allow the user of the class to use their favorite string matching +// engine. The overall flow is: Add all the regexps using Add, then +// Compile the FilteredRE2. Compile returns strings that need to be +// matched. Note that the returned strings are lowercased and distinct. +// For applying regexps to a search text, the caller does the string +// matching using the returned strings. When doing the string match, +// note that the caller has to do that in a case-insensitive way or +// on a lowercased version of the search text. Then call FirstMatch +// or AllMatches with a vector of indices of strings that were found +// in the text to get the actual regexp matches. + +#include +#include +#include + +#include "re2/re2.h" + +namespace re2 { + +class PrefilterTree; + +class FilteredRE2 { +public: + FilteredRE2(); + explicit FilteredRE2(int min_atom_len); + ~FilteredRE2(); + + // Not copyable. + FilteredRE2(const FilteredRE2 &) = delete; + FilteredRE2 &operator=(const FilteredRE2 &) = delete; + // Movable. + FilteredRE2(FilteredRE2 &&other); + FilteredRE2 &operator=(FilteredRE2 &&other); + + // Uses RE2 constructor to create a RE2 object (re). Returns + // re->error_code(). If error_code is other than NoError, then re is + // deleted and not added to re2_vec_. + RE2::ErrorCode Add(const StringPiece &pattern, const RE2::Options &options, int *id); + + // Prepares the regexps added by Add for filtering. Returns a set + // of strings that the caller should check for in candidate texts. + // The returned strings are lowercased and distinct. When doing + // string matching, it should be performed in a case-insensitive + // way or the search text should be lowercased first. Call after + // all Add calls are done. + void Compile(std::vector *strings_to_match); + + // Returns the index of the first matching regexp. + // Returns -1 on no match. Can be called prior to Compile. + // Does not do any filtering: simply tries to Match the + // regexps in a loop. + int SlowFirstMatch(const StringPiece &text) const; + + // Returns the index of the first matching regexp. + // Returns -1 on no match. Compile has to be called before + // calling this. + int FirstMatch(const StringPiece &text, const std::vector &atoms) const; + + // Returns the indices of all matching regexps, after first clearing + // matched_regexps. + bool AllMatches(const StringPiece &text, const std::vector &atoms, std::vector *matching_regexps) const; + + // Returns the indices of all potentially matching regexps after first + // clearing potential_regexps. + // A regexp is potentially matching if it passes the filter. + // If a regexp passes the filter it may still not match. + // A regexp that does not pass the filter is guaranteed to not match. + void AllPotentials(const std::vector &atoms, std::vector *potential_regexps) const; + + // The number of regexps added. + int NumRegexps() const { return static_cast(re2_vec_.size()); } + + // Get the individual RE2 objects. + const RE2 &GetRE2(int regexpid) const { return *re2_vec_[regexpid]; } + +private: + // Print prefilter. + void PrintPrefilter(int regexpid); + + // Useful for testing and debugging. + void RegexpsGivenStrings(const std::vector &matched_atoms, std::vector *passed_regexps); + + // All the regexps in the FilteredRE2. + std::vector re2_vec_; + + // Has the FilteredRE2 been compiled using Compile() + bool compiled_; + + // An AND-OR tree of string atoms used for filtering regexps. + std::unique_ptr prefilter_tree_; +}; + +} // namespace re2 + +#endif // RE2_FILTERED_RE2_H_ diff --git a/internal/cpp/re2/mimics_pcre.cc b/internal/cpp/re2/mimics_pcre.cc new file mode 100644 index 00000000000..88bc55627ad --- /dev/null +++ b/internal/cpp/re2/mimics_pcre.cc @@ -0,0 +1,192 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Determine whether this library should match PCRE exactly +// for a particular Regexp. (If so, the testing framework can +// check that it does.) +// +// This library matches PCRE except in these cases: +// * the regexp contains a repetition of an empty string, +// like (a*)* or (a*)+. In this case, PCRE will treat +// the repetition sequence as ending with an empty string, +// while this library does not. +// * Perl and PCRE differ on whether \v matches \n. +// For historical reasons, this library implements the Perl behavior. +// * Perl and PCRE allow $ in one-line mode to match either the very +// end of the text or just before a \n at the end of the text. +// This library requires it to match only the end of the text. +// * Similarly, Perl and PCRE do not allow ^ in multi-line mode to +// match the end of the text if the last character is a \n. +// This library does allow it. +// +// Regexp::MimicsPCRE checks for any of these conditions. + +#include "re2/regexp.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/util.h" + +namespace re2 { + +// Returns whether re might match an empty string. +static bool CanBeEmptyString(Regexp *re); + +// Walker class to compute whether library handles a regexp +// exactly as PCRE would. See comment at top for conditions. + +class PCREWalker : public Regexp::Walker { +public: + PCREWalker() {} + + virtual bool PostVisit(Regexp *re, bool parent_arg, bool pre_arg, bool *child_args, int nchild_args); + + virtual bool ShortVisit(Regexp *re, bool a) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "PCREWalker::ShortVisit called"; +#endif + return a; + } + +private: + PCREWalker(const PCREWalker &) = delete; + PCREWalker &operator=(const PCREWalker &) = delete; +}; + +// Called after visiting each of re's children and accumulating +// the return values in child_args. So child_args contains whether +// this library mimics PCRE for those subexpressions. +bool PCREWalker::PostVisit(Regexp *re, bool parent_arg, bool pre_arg, bool *child_args, int nchild_args) { + // If children failed, so do we. + for (int i = 0; i < nchild_args; i++) + if (!child_args[i]) + return false; + + // Otherwise look for other reasons to fail. + switch (re->op()) { + // Look for repeated empty string. + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + if (CanBeEmptyString(re->sub()[0])) + return false; + break; + case kRegexpRepeat: + if (re->max() == -1 && CanBeEmptyString(re->sub()[0])) + return false; + break; + + // Look for \v + case kRegexpLiteral: + if (re->rune() == '\v') + return false; + break; + + // Look for $ in single-line mode. + case kRegexpEndText: + case kRegexpEmptyMatch: + if (re->parse_flags() & Regexp::WasDollar) + return false; + break; + + // Look for ^ in multi-line mode. + case kRegexpBeginLine: + // No condition: in single-line mode ^ becomes kRegexpBeginText. + return false; + + default: + break; + } + + // Not proven guilty. + return true; +} + +// Returns whether this regexp's behavior will mimic PCRE's exactly. +bool Regexp::MimicsPCRE() { + PCREWalker w; + return w.Walk(this, true); +} + +// Walker class to compute whether a Regexp can match an empty string. +// It is okay to overestimate. For example, \b\B cannot match an empty +// string, because \b and \B are mutually exclusive, but this isn't +// that smart and will say it can. Spurious empty strings +// will reduce the number of regexps we sanity check against PCRE, +// but they won't break anything. + +class EmptyStringWalker : public Regexp::Walker { +public: + EmptyStringWalker() {} + + virtual bool PostVisit(Regexp *re, bool parent_arg, bool pre_arg, bool *child_args, int nchild_args); + + virtual bool ShortVisit(Regexp *re, bool a) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "EmptyStringWalker::ShortVisit called"; +#endif + return a; + } + +private: + EmptyStringWalker(const EmptyStringWalker &) = delete; + EmptyStringWalker &operator=(const EmptyStringWalker &) = delete; +}; + +// Called after visiting re's children. child_args contains the return +// value from each of the children's PostVisits (i.e., whether each child +// can match an empty string). Returns whether this clause can match an +// empty string. +bool EmptyStringWalker::PostVisit(Regexp *re, bool parent_arg, bool pre_arg, bool *child_args, int nchild_args) { + switch (re->op()) { + case kRegexpNoMatch: // never empty + case kRegexpLiteral: + case kRegexpAnyChar: + case kRegexpAnyByte: + case kRegexpCharClass: + case kRegexpLiteralString: + return false; + + case kRegexpEmptyMatch: // always empty + case kRegexpBeginLine: // always empty, when they match + case kRegexpEndLine: + case kRegexpNoWordBoundary: + case kRegexpWordBoundary: + case kRegexpBeginText: + case kRegexpEndText: + case kRegexpStar: // can always be empty + case kRegexpQuest: + case kRegexpHaveMatch: + return true; + + case kRegexpConcat: // can be empty if all children can + for (int i = 0; i < nchild_args; i++) + if (!child_args[i]) + return false; + return true; + + case kRegexpAlternate: // can be empty if any child can + for (int i = 0; i < nchild_args; i++) + if (child_args[i]) + return true; + return false; + + case kRegexpPlus: // can be empty if the child can + case kRegexpCapture: + return child_args[0]; + + case kRegexpRepeat: // can be empty if child can or is x{0} + return child_args[0] || re->min() == 0; + } + return false; +} + +// Returns whether re can match an empty string. +static bool CanBeEmptyString(Regexp *re) { + EmptyStringWalker w; + return w.Walk(re, true); +} + +} // namespace re2 diff --git a/internal/cpp/re2/nfa.cc b/internal/cpp/re2/nfa.cc new file mode 100644 index 00000000000..865c41579d6 --- /dev/null +++ b/internal/cpp/re2/nfa.cc @@ -0,0 +1,651 @@ +// Copyright 2006-2007 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tested by search_test.cc. +// +// Prog::SearchNFA, an NFA search. +// This is an actual NFA like the theorists talk about, +// not the pseudo-NFA found in backtracking regexp implementations. +// +// IMPLEMENTATION +// +// This algorithm is a variant of one that appeared in Rob Pike's sam editor, +// which is a variant of the one described in Thompson's 1968 CACM paper. +// See http://swtch.com/~rsc/regexp/ for various history. The main feature +// over the DFA implementation is that it tracks submatch boundaries. +// +// When the choice of submatch boundaries is ambiguous, this particular +// implementation makes the same choices that traditional backtracking +// implementations (in particular, Perl and PCRE) do. +// Note that unlike in Perl and PCRE, this algorithm *cannot* take exponential +// time in the length of the input. +// +// Like Thompson's original machine and like the DFA implementation, this +// implementation notices a match only once it is one byte past it. + +#include +#include +#include +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/regexp.h" +#include "re2/sparse_array.h" +#include "re2/sparse_set.h" +#include "util/logging.h" +#include "util/strutil.h" + +namespace re2 { + +class NFA { +public: + NFA(Prog *prog); + ~NFA(); + + // Searches for a matching string. + // * If anchored is true, only considers matches starting at offset. + // Otherwise finds lefmost match at or after offset. + // * If longest is true, returns the longest match starting + // at the chosen start point. Otherwise returns the so-called + // left-biased match, the one traditional backtracking engines + // (like Perl and PCRE) find. + // Records submatch boundaries in submatch[1..nsubmatch-1]. + // Submatch[0] is the entire match. When there is a choice in + // which text matches each subexpression, the submatch boundaries + // are chosen to match what a backtracking implementation would choose. + bool Search(const StringPiece &text, const StringPiece &context, bool anchored, bool longest, StringPiece *submatch, int nsubmatch); + +private: + struct Thread { + union { + int ref; + Thread *next; // when on free list + }; + const char **capture; + }; + + // State for explicit stack in AddToThreadq. + struct AddState { + int id; // Inst to process + Thread *t; // if not null, set t0 = t before processing id + }; + + // Threadq is a list of threads. The list is sorted by the order + // in which Perl would explore that particular state -- the earlier + // choices appear earlier in the list. + typedef SparseArray Threadq; + + inline Thread *AllocThread(); + inline Thread *Incref(Thread *t); + inline void Decref(Thread *t); + + // Follows all empty arrows from id0 and enqueues all the states reached. + // Enqueues only the ByteRange instructions that match byte c. + // context is used (with p) for evaluating empty-width specials. + // p is the current input position, and t0 is the current thread. + void AddToThreadq(Threadq *q, int id0, int c, const StringPiece &context, const char *p, Thread *t0); + + // Run runq on byte c, appending new states to nextq. + // Updates matched_ and match_ as new, better matches are found. + // context is used (with p) for evaluating empty-width specials. + // p is the position of byte c in the input string for AddToThreadq; + // p-1 will be used when processing Match instructions. + // Frees all the threads on runq. + // If there is a shortcut to the end, returns that shortcut. + int Step(Threadq *runq, Threadq *nextq, int c, const StringPiece &context, const char *p); + + // Returns text version of capture information, for debugging. + std::string FormatCapture(const char **capture); + + void CopyCapture(const char **dst, const char **src) { memmove(dst, src, ncapture_ * sizeof src[0]); } + + Prog *prog_; // underlying program + int start_; // start instruction in program + int ncapture_; // number of submatches to track + bool longest_; // whether searching for longest match + bool endmatch_; // whether match must end at text.end() + const char *btext_; // beginning of text (for FormatSubmatch) + const char *etext_; // end of text (for endmatch_) + Threadq q0_, q1_; // pre-allocated for Search. + PODArray stack_; // pre-allocated for AddToThreadq + std::deque arena_; // thread arena + Thread *freelist_; // thread freelist + const char **match_; // best match so far + bool matched_; // any match so far? + + NFA(const NFA &) = delete; + NFA &operator=(const NFA &) = delete; +}; + +NFA::NFA(Prog *prog) { + prog_ = prog; + start_ = prog_->start(); + ncapture_ = 0; + longest_ = false; + endmatch_ = false; + btext_ = NULL; + etext_ = NULL; + q0_.resize(prog_->size()); + q1_.resize(prog_->size()); + // See NFA::AddToThreadq() for why this is so. + int nstack = 2 * prog_->inst_count(kInstCapture) + prog_->inst_count(kInstEmptyWidth) + prog_->inst_count(kInstNop) + 1; // + 1 for start inst + stack_ = PODArray(nstack); + freelist_ = NULL; + match_ = NULL; + matched_ = false; +} + +NFA::~NFA() { + delete[] match_; + for (const Thread &t : arena_) + delete[] t.capture; +} + +NFA::Thread *NFA::AllocThread() { + Thread *t = freelist_; + if (t != NULL) { + freelist_ = t->next; + t->ref = 1; + // We don't need to touch t->capture because + // the caller will immediately overwrite it. + return t; + } + arena_.emplace_back(); + t = &arena_.back(); + t->ref = 1; + t->capture = new const char *[ncapture_]; + return t; +} + +NFA::Thread *NFA::Incref(Thread *t) { + DCHECK(t != NULL); + t->ref++; + return t; +} + +void NFA::Decref(Thread *t) { + DCHECK(t != NULL); + t->ref--; + if (t->ref > 0) + return; + DCHECK_EQ(t->ref, 0); + t->next = freelist_; + freelist_ = t; +} + +// Follows all empty arrows from id0 and enqueues all the states reached. +// Enqueues only the ByteRange instructions that match byte c. +// context is used (with p) for evaluating empty-width specials. +// p is the current input position, and t0 is the current thread. +void NFA::AddToThreadq(Threadq *q, int id0, int c, const StringPiece &context, const char *p, Thread *t0) { + if (id0 == 0) + return; + + // Use stack_ to hold our stack of instructions yet to process. + // It was preallocated as follows: + // two entries per Capture; + // one entry per EmptyWidth; and + // one entry per Nop. + // This reflects the maximum number of stack pushes that each can + // perform. (Each instruction can be processed at most once.) + AddState *stk = stack_.data(); + int nstk = 0; + + stk[nstk++] = {id0, NULL}; + while (nstk > 0) { + DCHECK_LE(nstk, stack_.size()); + AddState a = stk[--nstk]; + + Loop: + if (a.t != NULL) { + // t0 was a thread that we allocated and copied in order to + // record the capture, so we must now decref it. + Decref(t0); + t0 = a.t; + } + + int id = a.id; + if (id == 0) + continue; + if (q->has_index(id)) { + continue; + } + + // Create entry in q no matter what. We might fill it in below, + // or we might not. Even if not, it is necessary to have it, + // so that we don't revisit id0 during the recursion. + q->set_new(id, NULL); + Thread **tp = &q->get_existing(id); + int j; + Thread *t; + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled " << ip->opcode() << " in AddToThreadq"; + break; + + case kInstFail: + break; + + case kInstAltMatch: + // Save state; will pick up at next byte. + t = Incref(t0); + *tp = t; + + DCHECK(!ip->last()); + a = {id + 1, NULL}; + goto Loop; + + case kInstNop: + if (!ip->last()) + stk[nstk++] = {id + 1, NULL}; + + // Continue on. + a = {ip->out(), NULL}; + goto Loop; + + case kInstCapture: + if (!ip->last()) + stk[nstk++] = {id + 1, NULL}; + + if ((j = ip->cap()) < ncapture_) { + // Push a dummy whose only job is to restore t0 + // once we finish exploring this possibility. + stk[nstk++] = {0, t0}; + + // Record capture. + t = AllocThread(); + CopyCapture(t->capture, t0->capture); + t->capture[j] = p; + t0 = t; + } + a = {ip->out(), NULL}; + goto Loop; + + case kInstByteRange: + if (!ip->Matches(c)) + goto Next; + + // Save state; will pick up at next byte. + t = Incref(t0); + *tp = t; + + if (ip->hint() == 0) + break; + a = {id + ip->hint(), NULL}; + goto Loop; + + case kInstMatch: + // Save state; will pick up at next byte. + t = Incref(t0); + *tp = t; + + Next: + if (ip->last()) + break; + a = {id + 1, NULL}; + goto Loop; + + case kInstEmptyWidth: + if (!ip->last()) + stk[nstk++] = {id + 1, NULL}; + + // Continue on if we have all the right flag bits. + if (ip->empty() & ~Prog::EmptyFlags(context, p)) + break; + a = {ip->out(), NULL}; + goto Loop; + } + } +} + +// Run runq on byte c, appending new states to nextq. +// Updates matched_ and match_ as new, better matches are found. +// context is used (with p) for evaluating empty-width specials. +// p is the position of byte c in the input string for AddToThreadq; +// p-1 will be used when processing Match instructions. +// Frees all the threads on runq. +// If there is a shortcut to the end, returns that shortcut. +int NFA::Step(Threadq *runq, Threadq *nextq, int c, const StringPiece &context, const char *p) { + nextq->clear(); + + for (Threadq::iterator i = runq->begin(); i != runq->end(); ++i) { + Thread *t = i->value(); + if (t == NULL) + continue; + + if (longest_) { + // Can skip any threads started after our current best match. + if (matched_ && match_[0] < t->capture[0]) { + Decref(t); + continue; + } + } + + int id = i->index(); + Prog::Inst *ip = prog_->inst(id); + + switch (ip->opcode()) { + default: + // Should only see the values handled below. + LOG(DFATAL) << "Unhandled " << ip->opcode() << " in step"; + break; + + case kInstByteRange: + AddToThreadq(nextq, ip->out(), c, context, p, t); + break; + + case kInstAltMatch: + if (i != runq->begin()) + break; + // The match is ours if we want it. + if (ip->greedy(prog_) || longest_) { + CopyCapture(match_, t->capture); + matched_ = true; + + Decref(t); + for (++i; i != runq->end(); ++i) { + if (i->value() != NULL) + Decref(i->value()); + } + runq->clear(); + if (ip->greedy(prog_)) + return ip->out1(); + return ip->out(); + } + break; + + case kInstMatch: { + // Avoid invoking undefined behavior (arithmetic on a null pointer) + // by storing p instead of p-1. (What would the latter even mean?!) + // This complements the special case in NFA::Search(). + if (p == NULL) { + CopyCapture(match_, t->capture); + match_[1] = p; + matched_ = true; + break; + } + + if (endmatch_ && p - 1 != etext_) + break; + + if (longest_) { + // Leftmost-longest mode: save this match only if + // it is either farther to the left or at the same + // point but longer than an existing match. + if (!matched_ || t->capture[0] < match_[0] || (t->capture[0] == match_[0] && p - 1 > match_[1])) { + CopyCapture(match_, t->capture); + match_[1] = p - 1; + matched_ = true; + } + } else { + // Leftmost-biased mode: this match is by definition + // better than what we've already found (see next line). + CopyCapture(match_, t->capture); + match_[1] = p - 1; + matched_ = true; + + // Cut off the threads that can only find matches + // worse than the one we just found: don't run the + // rest of the current Threadq. + Decref(t); + for (++i; i != runq->end(); ++i) { + if (i->value() != NULL) + Decref(i->value()); + } + runq->clear(); + return 0; + } + break; + } + } + Decref(t); + } + runq->clear(); + return 0; +} + +std::string NFA::FormatCapture(const char **capture) { + std::string s; + for (int i = 0; i < ncapture_; i += 2) { + if (capture[i] == NULL) + s += "(?,?)"; + else if (capture[i + 1] == NULL) + s += StringPrintf("(%td,?)", capture[i] - btext_); + else + s += StringPrintf("(%td,%td)", capture[i] - btext_, capture[i + 1] - btext_); + } + return s; +} + +bool NFA::Search(const StringPiece &text, const StringPiece &const_context, bool anchored, bool longest, StringPiece *submatch, int nsubmatch) { + if (start_ == 0) + return false; + + StringPiece context = const_context; + if (context.data() == NULL) + context = text; + + // Sanity check: make sure that text lies within context. + if (BeginPtr(text) < BeginPtr(context) || EndPtr(text) > EndPtr(context)) { + LOG(DFATAL) << "context does not contain text"; + return false; + } + + if (prog_->anchor_start() && BeginPtr(context) != BeginPtr(text)) + return false; + if (prog_->anchor_end() && EndPtr(context) != EndPtr(text)) + return false; + anchored |= prog_->anchor_start(); + if (prog_->anchor_end()) { + longest = true; + endmatch_ = true; + } + + if (nsubmatch < 0) { + LOG(DFATAL) << "Bad args: nsubmatch=" << nsubmatch; + return false; + } + + // Save search parameters. + ncapture_ = 2 * nsubmatch; + longest_ = longest; + + if (nsubmatch == 0) { + // We need to maintain match[0], both to distinguish the + // longest match (if longest is true) and also to tell + // whether we've seen any matches at all. + ncapture_ = 2; + } + + match_ = new const char *[ncapture_]; + memset(match_, 0, ncapture_ * sizeof match_[0]); + matched_ = false; + + // For debugging prints. + btext_ = context.data(); + // For convenience. + etext_ = text.data() + text.size(); + + // Set up search. + Threadq *runq = &q0_; + Threadq *nextq = &q1_; + runq->clear(); + nextq->clear(); + + // Loop over the text, stepping the machine. + for (const char *p = text.data();; p++) { + // This is a no-op the first time around the loop because runq is empty. + int id = Step(runq, nextq, p < etext_ ? p[0] & 0xFF : -1, context, p); + DCHECK_EQ(runq->size(), 0); + using std::swap; + swap(nextq, runq); + nextq->clear(); + if (id != 0) { + // We're done: full match ahead. + p = etext_; + for (;;) { + Prog::Inst *ip = prog_->inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "Unexpected opcode in short circuit: " << ip->opcode(); + break; + + case kInstCapture: + if (ip->cap() < ncapture_) + match_[ip->cap()] = p; + id = ip->out(); + continue; + + case kInstNop: + id = ip->out(); + continue; + + case kInstMatch: + match_[1] = p; + matched_ = true; + break; + } + break; + } + break; + } + + if (p > etext_) + break; + + // Start a new thread if there have not been any matches. + // (No point in starting a new thread if there have been + // matches, since it would be to the right of the match + // we already found.) + if (!matched_ && (!anchored || p == text.data())) { + // Try to use prefix accel (e.g. memchr) to skip ahead. + // The search must be unanchored and there must be zero + // possible matches already. + if (!anchored && runq->size() == 0 && p < etext_ && prog_->can_prefix_accel()) { + p = reinterpret_cast(prog_->PrefixAccel(p, etext_ - p)); + if (p == NULL) + p = etext_; + } + + Thread *t = AllocThread(); + CopyCapture(t->capture, match_); + t->capture[0] = p; + AddToThreadq(runq, start_, p < etext_ ? p[0] & 0xFF : -1, context, p, t); + Decref(t); + } + + // If all the threads have died, stop early. + if (runq->size() == 0) { + break; + } + + // Avoid invoking undefined behavior (arithmetic on a null pointer) + // by simply not continuing the loop. + // This complements the special case in NFA::Step(). + if (p == NULL) { + (void)Step(runq, nextq, -1, context, p); + DCHECK_EQ(runq->size(), 0); + using std::swap; + swap(nextq, runq); + nextq->clear(); + break; + } + } + + for (Threadq::iterator i = runq->begin(); i != runq->end(); ++i) { + if (i->value() != NULL) + Decref(i->value()); + } + + if (matched_) { + for (int i = 0; i < nsubmatch; i++) + submatch[i] = StringPiece(match_[2 * i], static_cast(match_[2 * i + 1] - match_[2 * i])); + return true; + } + return false; +} + +bool Prog::SearchNFA(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch) { + + NFA nfa(this); + StringPiece sp; + if (kind == kFullMatch) { + anchor = kAnchored; + if (nmatch == 0) { + match = &sp; + nmatch = 1; + } + } + if (!nfa.Search(text, context, anchor == kAnchored, kind != kFirstMatch, match, nmatch)) + return false; + if (kind == kFullMatch && EndPtr(match[0]) != EndPtr(text)) + return false; + return true; +} + +// For each instruction i in the program reachable from the start, compute the +// number of instructions reachable from i by following only empty transitions +// and record that count as fanout[i]. +// +// fanout holds the results and is also the work queue for the outer iteration. +// reachable holds the reached nodes for the inner iteration. +void Prog::Fanout(SparseArray *fanout) { + DCHECK_EQ(fanout->max_size(), size()); + SparseSet reachable(size()); + fanout->clear(); + fanout->set_new(start(), 0); + for (SparseArray::iterator i = fanout->begin(); i != fanout->end(); ++i) { + int *count = &i->value(); + reachable.clear(); + reachable.insert(i->index()); + for (SparseSet::iterator j = reachable.begin(); j != reachable.end(); ++j) { + int id = *j; + Prog::Inst *ip = inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled " << ip->opcode() << " in Prog::Fanout()"; + break; + + case kInstByteRange: + if (!ip->last()) + reachable.insert(id + 1); + + (*count)++; + if (!fanout->has_index(ip->out())) { + fanout->set_new(ip->out(), 0); + } + break; + + case kInstAltMatch: + DCHECK(!ip->last()); + reachable.insert(id + 1); + break; + + case kInstCapture: + case kInstEmptyWidth: + case kInstNop: + if (!ip->last()) + reachable.insert(id + 1); + + reachable.insert(ip->out()); + break; + + case kInstMatch: + if (!ip->last()) + reachable.insert(id + 1); + break; + + case kInstFail: + break; + } + } + } +} + +} // namespace re2 diff --git a/internal/cpp/re2/onepass.cc b/internal/cpp/re2/onepass.cc new file mode 100644 index 00000000000..01c331b340a --- /dev/null +++ b/internal/cpp/re2/onepass.cc @@ -0,0 +1,577 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tested by search_test.cc. +// +// Prog::SearchOnePass is an efficient implementation of +// regular expression search with submatch tracking for +// what I call "one-pass regular expressions". (An alternate +// name might be "backtracking-free regular expressions".) +// +// One-pass regular expressions have the property that +// at each input byte during an anchored match, there may be +// multiple alternatives but only one can proceed for any +// given input byte. +// +// For example, the regexp /x*yx*/ is one-pass: you read +// x's until a y, then you read the y, then you keep reading x's. +// At no point do you have to guess what to do or back up +// and try a different guess. +// +// On the other hand, /x*x/ is not one-pass: when you're +// looking at an input "x", it's not clear whether you should +// use it to extend the x* or as the final x. +// +// More examples: /([^ ]*) (.*)/ is one-pass; /(.*) (.*)/ is not. +// /(\d+)-(\d+)/ is one-pass; /(\d+).(\d+)/ is not. +// +// A simple intuition for identifying one-pass regular expressions +// is that it's always immediately obvious when a repetition ends. +// It must also be immediately obvious which branch of an | to take: +// +// /x(y|z)/ is one-pass, but /(xy|xz)/ is not. +// +// The NFA-based search in nfa.cc does some bookkeeping to +// avoid the need for backtracking and its associated exponential blowup. +// But if we have a one-pass regular expression, there is no +// possibility of backtracking, so there is no need for the +// extra bookkeeping. Hence, this code. +// +// On a one-pass regular expression, the NFA code in nfa.cc +// runs at about 1/20 of the backtracking-based PCRE speed. +// In contrast, the code in this file runs at about the same +// speed as PCRE. +// +// One-pass regular expressions get used a lot when RE is +// used for parsing simple strings, so it pays off to +// notice them and handle them efficiently. +// +// See also Anne Brüggemann-Klein and Derick Wood, +// "One-unambiguous regular languages", Information and Computation 142(2). + +#include +#include +#include +#include +#include +#include + +#include "util/util.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/utf.h" +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/sparse_set.h" +#include "re2/stringpiece.h" + +// Silence "zero-sized array in struct/union" warning for OneState::action. +#ifdef _MSC_VER +#pragma warning(disable: 4200) +#endif + +namespace re2 { + +// The key insight behind this implementation is that the +// non-determinism in an NFA for a one-pass regular expression +// is contained. To explain what that means, first a +// refresher about what regular expression programs look like +// and how the usual NFA execution runs. +// +// In a regular expression program, only the kInstByteRange +// instruction processes an input byte c and moves on to the +// next byte in the string (it does so if c is in the given range). +// The kInstByteRange instructions correspond to literal characters +// and character classes in the regular expression. +// +// The kInstAlt instructions are used as wiring to connect the +// kInstByteRange instructions together in interesting ways when +// implementing | + and *. +// The kInstAlt instruction forks execution, like a goto that +// jumps to ip->out() and ip->out1() in parallel. Each of the +// resulting computation paths is called a thread. +// +// The other instructions -- kInstEmptyWidth, kInstMatch, kInstCapture -- +// are interesting in their own right but like kInstAlt they don't +// advance the input pointer. Only kInstByteRange does. +// +// The automaton execution in nfa.cc runs all the possible +// threads of execution in lock-step over the input. To process +// a particular byte, each thread gets run until it either dies +// or finds a kInstByteRange instruction matching the byte. +// If the latter happens, the thread stops just past the +// kInstByteRange instruction (at ip->out()) and waits for +// the other threads to finish processing the input byte. +// Then, once all the threads have processed that input byte, +// the whole process repeats. The kInstAlt state instruction +// might create new threads during input processing, but no +// matter what, all the threads stop after a kInstByteRange +// and wait for the other threads to "catch up". +// Running in lock step like this ensures that the NFA reads +// the input string only once. +// +// Each thread maintains its own set of capture registers +// (the string positions at which it executed the kInstCapture +// instructions corresponding to capturing parentheses in the +// regular expression). Repeated copying of the capture registers +// is the main performance bottleneck in the NFA implementation. +// +// A regular expression program is "one-pass" if, no matter what +// the input string, there is only one thread that makes it +// past a kInstByteRange instruction at each input byte. This means +// that there is in some sense only one active thread throughout +// the execution. Other threads might be created during the +// processing of an input byte, but they are ephemeral: only one +// thread is left to start processing the next input byte. +// This is what I meant above when I said the non-determinism +// was "contained". +// +// To execute a one-pass regular expression program, we can build +// a DFA (no non-determinism) that has at most as many states as +// the NFA (compare this to the possibly exponential number of states +// in the general case). Each state records, for each possible +// input byte, the next state along with the conditions required +// before entering that state -- empty-width flags that must be true +// and capture operations that must be performed. It also records +// whether a set of conditions required to finish a match at that +// point in the input rather than process the next byte. + +// A state in the one-pass NFA - just an array of actions indexed +// by the bytemap_[] of the next input byte. (The bytemap +// maps next input bytes into equivalence classes, to reduce +// the memory footprint.) +struct OneState { + uint32_t matchcond; // conditions to match right now. + uint32_t action[256]; +}; + +// The uint32_t conditions in the action are a combination of +// condition and capture bits and the next state. The bottom 16 bits +// are the condition and capture bits, and the top 16 are the index of +// the next state. +// +// Bits 0-5 are the empty-width flags from prog.h. +// Bit 6 is kMatchWins, which means the match takes +// priority over moving to next in a first-match search. +// The remaining bits mark capture registers that should +// be set to the current input position. The capture bits +// start at index 2, since the search loop can take care of +// cap[0], cap[1] (the overall match position). +// That means we can handle up to 5 capturing parens: $1 through $4, plus $0. +// No input position can satisfy both kEmptyWordBoundary +// and kEmptyNonWordBoundary, so we can use that as a sentinel +// instead of needing an extra bit. + +static const int kIndexShift = 16; // number of bits below index +static const int kEmptyShift = 6; // number of empty flags in prog.h +static const int kRealCapShift = kEmptyShift + 1; +static const int kRealMaxCap = (kIndexShift - kRealCapShift) / 2 * 2; + +// Parameters used to skip over cap[0], cap[1]. +static const int kCapShift = kRealCapShift - 2; +static const int kMaxCap = kRealMaxCap + 2; + +static const uint32_t kMatchWins = 1 << kEmptyShift; +static const uint32_t kCapMask = ((1 << kRealMaxCap) - 1) << kRealCapShift; + +static const uint32_t kImpossible = kEmptyWordBoundary | kEmptyNonWordBoundary; + +// Check, at compile time, that prog.h agrees with math above. +// This function is never called. +void OnePass_Checks() { + static_assert((1<(nodes + statesize*nodeindex); +} + +bool Prog::SearchOnePass(const StringPiece& text, + const StringPiece& const_context, + Anchor anchor, MatchKind kind, + StringPiece* match, int nmatch) { + if (anchor != kAnchored && kind != kFullMatch) { + LOG(DFATAL) << "Cannot use SearchOnePass for unanchored matches."; + return false; + } + + // Make sure we have at least cap[1], + // because we use it to tell if we matched. + int ncap = 2*nmatch; + if (ncap < 2) + ncap = 2; + + const char* cap[kMaxCap]; + for (int i = 0; i < ncap; i++) + cap[i] = NULL; + + const char* matchcap[kMaxCap]; + for (int i = 0; i < ncap; i++) + matchcap[i] = NULL; + + StringPiece context = const_context; + if (context.data() == NULL) + context = text; + if (anchor_start() && BeginPtr(context) != BeginPtr(text)) + return false; + if (anchor_end() && EndPtr(context) != EndPtr(text)) + return false; + if (anchor_end()) + kind = kFullMatch; + + uint8_t* nodes = onepass_nodes_.data(); + int statesize = sizeof(uint32_t) + bytemap_range()*sizeof(uint32_t); + + // start() is always mapped to the zeroth OneState. + OneState* state = IndexToNode(nodes, statesize, 0); + uint8_t* bytemap = bytemap_; + const char* bp = text.data(); + const char* ep = text.data() + text.size(); + const char* p; + bool matched = false; + matchcap[0] = bp; + cap[0] = bp; + uint32_t nextmatchcond = state->matchcond; + for (p = bp; p < ep; p++) { + int c = bytemap[*p & 0xFF]; + uint32_t matchcond = nextmatchcond; + uint32_t cond = state->action[c]; + + // Determine whether we can reach act->next. + // If so, advance state and nextmatchcond. + if ((cond & kEmptyAllFlags) == 0 || Satisfy(cond, context, p)) { + uint32_t nextindex = cond >> kIndexShift; + state = IndexToNode(nodes, statesize, nextindex); + nextmatchcond = state->matchcond; + } else { + state = NULL; + nextmatchcond = kImpossible; + } + + // This code section is carefully tuned. + // The goto sequence is about 10% faster than the + // obvious rewrite as a large if statement in the + // ASCIIMatchRE2 and DotMatchRE2 benchmarks. + + // Saving the match capture registers is expensive. + // Is this intermediate match worth thinking about? + + // Not if we want a full match. + if (kind == kFullMatch) + goto skipmatch; + + // Not if it's impossible. + if (matchcond == kImpossible) + goto skipmatch; + + // Not if the possible match is beaten by the certain + // match at the next byte. When this test is useless + // (e.g., HTTPPartialMatchRE2) it slows the loop by + // about 10%, but when it avoids work (e.g., DotMatchRE2), + // it cuts the loop execution by about 45%. + if ((cond & kMatchWins) == 0 && (nextmatchcond & kEmptyAllFlags) == 0) + goto skipmatch; + + // Finally, the match conditions must be satisfied. + if ((matchcond & kEmptyAllFlags) == 0 || Satisfy(matchcond, context, p)) { + for (int i = 2; i < 2*nmatch; i++) + matchcap[i] = cap[i]; + if (nmatch > 1 && (matchcond & kCapMask)) + ApplyCaptures(matchcond, p, matchcap, ncap); + matchcap[1] = p; + matched = true; + + // If we're in longest match mode, we have to keep + // going and see if we find a longer match. + // In first match mode, we can stop if the match + // takes priority over the next state for this input byte. + // That bit is per-input byte and thus in cond, not matchcond. + if (kind == kFirstMatch && (cond & kMatchWins)) + goto done; + } + + skipmatch: + if (state == NULL) + goto done; + if ((cond & kCapMask) && nmatch > 1) + ApplyCaptures(cond, p, cap, ncap); + } + + // Look for match at end of input. + { + uint32_t matchcond = state->matchcond; + if (matchcond != kImpossible && + ((matchcond & kEmptyAllFlags) == 0 || Satisfy(matchcond, context, p))) { + if (nmatch > 1 && (matchcond & kCapMask)) + ApplyCaptures(matchcond, p, cap, ncap); + for (int i = 2; i < ncap; i++) + matchcap[i] = cap[i]; + matchcap[1] = p; + matched = true; + } + } + +done: + if (!matched) + return false; + for (int i = 0; i < nmatch; i++) + match[i] = + StringPiece(matchcap[2 * i], + static_cast(matchcap[2 * i + 1] - matchcap[2 * i])); + return true; +} + + +// Analysis to determine whether a given regexp program is one-pass. + +// If ip is not on workq, adds ip to work queue and returns true. +// If ip is already on work queue, does nothing and returns false. +// If ip is NULL, does nothing and returns true (pretends to add it). +typedef SparseSet Instq; +static bool AddQ(Instq *q, int id) { + if (id == 0) + return true; + if (q->contains(id)) + return false; + q->insert(id); + return true; +} + +struct InstCond { + int id; + uint32_t cond; +}; + +// Returns whether this is a one-pass program; that is, +// returns whether it is safe to use SearchOnePass on this program. +// These conditions must be true for any instruction ip: +// +// (1) for any other Inst nip, there is at most one input-free +// path from ip to nip. +// (2) there is at most one kInstByte instruction reachable from +// ip that matches any particular byte c. +// (3) there is at most one input-free path from ip to a kInstMatch +// instruction. +// +// This is actually just a conservative approximation: it might +// return false when the answer is true, when kInstEmptyWidth +// instructions are involved. +// Constructs and saves corresponding one-pass NFA on success. +bool Prog::IsOnePass() { + if (did_onepass_) + return onepass_nodes_.data() != NULL; + did_onepass_ = true; + + if (start() == 0) // no match + return false; + + // Steal memory for the one-pass NFA from the overall DFA budget. + // Willing to use at most 1/4 of the DFA budget (heuristic). + // Limit max node count to 65000 as a conservative estimate to + // avoid overflowing 16-bit node index in encoding. + int maxnodes = 2 + inst_count(kInstByteRange); + int statesize = sizeof(uint32_t) + bytemap_range()*sizeof(uint32_t); + if (maxnodes >= 65000 || dfa_mem_ / 4 / statesize < maxnodes) + return false; + + // Flood the graph starting at the start state, and check + // that in each reachable state, each possible byte leads + // to a unique next state. + int stacksize = inst_count(kInstCapture) + + inst_count(kInstEmptyWidth) + + inst_count(kInstNop) + 1; // + 1 for start inst + PODArray stack(stacksize); + + int size = this->size(); + PODArray nodebyid(size); // indexed by ip + memset(nodebyid.data(), 0xFF, size*sizeof nodebyid[0]); + + // Originally, nodes was a uint8_t[maxnodes*statesize], but that was + // unnecessarily optimistic: why allocate a large amount of memory + // upfront for a large program when it is unlikely to be one-pass? + std::vector nodes; + + Instq tovisit(size), workq(size); + AddQ(&tovisit, start()); + nodebyid[start()] = 0; + int nalloc = 1; + nodes.insert(nodes.end(), statesize, 0); + for (Instq::iterator it = tovisit.begin(); it != tovisit.end(); ++it) { + int id = *it; + int nodeindex = nodebyid[id]; + OneState* node = IndexToNode(nodes.data(), statesize, nodeindex); + + // Flood graph using manual stack, filling in actions as found. + // Default is none. + for (int b = 0; b < bytemap_range_; b++) + node->action[b] = kImpossible; + node->matchcond = kImpossible; + + workq.clear(); + bool matched = false; + int nstack = 0; + stack[nstack].id = id; + stack[nstack++].cond = 0; + while (nstack > 0) { + int id = stack[--nstack].id; + uint32_t cond = stack[nstack].cond; + + Loop: + Prog::Inst* ip = inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstAltMatch: + // TODO(rsc): Ignoring kInstAltMatch optimization. + // Should implement it in this engine, but it's subtle. + DCHECK(!ip->last()); + // If already on work queue, (1) is violated: bail out. + if (!AddQ(&workq, id+1)) + goto fail; + id = id+1; + goto Loop; + + case kInstByteRange: { + int nextindex = nodebyid[ip->out()]; + if (nextindex == -1) { + if (nalloc >= maxnodes) { + goto fail; + } + nextindex = nalloc; + AddQ(&tovisit, ip->out()); + nodebyid[ip->out()] = nalloc; + nalloc++; + nodes.insert(nodes.end(), statesize, 0); + // Update node because it might have been invalidated. + node = IndexToNode(nodes.data(), statesize, nodeindex); + } + for (int c = ip->lo(); c <= ip->hi(); c++) { + int b = bytemap_[c]; + // Skip any bytes immediately after c that are also in b. + while (c < 256-1 && bytemap_[c+1] == b) + c++; + uint32_t act = node->action[b]; + uint32_t newact = (nextindex << kIndexShift) | cond; + if (matched) + newact |= kMatchWins; + if ((act & kImpossible) == kImpossible) { + node->action[b] = newact; + } else if (act != newact) { + goto fail; + } + } + if (ip->foldcase()) { + Rune lo = std::max(ip->lo(), 'a') + 'A' - 'a'; + Rune hi = std::min(ip->hi(), 'z') + 'A' - 'a'; + for (int c = lo; c <= hi; c++) { + int b = bytemap_[c]; + // Skip any bytes immediately after c that are also in b. + while (c < 256-1 && bytemap_[c+1] == b) + c++; + uint32_t act = node->action[b]; + uint32_t newact = (nextindex << kIndexShift) | cond; + if (matched) + newact |= kMatchWins; + if ((act & kImpossible) == kImpossible) { + node->action[b] = newact; + } else if (act != newact) { + goto fail; + } + } + } + + if (ip->last()) + break; + // If already on work queue, (1) is violated: bail out. + if (!AddQ(&workq, id+1)) + goto fail; + id = id+1; + goto Loop; + } + + case kInstCapture: + case kInstEmptyWidth: + case kInstNop: + if (!ip->last()) { + // If already on work queue, (1) is violated: bail out. + if (!AddQ(&workq, id+1)) + goto fail; + stack[nstack].id = id+1; + stack[nstack++].cond = cond; + } + + if (ip->opcode() == kInstCapture && ip->cap() < kMaxCap) + cond |= (1 << kCapShift) << ip->cap(); + if (ip->opcode() == kInstEmptyWidth) + cond |= ip->empty(); + + // kInstCapture and kInstNop always proceed to ip->out(). + // kInstEmptyWidth only sometimes proceeds to ip->out(), + // but as a conservative approximation we assume it always does. + // We could be a little more precise by looking at what c + // is, but that seems like overkill. + + // If already on work queue, (1) is violated: bail out. + if (!AddQ(&workq, ip->out())) { + goto fail; + } + id = ip->out(); + goto Loop; + + case kInstMatch: + if (matched) { + // (3) is violated + goto fail; + } + matched = true; + node->matchcond = cond; + + if (ip->last()) + break; + // If already on work queue, (1) is violated: bail out. + if (!AddQ(&workq, id+1)) + goto fail; + id = id+1; + goto Loop; + + case kInstFail: + break; + } + } + } + + dfa_mem_ -= nalloc*statesize; + onepass_nodes_ = PODArray(nalloc*statesize); + memmove(onepass_nodes_.data(), nodes.data(), nalloc*statesize); + return true; + +fail: + return false; +} + +} // namespace re2 diff --git a/internal/cpp/re2/parse.cc b/internal/cpp/re2/parse.cc new file mode 100644 index 00000000000..2350af0ecd8 --- /dev/null +++ b/internal/cpp/re2/parse.cc @@ -0,0 +1,2481 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Regular expression parser. + +// The parser is a simple precedence-based parser with a +// manual stack. The parsing work is done by the methods +// of the ParseState class. The Regexp::Parse function is +// essentially just a lexer that calls the ParseState method +// for each token. + +// The parser recognizes POSIX extended regular expressions +// excluding backreferences, collating elements, and collating +// classes. It also allows the empty string as a regular expression +// and recognizes the Perl escape sequences \d, \s, \w, \D, \S, and \W. +// See regexp.h for rationale. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util/util.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/utf.h" +#include "re2/pod_array.h" +#include "re2/regexp.h" +#include "re2/stringpiece.h" +#include "re2/unicode_casefold.h" +#include "re2/unicode_groups.h" +#include "re2/walker-inl.h" + +#if defined(RE2_USE_ICU) +//#include "unicode/uniset.h" +//#include "unicode/unistr.h" +//#include "unicode/utypes.h" +#endif + +namespace re2 { + +// Controls the maximum repeat count permitted by the parser. +static int maximum_repeat_count = 1000; + +void Regexp::FUZZING_ONLY_set_maximum_repeat_count(int i) { + maximum_repeat_count = i; +} + +// Regular expression parse state. +// The list of parsed regexps so far is maintained as a vector of +// Regexp pointers called the stack. Left parenthesis and vertical +// bar markers are also placed on the stack, as Regexps with +// non-standard opcodes. +// Scanning a left parenthesis causes the parser to push a left parenthesis +// marker on the stack. +// Scanning a vertical bar causes the parser to pop the stack until it finds a +// vertical bar or left parenthesis marker (not popping the marker), +// concatenate all the popped results, and push them back on +// the stack (DoConcatenation). +// Scanning a right parenthesis causes the parser to act as though it +// has seen a vertical bar, which then leaves the top of the stack in the +// form LeftParen regexp VerticalBar regexp VerticalBar ... regexp VerticalBar. +// The parser pops all this off the stack and creates an alternation of the +// regexps (DoAlternation). + +class Regexp::ParseState { + public: + ParseState(ParseFlags flags, const StringPiece& whole_regexp, + RegexpStatus* status); + ~ParseState(); + + ParseFlags flags() { return flags_; } + int rune_max() { return rune_max_; } + + // Parse methods. All public methods return a bool saying + // whether parsing should continue. If a method returns + // false, it has set fields in *status_, and the parser + // should return NULL. + + // Pushes the given regular expression onto the stack. + // Could check for too much memory used here. + bool PushRegexp(Regexp* re); + + // Pushes the literal rune r onto the stack. + bool PushLiteral(Rune r); + + // Pushes a regexp with the given op (and no args) onto the stack. + bool PushSimpleOp(RegexpOp op); + + // Pushes a ^ onto the stack. + bool PushCaret(); + + // Pushes a \b (word == true) or \B (word == false) onto the stack. + bool PushWordBoundary(bool word); + + // Pushes a $ onto the stack. + bool PushDollar(); + + // Pushes a . onto the stack + bool PushDot(); + + // Pushes a repeat operator regexp onto the stack. + // A valid argument for the operator must already be on the stack. + // s is the name of the operator, for use in error messages. + bool PushRepeatOp(RegexpOp op, const StringPiece& s, bool nongreedy); + + // Pushes a repetition regexp onto the stack. + // A valid argument for the operator must already be on the stack. + bool PushRepetition(int min, int max, const StringPiece& s, bool nongreedy); + + // Checks whether a particular regexp op is a marker. + bool IsMarker(RegexpOp op); + + // Processes a left parenthesis in the input. + // Pushes a marker onto the stack. + bool DoLeftParen(const StringPiece& name); + bool DoLeftParenNoCapture(); + + // Processes a vertical bar in the input. + bool DoVerticalBar(); + + // Processes a right parenthesis in the input. + bool DoRightParen(); + + // Processes the end of input, returning the final regexp. + Regexp* DoFinish(); + + // Finishes the regexp if necessary, preparing it for use + // in a more complicated expression. + // If it is a CharClassBuilder, converts into a CharClass. + Regexp* FinishRegexp(Regexp*); + + // These routines don't manipulate the parse stack + // directly, but they do need to look at flags_. + // ParseCharClass also manipulates the internals of Regexp + // while creating *out_re. + + // Parse a character class into *out_re. + // Removes parsed text from s. + bool ParseCharClass(StringPiece* s, Regexp** out_re, + RegexpStatus* status); + + // Parse a character class character into *rp. + // Removes parsed text from s. + bool ParseCCCharacter(StringPiece* s, Rune *rp, + const StringPiece& whole_class, + RegexpStatus* status); + + // Parse a character class range into rr. + // Removes parsed text from s. + bool ParseCCRange(StringPiece* s, RuneRange* rr, + const StringPiece& whole_class, + RegexpStatus* status); + + // Parse a Perl flag set or non-capturing group from s. + bool ParsePerlFlags(StringPiece* s); + + + // Finishes the current concatenation, + // collapsing it into a single regexp on the stack. + void DoConcatenation(); + + // Finishes the current alternation, + // collapsing it to a single regexp on the stack. + void DoAlternation(); + + // Generalized DoAlternation/DoConcatenation. + void DoCollapse(RegexpOp op); + + // Maybe concatenate Literals into LiteralString. + bool MaybeConcatString(int r, ParseFlags flags); + +private: + ParseFlags flags_; + StringPiece whole_regexp_; + RegexpStatus* status_; + Regexp* stacktop_; + int ncap_; // number of capturing parens seen + int rune_max_; // maximum char value for this encoding + + ParseState(const ParseState&) = delete; + ParseState& operator=(const ParseState&) = delete; +}; + +// Pseudo-operators - only on parse stack. +const RegexpOp kLeftParen = static_cast(kMaxRegexpOp+1); +const RegexpOp kVerticalBar = static_cast(kMaxRegexpOp+2); + +Regexp::ParseState::ParseState(ParseFlags flags, + const StringPiece& whole_regexp, + RegexpStatus* status) + : flags_(flags), whole_regexp_(whole_regexp), + status_(status), stacktop_(NULL), ncap_(0) { + if (flags_ & Latin1) + rune_max_ = 0xFF; + else + rune_max_ = Runemax; +} + +// Cleans up by freeing all the regexps on the stack. +Regexp::ParseState::~ParseState() { + Regexp* next; + for (Regexp* re = stacktop_; re != NULL; re = next) { + next = re->down_; + re->down_ = NULL; + if (re->op() == kLeftParen) + delete re->arguments.capture.name_; + re->Decref(); + } +} + +// Finishes the regexp if necessary, preparing it for use in +// a more complex expression. +// If it is a CharClassBuilder, converts into a CharClass. +Regexp* Regexp::ParseState::FinishRegexp(Regexp* re) { + if (re == NULL) + return NULL; + re->down_ = NULL; + + if (re->op_ == kRegexpCharClass && re->arguments.char_class.ccb_ != NULL) { + CharClassBuilder* ccb = re->arguments.char_class.ccb_; + re->arguments.char_class.ccb_ = NULL; + re->arguments.char_class.cc_ = ccb->GetCharClass(); + delete ccb; + } + + return re; +} + +// Pushes the given regular expression onto the stack. +// Could check for too much memory used here. +bool Regexp::ParseState::PushRegexp(Regexp* re) { + MaybeConcatString(-1, NoParseFlags); + + // Special case: a character class of one character is just + // a literal. This is a common idiom for escaping + // single characters (e.g., [.] instead of \.), and some + // analysis does better with fewer character classes. + // Similarly, [Aa] can be rewritten as a literal A with ASCII case folding. + if (re->op_ == kRegexpCharClass && re->arguments.char_class.ccb_ != NULL) { + re->arguments.char_class.ccb_->RemoveAbove(rune_max_); + if (re->arguments.char_class.ccb_->size() == 1) { + Rune r = re->arguments.char_class.ccb_->begin()->lo; + re->Decref(); + re = new Regexp(kRegexpLiteral, flags_); + re->arguments.rune_ = r; + } else if (re->arguments.char_class.ccb_->size() == 2) { + Rune r = re->arguments.char_class.ccb_->begin()->lo; + if ('A' <= r && r <= 'Z' && re->arguments.char_class.ccb_->Contains(r + 'a' - 'A')) { + re->Decref(); + re = new Regexp(kRegexpLiteral, flags_ | FoldCase); + re->arguments.rune_ = r + 'a' - 'A'; + } + } + } + + if (!IsMarker(re->op())) + re->simple_ = re->ComputeSimple(); + re->down_ = stacktop_; + stacktop_ = re; + return true; +} + +// Searches the case folding tables and returns the CaseFold* that contains r. +// If there isn't one, returns the CaseFold* with smallest f->lo bigger than r. +// If there isn't one, returns NULL. +const CaseFold* LookupCaseFold(const CaseFold *f, int n, Rune r) { + const CaseFold* ef = f + n; + + // Binary search for entry containing r. + while (n > 0) { + int m = n/2; + if (f[m].lo <= r && r <= f[m].hi) + return &f[m]; + if (r < f[m].lo) { + n = m; + } else { + f += m+1; + n -= m+1; + } + } + + // There is no entry that contains r, but f points + // where it would have been. Unless f points at + // the end of the array, it points at the next entry + // after r. + if (f < ef) + return f; + + // No entry contains r; no entry contains runes > r. + return NULL; +} + +// Returns the result of applying the fold f to the rune r. +Rune ApplyFold(const CaseFold *f, Rune r) { + switch (f->delta) { + default: + return r + f->delta; + + case EvenOddSkip: // even <-> odd but only applies to every other + if ((r - f->lo) % 2) + return r; + FALLTHROUGH_INTENDED; + case EvenOdd: // even <-> odd + if (r%2 == 0) + return r + 1; + return r - 1; + + case OddEvenSkip: // odd <-> even but only applies to every other + if ((r - f->lo) % 2) + return r; + FALLTHROUGH_INTENDED; + case OddEven: // odd <-> even + if (r%2 == 1) + return r + 1; + return r - 1; + } +} + +// Returns the next Rune in r's folding cycle (see unicode_casefold.h). +// Examples: +// CycleFoldRune('A') = 'a' +// CycleFoldRune('a') = 'A' +// +// CycleFoldRune('K') = 'k' +// CycleFoldRune('k') = 0x212A (Kelvin) +// CycleFoldRune(0x212A) = 'K' +// +// CycleFoldRune('?') = '?' +Rune CycleFoldRune(Rune r) { + const CaseFold* f = LookupCaseFold(unicode_casefold, num_unicode_casefold, r); + if (f == NULL || r < f->lo) + return r; + return ApplyFold(f, r); +} + +// Add lo-hi to the class, along with their fold-equivalent characters. +// If lo-hi is already in the class, assume that the fold-equivalent +// chars are there too, so there's no work to do. +static void AddFoldedRange(CharClassBuilder* cc, Rune lo, Rune hi, int depth) { + // AddFoldedRange calls itself recursively for each rune in the fold cycle. + // Most folding cycles are small: there aren't any bigger than four in the + // current Unicode tables. make_unicode_casefold.py checks that + // the cycles are not too long, and we double-check here using depth. + if (depth > 10) { + LOG(DFATAL) << "AddFoldedRange recurses too much."; + return; + } + + if (!cc->AddRange(lo, hi)) // lo-hi was already there? we're done + return; + + while (lo <= hi) { + const CaseFold* f = LookupCaseFold(unicode_casefold, num_unicode_casefold, lo); + if (f == NULL) // lo has no fold, nor does anything above lo + break; + if (lo < f->lo) { // lo has no fold; next rune with a fold is f->lo + lo = f->lo; + continue; + } + + // Add in the result of folding the range lo - f->hi + // and that range's fold, recursively. + Rune lo1 = lo; + Rune hi1 = std::min(hi, f->hi); + switch (f->delta) { + default: + lo1 += f->delta; + hi1 += f->delta; + break; + case EvenOdd: + if (lo1%2 == 1) + lo1--; + if (hi1%2 == 0) + hi1++; + break; + case OddEven: + if (lo1%2 == 0) + lo1--; + if (hi1%2 == 1) + hi1++; + break; + } + AddFoldedRange(cc, lo1, hi1, depth+1); + + // Pick up where this fold left off. + lo = f->hi + 1; + } +} + +// Pushes the literal rune r onto the stack. +bool Regexp::ParseState::PushLiteral(Rune r) { + // Do case folding if needed. + if ((flags_ & FoldCase) && CycleFoldRune(r) != r) { + Regexp* re = new Regexp(kRegexpCharClass, flags_ & ~FoldCase); + re->arguments.char_class.ccb_ = new CharClassBuilder; + Rune r1 = r; + do { + if (!(flags_ & NeverNL) || r != '\n') { + re->arguments.char_class.ccb_->AddRange(r, r); + } + r = CycleFoldRune(r); + } while (r != r1); + return PushRegexp(re); + } + + // Exclude newline if applicable. + if ((flags_ & NeverNL) && r == '\n') + return PushRegexp(new Regexp(kRegexpNoMatch, flags_)); + + // No fancy stuff worked. Ordinary literal. + if (MaybeConcatString(r, flags_)) + return true; + + Regexp* re = new Regexp(kRegexpLiteral, flags_); + re->arguments.rune_ = r; + return PushRegexp(re); +} + +// Pushes a ^ onto the stack. +bool Regexp::ParseState::PushCaret() { + if (flags_ & OneLine) { + return PushSimpleOp(kRegexpBeginText); + } + return PushSimpleOp(kRegexpBeginLine); +} + +// Pushes a \b or \B onto the stack. +bool Regexp::ParseState::PushWordBoundary(bool word) { + if (word) + return PushSimpleOp(kRegexpWordBoundary); + return PushSimpleOp(kRegexpNoWordBoundary); +} + +// Pushes a $ onto the stack. +bool Regexp::ParseState::PushDollar() { + if (flags_ & OneLine) { + // Clumsy marker so that MimicsPCRE() can tell whether + // this kRegexpEndText was a $ and not a \z. + Regexp::ParseFlags oflags = flags_; + flags_ = flags_ | WasDollar; + bool ret = PushSimpleOp(kRegexpEndText); + flags_ = oflags; + return ret; + } + return PushSimpleOp(kRegexpEndLine); +} + +// Pushes a . onto the stack. +bool Regexp::ParseState::PushDot() { + if ((flags_ & DotNL) && !(flags_ & NeverNL)) + return PushSimpleOp(kRegexpAnyChar); + // Rewrite . into [^\n] + Regexp* re = new Regexp(kRegexpCharClass, flags_ & ~FoldCase); + re->arguments.char_class.ccb_ = new CharClassBuilder; + re->arguments.char_class.ccb_->AddRange(0, '\n' - 1); + re->arguments.char_class.ccb_->AddRange('\n' + 1, rune_max_); + return PushRegexp(re); +} + +// Pushes a regexp with the given op (and no args) onto the stack. +bool Regexp::ParseState::PushSimpleOp(RegexpOp op) { + Regexp* re = new Regexp(op, flags_); + return PushRegexp(re); +} + +// Pushes a repeat operator regexp onto the stack. +// A valid argument for the operator must already be on the stack. +// The char c is the name of the operator, for use in error messages. +bool Regexp::ParseState::PushRepeatOp(RegexpOp op, const StringPiece& s, + bool nongreedy) { + if (stacktop_ == NULL || IsMarker(stacktop_->op())) { + status_->set_code(kRegexpRepeatArgument); + status_->set_error_arg(s); + return false; + } + Regexp::ParseFlags fl = flags_; + if (nongreedy) + fl = fl ^ NonGreedy; + + // Squash **, ++ and ??. Regexp::Star() et al. handle this too, but + // they're mostly for use during simplification, not during parsing. + if (op == stacktop_->op() && fl == stacktop_->parse_flags()) + return true; + + // Squash *+, *?, +*, +?, ?* and ?+. They all squash to *, so because + // op is a repeat, we just have to check that stacktop_->op() is too, + // then adjust stacktop_. + if ((stacktop_->op() == kRegexpStar || + stacktop_->op() == kRegexpPlus || + stacktop_->op() == kRegexpQuest) && + fl == stacktop_->parse_flags()) { + stacktop_->op_ = kRegexpStar; + return true; + } + + Regexp* re = new Regexp(op, fl); + re->AllocSub(1); + re->down_ = stacktop_->down_; + re->sub()[0] = FinishRegexp(stacktop_); + re->simple_ = re->ComputeSimple(); + stacktop_ = re; + return true; +} + +// RepetitionWalker reports whether the repetition regexp is valid. +// Valid means that the combination of the top-level repetition +// and any inner repetitions does not exceed n copies of the +// innermost thing. +// This rewalks the regexp tree and is called for every repetition, +// so we have to worry about inducing quadratic behavior in the parser. +// We avoid this by only using RepetitionWalker when min or max >= 2. +// In that case the depth of any >= 2 nesting can only get to 9 without +// triggering a parse error, so each subtree can only be rewalked 9 times. +class RepetitionWalker : public Regexp::Walker { + public: + RepetitionWalker() {} + virtual int PreVisit(Regexp* re, int parent_arg, bool* stop); + virtual int PostVisit(Regexp* re, int parent_arg, int pre_arg, + int* child_args, int nchild_args); + virtual int ShortVisit(Regexp* re, int parent_arg); + + private: + RepetitionWalker(const RepetitionWalker&) = delete; + RepetitionWalker& operator=(const RepetitionWalker&) = delete; +}; + +int RepetitionWalker::PreVisit(Regexp* re, int parent_arg, bool* stop) { + int arg = parent_arg; + if (re->op() == kRegexpRepeat) { + int m = re->max(); + if (m < 0) { + m = re->min(); + } + if (m > 0) { + arg /= m; + } + } + return arg; +} + +int RepetitionWalker::PostVisit(Regexp* re, int parent_arg, int pre_arg, + int* child_args, int nchild_args) { + int arg = pre_arg; + for (int i = 0; i < nchild_args; i++) { + if (child_args[i] < arg) { + arg = child_args[i]; + } + } + return arg; +} + +int RepetitionWalker::ShortVisit(Regexp* re, int parent_arg) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "RepetitionWalker::ShortVisit called"; +#endif + return 0; +} + +// Pushes a repetition regexp onto the stack. +// A valid argument for the operator must already be on the stack. +bool Regexp::ParseState::PushRepetition(int min, int max, + const StringPiece& s, + bool nongreedy) { + if ((max != -1 && max < min) || + min > maximum_repeat_count || + max > maximum_repeat_count) { + status_->set_code(kRegexpRepeatSize); + status_->set_error_arg(s); + return false; + } + if (stacktop_ == NULL || IsMarker(stacktop_->op())) { + status_->set_code(kRegexpRepeatArgument); + status_->set_error_arg(s); + return false; + } + Regexp::ParseFlags fl = flags_; + if (nongreedy) + fl = fl ^ NonGreedy; + Regexp* re = new Regexp(kRegexpRepeat, fl); + re->arguments.repeat.min_ = min; + re->arguments.repeat.max_ = max; + re->AllocSub(1); + re->down_ = stacktop_->down_; + re->sub()[0] = FinishRegexp(stacktop_); + re->simple_ = re->ComputeSimple(); + stacktop_ = re; + if (min >= 2 || max >= 2) { + RepetitionWalker w; + if (w.Walk(stacktop_, maximum_repeat_count) == 0) { + status_->set_code(kRegexpRepeatSize); + status_->set_error_arg(s); + return false; + } + } + return true; +} + +// Checks whether a particular regexp op is a marker. +bool Regexp::ParseState::IsMarker(RegexpOp op) { + return op >= kLeftParen; +} + +// Processes a left parenthesis in the input. +// Pushes a marker onto the stack. +bool Regexp::ParseState::DoLeftParen(const StringPiece& name) { + Regexp* re = new Regexp(kLeftParen, flags_); + re->arguments.capture.cap_ = ++ncap_; + if (name.data() != NULL) + re->arguments.capture.name_ = new std::string(name); + return PushRegexp(re); +} + +// Pushes a non-capturing marker onto the stack. +bool Regexp::ParseState::DoLeftParenNoCapture() { + Regexp* re = new Regexp(kLeftParen, flags_); + re->arguments.capture.cap_ = -1; + return PushRegexp(re); +} + +// Processes a vertical bar in the input. +bool Regexp::ParseState::DoVerticalBar() { + MaybeConcatString(-1, NoParseFlags); + DoConcatenation(); + + // Below the vertical bar is a list to alternate. + // Above the vertical bar is a list to concatenate. + // We just did the concatenation, so either swap + // the result below the vertical bar or push a new + // vertical bar on the stack. + Regexp* r1; + Regexp* r2; + if ((r1 = stacktop_) != NULL && + (r2 = r1->down_) != NULL && + r2->op() == kVerticalBar) { + Regexp* r3; + if ((r3 = r2->down_) != NULL && + (r1->op() == kRegexpAnyChar || r3->op() == kRegexpAnyChar)) { + // AnyChar is above or below the vertical bar. Let it subsume + // the other when the other is Literal, CharClass or AnyChar. + if (r3->op() == kRegexpAnyChar && + (r1->op() == kRegexpLiteral || + r1->op() == kRegexpCharClass || + r1->op() == kRegexpAnyChar)) { + // Discard r1. + stacktop_ = r2; + r1->Decref(); + return true; + } + if (r1->op() == kRegexpAnyChar && + (r3->op() == kRegexpLiteral || + r3->op() == kRegexpCharClass || + r3->op() == kRegexpAnyChar)) { + // Rearrange the stack and discard r3. + r1->down_ = r3->down_; + r2->down_ = r1; + stacktop_ = r2; + r3->Decref(); + return true; + } + } + // Swap r1 below vertical bar (r2). + r1->down_ = r2->down_; + r2->down_ = r1; + stacktop_ = r2; + return true; + } + return PushSimpleOp(kVerticalBar); +} + +// Processes a right parenthesis in the input. +bool Regexp::ParseState::DoRightParen() { + // Finish the current concatenation and alternation. + DoAlternation(); + + // The stack should be: LeftParen regexp + // Remove the LeftParen, leaving the regexp, + // parenthesized. + Regexp* r1; + Regexp* r2; + if ((r1 = stacktop_) == NULL || + (r2 = r1->down_) == NULL || + r2->op() != kLeftParen) { + status_->set_code(kRegexpUnexpectedParen); + status_->set_error_arg(whole_regexp_); + return false; + } + + // Pop off r1, r2. Will Decref or reuse below. + stacktop_ = r2->down_; + + // Restore flags from when paren opened. + Regexp* re = r2; + flags_ = re->parse_flags(); + + // Rewrite LeftParen as capture if needed. + if (re->arguments.capture.cap_ > 0) { + re->op_ = kRegexpCapture; + // re->cap_ is already set + re->AllocSub(1); + re->sub()[0] = FinishRegexp(r1); + re->simple_ = re->ComputeSimple(); + } else { + re->Decref(); + re = r1; + } + return PushRegexp(re); +} + +// Processes the end of input, returning the final regexp. +Regexp* Regexp::ParseState::DoFinish() { + DoAlternation(); + Regexp* re = stacktop_; + if (re != NULL && re->down_ != NULL) { + status_->set_code(kRegexpMissingParen); + status_->set_error_arg(whole_regexp_); + return NULL; + } + stacktop_ = NULL; + return FinishRegexp(re); +} + +// Returns the leading regexp that re starts with. +// The returned Regexp* points into a piece of re, +// so it must not be used after the caller calls re->Decref(). +Regexp* Regexp::LeadingRegexp(Regexp* re) { + if (re->op() == kRegexpEmptyMatch) + return NULL; + if (re->op() == kRegexpConcat && re->nsub() >= 2) { + Regexp** sub = re->sub(); + if (sub[0]->op() == kRegexpEmptyMatch) + return NULL; + return sub[0]; + } + return re; +} + +// Removes LeadingRegexp(re) from re and returns what's left. +// Consumes the reference to re and may edit it in place. +// If caller wants to hold on to LeadingRegexp(re), +// must have already Incref'ed it. +Regexp* Regexp::RemoveLeadingRegexp(Regexp* re) { + if (re->op() == kRegexpEmptyMatch) + return re; + if (re->op() == kRegexpConcat && re->nsub() >= 2) { + Regexp** sub = re->sub(); + if (sub[0]->op() == kRegexpEmptyMatch) + return re; + sub[0]->Decref(); + sub[0] = NULL; + if (re->nsub() == 2) { + // Collapse concatenation to single regexp. + Regexp* nre = sub[1]; + sub[1] = NULL; + re->Decref(); + return nre; + } + // 3 or more -> 2 or more. + re->nsub_--; + memmove(sub, sub + 1, re->nsub_ * sizeof sub[0]); + return re; + } + Regexp::ParseFlags pf = re->parse_flags(); + re->Decref(); + return new Regexp(kRegexpEmptyMatch, pf); +} + +// Returns the leading string that re starts with. +// The returned Rune* points into a piece of re, +// so it must not be used after the caller calls re->Decref(). +Rune* Regexp::LeadingString(Regexp* re, int *nrune, + Regexp::ParseFlags *flags) { + while (re->op() == kRegexpConcat && re->nsub() > 0) + re = re->sub()[0]; + + *flags = static_cast(re->parse_flags_ & Regexp::FoldCase); + + if (re->op() == kRegexpLiteral) { + *nrune = 1; + return &re->arguments.rune_; + } + + if (re->op() == kRegexpLiteralString) { + *nrune = re->arguments.literal_string.nrunes_; + return re->arguments.literal_string.runes_; + } + + *nrune = 0; + return NULL; +} + +// Removes the first n leading runes from the beginning of re. +// Edits re in place. +void Regexp::RemoveLeadingString(Regexp* re, int n) { + // Chase down concats to find first string. + // For regexps generated by parser, nested concats are + // flattened except when doing so would overflow the 16-bit + // limit on the size of a concatenation, so we should never + // see more than two here. + Regexp* stk[4]; + size_t d = 0; + while (re->op() == kRegexpConcat) { + if (d < arraysize(stk)) + stk[d++] = re; + re = re->sub()[0]; + } + + // Remove leading string from re. + if (re->op() == kRegexpLiteral) { + re->arguments.rune_ = 0; + re->op_ = kRegexpEmptyMatch; + } else if (re->op() == kRegexpLiteralString) { + if (n >= re->arguments.literal_string.nrunes_) { + delete[] re->arguments.literal_string.runes_; + re->arguments.literal_string.runes_ = NULL; + re->arguments.literal_string.nrunes_ = 0; + re->op_ = kRegexpEmptyMatch; + } else if (n == re->arguments.literal_string.nrunes_ - 1) { + Rune rune = re->arguments.literal_string.runes_[re->arguments.literal_string.nrunes_ - 1]; + delete[] re->arguments.literal_string.runes_; + re->arguments.literal_string.runes_ = NULL; + re->arguments.literal_string.nrunes_ = 0; + re->arguments.rune_ = rune; + re->op_ = kRegexpLiteral; + } else { + re->arguments.literal_string.nrunes_ -= n; + memmove(re->arguments.literal_string.runes_, re->arguments.literal_string.runes_ + n, re->arguments.literal_string.nrunes_ * sizeof re->arguments.literal_string.runes_[0]); + } + } + + // If re is now empty, concatenations might simplify too. + while (d > 0) { + re = stk[--d]; + Regexp** sub = re->sub(); + if (sub[0]->op() == kRegexpEmptyMatch) { + sub[0]->Decref(); + sub[0] = NULL; + // Delete first element of concat. + switch (re->nsub()) { + case 0: + case 1: + // Impossible. + LOG(DFATAL) << "Concat of " << re->nsub(); + re->submany_ = NULL; + re->op_ = kRegexpEmptyMatch; + break; + + case 2: { + // Replace re with sub[1]. + Regexp* old = sub[1]; + sub[1] = NULL; + re->Swap(old); + old->Decref(); + break; + } + + default: + // Slide down. + re->nsub_--; + memmove(sub, sub + 1, re->nsub_ * sizeof sub[0]); + break; + } + } + } +} + +// In the context of factoring alternations, a Splice is: a factored prefix or +// merged character class computed by one iteration of one round of factoring; +// the span of subexpressions of the alternation to be "spliced" (i.e. removed +// and replaced); and, for a factored prefix, the number of suffixes after any +// factoring that might have subsequently been performed on them. For a merged +// character class, there are no suffixes, of course, so the field is ignored. +struct Splice { + Splice(Regexp* prefix, Regexp** sub, int nsub) + : prefix(prefix), + sub(sub), + nsub(nsub), + nsuffix(-1) {} + + Regexp* prefix; + Regexp** sub; + int nsub; + int nsuffix; +}; + +// Named so because it is used to implement an explicit stack, a Frame is: the +// span of subexpressions of the alternation to be factored; the current round +// of factoring; any Splices computed; and, for a factored prefix, an iterator +// to the next Splice to be factored (i.e. in another Frame) because suffixes. +struct Frame { + Frame(Regexp** sub, int nsub) + : sub(sub), + nsub(nsub), + round(0) {} + + Regexp** sub; + int nsub; + int round; + std::vector splices; + int spliceidx; +}; + +// Bundled into a class for friend access to Regexp without needing to declare +// (or define) Splice in regexp.h. +class FactorAlternationImpl { + public: + static void Round1(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices); + static void Round2(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices); + static void Round3(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices); +}; + +// Factors common prefixes from alternation. +// For example, +// ABC|ABD|AEF|BCX|BCY +// simplifies to +// A(B(C|D)|EF)|BC(X|Y) +// and thence to +// A(B[CD]|EF)|BC[XY] +// +// Rewrites sub to contain simplified list to alternate and returns +// the new length of sub. Adjusts reference counts accordingly +// (incoming sub[i] decremented, outgoing sub[i] incremented). +int Regexp::FactorAlternation(Regexp** sub, int nsub, ParseFlags flags) { + std::vector stk; + stk.emplace_back(sub, nsub); + + for (;;) { + auto& sub = stk.back().sub; + auto& nsub = stk.back().nsub; + auto& round = stk.back().round; + auto& splices = stk.back().splices; + auto& spliceidx = stk.back().spliceidx; + + if (splices.empty()) { + // Advance to the next round of factoring. Note that this covers + // the initialised state: when splices is empty and round is 0. + round++; + } else if (spliceidx < static_cast(splices.size())) { + // We have at least one more Splice to factor. Recurse logically. + stk.emplace_back(splices[spliceidx].sub, splices[spliceidx].nsub); + continue; + } else { + // We have no more Splices to factor. Apply them. + auto iter = splices.begin(); + int out = 0; + for (int i = 0; i < nsub; ) { + // Copy until we reach where the next Splice begins. + while (sub + i < iter->sub) + sub[out++] = sub[i++]; + switch (round) { + case 1: + case 2: { + // Assemble the Splice prefix and the suffixes. + Regexp* re[2]; + re[0] = iter->prefix; + re[1] = Regexp::AlternateNoFactor(iter->sub, iter->nsuffix, flags); + sub[out++] = Regexp::Concat(re, 2, flags); + i += iter->nsub; + break; + } + case 3: + // Just use the Splice prefix. + sub[out++] = iter->prefix; + i += iter->nsub; + break; + default: + LOG(DFATAL) << "unknown round: " << round; + break; + } + // If we are done, copy until the end of sub. + if (++iter == splices.end()) { + while (i < nsub) + sub[out++] = sub[i++]; + } + } + splices.clear(); + nsub = out; + // Advance to the next round of factoring. + round++; + } + + switch (round) { + case 1: + FactorAlternationImpl::Round1(sub, nsub, flags, &splices); + break; + case 2: + FactorAlternationImpl::Round2(sub, nsub, flags, &splices); + break; + case 3: + FactorAlternationImpl::Round3(sub, nsub, flags, &splices); + break; + case 4: + if (stk.size() == 1) { + // We are at the top of the stack. Just return. + return nsub; + } else { + // Pop the stack and set the number of suffixes. + // (Note that references will be invalidated!) + int nsuffix = nsub; + stk.pop_back(); + stk.back().splices[stk.back().spliceidx].nsuffix = nsuffix; + ++stk.back().spliceidx; + continue; + } + default: + LOG(DFATAL) << "unknown round: " << round; + break; + } + + // Set spliceidx depending on whether we have Splices to factor. + if (splices.empty() || round == 3) { + spliceidx = static_cast(splices.size()); + } else { + spliceidx = 0; + } + } +} + +void FactorAlternationImpl::Round1(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices) { + // Round 1: Factor out common literal prefixes. + int start = 0; + Rune* rune = NULL; + int nrune = 0; + Regexp::ParseFlags runeflags = Regexp::NoParseFlags; + for (int i = 0; i <= nsub; i++) { + // Invariant: sub[start:i] consists of regexps that all + // begin with rune[0:nrune]. + Rune* rune_i = NULL; + int nrune_i = 0; + Regexp::ParseFlags runeflags_i = Regexp::NoParseFlags; + if (i < nsub) { + rune_i = Regexp::LeadingString(sub[i], &nrune_i, &runeflags_i); + if (runeflags_i == runeflags) { + int same = 0; + while (same < nrune && same < nrune_i && rune[same] == rune_i[same]) + same++; + if (same > 0) { + // Matches at least one rune in current range. Keep going around. + nrune = same; + continue; + } + } + } + + // Found end of a run with common leading literal string: + // sub[start:i] all begin with rune[0:nrune], + // but sub[i] does not even begin with rune[0]. + if (i == start) { + // Nothing to do - first iteration. + } else if (i == start+1) { + // Just one: don't bother factoring. + } else { + Regexp* prefix = Regexp::LiteralString(rune, nrune, runeflags); + for (int j = start; j < i; j++) + Regexp::RemoveLeadingString(sub[j], nrune); + splices->emplace_back(prefix, sub + start, i - start); + } + + // Prepare for next iteration (if there is one). + if (i < nsub) { + start = i; + rune = rune_i; + nrune = nrune_i; + runeflags = runeflags_i; + } + } +} + +void FactorAlternationImpl::Round2(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices) { + // Round 2: Factor out common simple prefixes, + // just the first piece of each concatenation. + // This will be good enough a lot of the time. + // + // Complex subexpressions (e.g. involving quantifiers) + // are not safe to factor because that collapses their + // distinct paths through the automaton, which affects + // correctness in some cases. + int start = 0; + Regexp* first = NULL; + for (int i = 0; i <= nsub; i++) { + // Invariant: sub[start:i] consists of regexps that all + // begin with first. + Regexp* first_i = NULL; + if (i < nsub) { + first_i = Regexp::LeadingRegexp(sub[i]); + if (first != NULL && + // first must be an empty-width op + // OR a char class, any char or any byte + // OR a fixed repeat of a literal, char class, any char or any byte. + (first->op() == kRegexpBeginLine || + first->op() == kRegexpEndLine || + first->op() == kRegexpWordBoundary || + first->op() == kRegexpNoWordBoundary || + first->op() == kRegexpBeginText || + first->op() == kRegexpEndText || + first->op() == kRegexpCharClass || + first->op() == kRegexpAnyChar || + first->op() == kRegexpAnyByte || + (first->op() == kRegexpRepeat && + first->min() == first->max() && + (first->sub()[0]->op() == kRegexpLiteral || + first->sub()[0]->op() == kRegexpCharClass || + first->sub()[0]->op() == kRegexpAnyChar || + first->sub()[0]->op() == kRegexpAnyByte))) && + Regexp::Equal(first, first_i)) + continue; + } + + // Found end of a run with common leading regexp: + // sub[start:i] all begin with first, + // but sub[i] does not. + if (i == start) { + // Nothing to do - first iteration. + } else if (i == start+1) { + // Just one: don't bother factoring. + } else { + Regexp* prefix = first->Incref(); + for (int j = start; j < i; j++) + sub[j] = Regexp::RemoveLeadingRegexp(sub[j]); + splices->emplace_back(prefix, sub + start, i - start); + } + + // Prepare for next iteration (if there is one). + if (i < nsub) { + start = i; + first = first_i; + } + } +} + +void FactorAlternationImpl::Round3(Regexp** sub, int nsub, + Regexp::ParseFlags flags, + std::vector* splices) { + // Round 3: Merge runs of literals and/or character classes. + int start = 0; + Regexp* first = NULL; + for (int i = 0; i <= nsub; i++) { + // Invariant: sub[start:i] consists of regexps that all + // are either literals (i.e. runes) or character classes. + Regexp* first_i = NULL; + if (i < nsub) { + first_i = sub[i]; + if (first != NULL && + (first->op() == kRegexpLiteral || + first->op() == kRegexpCharClass) && + (first_i->op() == kRegexpLiteral || + first_i->op() == kRegexpCharClass)) + continue; + } + + // Found end of a run of Literal/CharClass: + // sub[start:i] all are either one or the other, + // but sub[i] is not. + if (i == start) { + // Nothing to do - first iteration. + } else if (i == start+1) { + // Just one: don't bother factoring. + } else { + CharClassBuilder ccb; + for (int j = start; j < i; j++) { + Regexp* re = sub[j]; + if (re->op() == kRegexpCharClass) { + CharClass* cc = re->cc(); + for (CharClass::iterator it = cc->begin(); it != cc->end(); ++it) + ccb.AddRange(it->lo, it->hi); + } else if (re->op() == kRegexpLiteral) { + ccb.AddRangeFlags(re->rune(), re->rune(), re->parse_flags()); + } else { + LOG(DFATAL) << "RE2: unexpected op: " << re->op() << " " + << re->ToString(); + } + re->Decref(); + } + Regexp* re = Regexp::NewCharClass(ccb.GetCharClass(), flags); + splices->emplace_back(re, sub + start, i - start); + } + + // Prepare for next iteration (if there is one). + if (i < nsub) { + start = i; + first = first_i; + } + } +} + +// Collapse the regexps on top of the stack, down to the +// first marker, into a new op node (op == kRegexpAlternate +// or op == kRegexpConcat). +void Regexp::ParseState::DoCollapse(RegexpOp op) { + // Scan backward to marker, counting children of composite. + int n = 0; + Regexp* next = NULL; + Regexp* sub; + for (sub = stacktop_; sub != NULL && !IsMarker(sub->op()); sub = next) { + next = sub->down_; + if (sub->op_ == op) + n += sub->nsub_; + else + n++; + } + + // If there's just one child, leave it alone. + // (Concat of one thing is that one thing; alternate of one thing is same.) + if (stacktop_ != NULL && stacktop_->down_ == next) + return; + + // Construct op (alternation or concatenation), flattening op of op. + PODArray subs(n); + next = NULL; + int i = n; + for (sub = stacktop_; sub != NULL && !IsMarker(sub->op()); sub = next) { + next = sub->down_; + if (sub->op_ == op) { + Regexp** sub_subs = sub->sub(); + for (int k = sub->nsub_ - 1; k >= 0; k--) + subs[--i] = sub_subs[k]->Incref(); + sub->Decref(); + } else { + subs[--i] = FinishRegexp(sub); + } + } + + Regexp* re = ConcatOrAlternate(op, subs.data(), n, flags_, true); + re->simple_ = re->ComputeSimple(); + re->down_ = next; + stacktop_ = re; +} + +// Finishes the current concatenation, +// collapsing it into a single regexp on the stack. +void Regexp::ParseState::DoConcatenation() { + Regexp* r1 = stacktop_; + if (r1 == NULL || IsMarker(r1->op())) { + // empty concatenation is special case + Regexp* re = new Regexp(kRegexpEmptyMatch, flags_); + PushRegexp(re); + } + DoCollapse(kRegexpConcat); +} + +// Finishes the current alternation, +// collapsing it to a single regexp on the stack. +void Regexp::ParseState::DoAlternation() { + DoVerticalBar(); + // Now stack top is kVerticalBar. + Regexp* r1 = stacktop_; + stacktop_ = r1->down_; + r1->Decref(); + DoCollapse(kRegexpAlternate); +} + +// Incremental conversion of concatenated literals into strings. +// If top two elements on stack are both literal or string, +// collapse into single string. +// Don't walk down the stack -- the parser calls this frequently +// enough that below the bottom two is known to be collapsed. +// Only called when another regexp is about to be pushed +// on the stack, so that the topmost literal is not being considered. +// (Otherwise ab* would turn into (ab)*.) +// If r >= 0, consider pushing a literal r on the stack. +// Return whether that happened. +bool Regexp::ParseState::MaybeConcatString(int r, ParseFlags flags) { + Regexp* re1; + Regexp* re2; + if ((re1 = stacktop_) == NULL || (re2 = re1->down_) == NULL) + return false; + + if (re1->op_ != kRegexpLiteral && re1->op_ != kRegexpLiteralString) + return false; + if (re2->op_ != kRegexpLiteral && re2->op_ != kRegexpLiteralString) + return false; + if ((re1->parse_flags_ & FoldCase) != (re2->parse_flags_ & FoldCase)) + return false; + + if (re2->op_ == kRegexpLiteral) { + // convert into string + Rune rune = re2->arguments.rune_; + re2->op_ = kRegexpLiteralString; + re2->arguments.literal_string.nrunes_ = 0; + re2->arguments.literal_string.runes_ = NULL; + re2->AddRuneToString(rune); + } + + // push re1 into re2. + if (re1->op_ == kRegexpLiteral) { + re2->AddRuneToString(re1->arguments.rune_); + } else { + for (int i = 0; i < re1->arguments.literal_string.nrunes_; i++) + re2->AddRuneToString(re1->arguments.literal_string.runes_[i]); + re1->arguments.literal_string.nrunes_ = 0; + delete[] re1->arguments.literal_string.runes_; + re1->arguments.literal_string.runes_ = NULL; + } + + // reuse re1 if possible + if (r >= 0) { + re1->op_ = kRegexpLiteral; + re1->arguments.rune_ = r; + re1->parse_flags_ = static_cast(flags); + return true; + } + + stacktop_ = re2; + re1->Decref(); + return false; +} + +// Lexing routines. + +// Parses a decimal integer, storing it in *np. +// Sets *s to span the remainder of the string. +static bool ParseInteger(StringPiece* s, int* np) { + if (s->empty() || !isdigit((*s)[0] & 0xFF)) + return false; + // Disallow leading zeros. + if (s->size() >= 2 && (*s)[0] == '0' && isdigit((*s)[1] & 0xFF)) + return false; + int n = 0; + int c; + while (!s->empty() && isdigit(c = (*s)[0] & 0xFF)) { + // Avoid overflow. + if (n >= 100000000) + return false; + n = n*10 + c - '0'; + s->remove_prefix(1); // digit + } + *np = n; + return true; +} + +// Parses a repetition suffix like {1,2} or {2} or {2,}. +// Sets *s to span the remainder of the string on success. +// Sets *lo and *hi to the given range. +// In the case of {2,}, the high number is unbounded; +// sets *hi to -1 to signify this. +// {,2} is NOT a valid suffix. +// The Maybe in the name signifies that the regexp parse +// doesn't fail even if ParseRepetition does, so the StringPiece +// s must NOT be edited unless MaybeParseRepetition returns true. +static bool MaybeParseRepetition(StringPiece* sp, int* lo, int* hi) { + StringPiece s = *sp; + if (s.empty() || s[0] != '{') + return false; + s.remove_prefix(1); // '{' + if (!ParseInteger(&s, lo)) + return false; + if (s.empty()) + return false; + if (s[0] == ',') { + s.remove_prefix(1); // ',' + if (s.empty()) + return false; + if (s[0] == '}') { + // {2,} means at least 2 + *hi = -1; + } else { + // {2,4} means 2, 3, or 4. + if (!ParseInteger(&s, hi)) + return false; + } + } else { + // {2} means exactly two + *hi = *lo; + } + if (s.empty() || s[0] != '}') + return false; + s.remove_prefix(1); // '}' + *sp = s; + return true; +} + +// Removes the next Rune from the StringPiece and stores it in *r. +// Returns number of bytes removed from sp. +// Behaves as though there is a terminating NUL at the end of sp. +// Argument order is backwards from usual Google style +// but consistent with chartorune. +static int StringPieceToRune(Rune *r, StringPiece *sp, RegexpStatus* status) { + // fullrune() takes int, not size_t. However, it just looks + // at the leading byte and treats any length >= 4 the same. + if (fullrune(sp->data(), static_cast(std::min(size_t{4}, sp->size())))) { + int n = chartorune(r, sp->data()); + // Some copies of chartorune have a bug that accepts + // encodings of values in (10FFFF, 1FFFFF] as valid. + // Those values break the character class algorithm, + // which assumes Runemax is the largest rune. + if (*r > Runemax) { + n = 1; + *r = Runeerror; + } + if (!(n == 1 && *r == Runeerror)) { // no decoding error + sp->remove_prefix(n); + return n; + } + } + + if (status != NULL) { + status->set_code(kRegexpBadUTF8); + status->set_error_arg(StringPiece()); + } + return -1; +} + +// Returns whether name is valid UTF-8. +// If not, sets status to kRegexpBadUTF8. +static bool IsValidUTF8(const StringPiece& s, RegexpStatus* status) { + StringPiece t = s; + Rune r; + while (!t.empty()) { + if (StringPieceToRune(&r, &t, status) < 0) + return false; + } + return true; +} + +// Is c a hex digit? +static int IsHex(int c) { + return ('0' <= c && c <= '9') || + ('A' <= c && c <= 'F') || + ('a' <= c && c <= 'f'); +} + +// Convert hex digit to value. +static int UnHex(int c) { + if ('0' <= c && c <= '9') + return c - '0'; + if ('A' <= c && c <= 'F') + return c - 'A' + 10; + if ('a' <= c && c <= 'f') + return c - 'a' + 10; + LOG(DFATAL) << "Bad hex digit " << c; + return 0; +} + +// Parse an escape sequence (e.g., \n, \{). +// Sets *s to span the remainder of the string. +// Sets *rp to the named character. +static bool ParseEscape(StringPiece* s, Rune* rp, + RegexpStatus* status, int rune_max) { + const char* begin = s->data(); + if (s->empty() || (*s)[0] != '\\') { + // Should not happen - caller always checks. + status->set_code(kRegexpInternalError); + status->set_error_arg(StringPiece()); + return false; + } + if (s->size() == 1) { + status->set_code(kRegexpTrailingBackslash); + status->set_error_arg(StringPiece()); + return false; + } + Rune c, c1; + s->remove_prefix(1); // backslash + if (StringPieceToRune(&c, s, status) < 0) + return false; + int code; + switch (c) { + default: + if (c < Runeself && !isalpha(c) && !isdigit(c)) { + // Escaped non-word characters are always themselves. + // PCRE is not quite so rigorous: it accepts things like + // \q, but we don't. We once rejected \_, but too many + // programs and people insist on using it, so allow \_. + *rp = c; + return true; + } + goto BadEscape; + + // Octal escapes. + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + // Single non-zero octal digit is a backreference; not supported. + if (s->empty() || (*s)[0] < '0' || (*s)[0] > '7') + goto BadEscape; + FALLTHROUGH_INTENDED; + case '0': + // consume up to three octal digits; already have one. + code = c - '0'; + if (!s->empty() && '0' <= (c = (*s)[0]) && c <= '7') { + code = code * 8 + c - '0'; + s->remove_prefix(1); // digit + if (!s->empty()) { + c = (*s)[0]; + if ('0' <= c && c <= '7') { + code = code * 8 + c - '0'; + s->remove_prefix(1); // digit + } + } + } + if (code > rune_max) + goto BadEscape; + *rp = code; + return true; + + // Hexadecimal escapes + case 'x': + if (s->empty()) + goto BadEscape; + if (StringPieceToRune(&c, s, status) < 0) + return false; + if (c == '{') { + // Any number of digits in braces. + // Update n as we consume the string, so that + // the whole thing gets shown in the error message. + // Perl accepts any text at all; it ignores all text + // after the first non-hex digit. We require only hex digits, + // and at least one. + if (StringPieceToRune(&c, s, status) < 0) + return false; + int nhex = 0; + code = 0; + while (IsHex(c)) { + nhex++; + code = code * 16 + UnHex(c); + if (code > rune_max) + goto BadEscape; + if (s->empty()) + goto BadEscape; + if (StringPieceToRune(&c, s, status) < 0) + return false; + } + if (c != '}' || nhex == 0) + goto BadEscape; + *rp = code; + return true; + } + // Easy case: two hex digits. + if (s->empty()) + goto BadEscape; + if (StringPieceToRune(&c1, s, status) < 0) + return false; + if (!IsHex(c) || !IsHex(c1)) + goto BadEscape; + *rp = UnHex(c) * 16 + UnHex(c1); + return true; + + // C escapes. + case 'n': + *rp = '\n'; + return true; + case 'r': + *rp = '\r'; + return true; + case 't': + *rp = '\t'; + return true; + + // Less common C escapes. + case 'a': + *rp = '\a'; + return true; + case 'f': + *rp = '\f'; + return true; + case 'v': + *rp = '\v'; + return true; + + // This code is disabled to avoid misparsing + // the Perl word-boundary \b as a backspace + // when in POSIX regexp mode. Surprisingly, + // in Perl, \b means word-boundary but [\b] + // means backspace. We don't support that: + // if you want a backspace embed a literal + // backspace character or use \x08. + // + // case 'b': + // *rp = '\b'; + // return true; + } + +BadEscape: + // Unrecognized escape sequence. + status->set_code(kRegexpBadEscape); + status->set_error_arg( + StringPiece(begin, static_cast(s->data() - begin))); + return false; +} + +// Add a range to the character class, but exclude newline if asked. +// Also handle case folding. +void CharClassBuilder::AddRangeFlags( + Rune lo, Rune hi, Regexp::ParseFlags parse_flags) { + + // Take out \n if the flags say so. + bool cutnl = !(parse_flags & Regexp::ClassNL) || + (parse_flags & Regexp::NeverNL); + if (cutnl && lo <= '\n' && '\n' <= hi) { + if (lo < '\n') + AddRangeFlags(lo, '\n' - 1, parse_flags); + if (hi > '\n') + AddRangeFlags('\n' + 1, hi, parse_flags); + return; + } + + // If folding case, add fold-equivalent characters too. + if (parse_flags & Regexp::FoldCase) + AddFoldedRange(this, lo, hi, 0); + else + AddRange(lo, hi); +} + +// Look for a group with the given name. +static const UGroup* LookupGroup(const StringPiece& name, + const UGroup *groups, int ngroups) { + // Simple name lookup. + for (int i = 0; i < ngroups; i++) + if (StringPiece(groups[i].name) == name) + return &groups[i]; + return NULL; +} + +// Look for a POSIX group with the given name (e.g., "[:^alpha:]") +static const UGroup* LookupPosixGroup(const StringPiece& name) { + return LookupGroup(name, posix_groups, num_posix_groups); +} + +static const UGroup* LookupPerlGroup(const StringPiece& name) { + return LookupGroup(name, perl_groups, num_perl_groups); +} + +#if !defined(RE2_USE_ICU) +// Fake UGroup containing all Runes +static URange16 any16[] = { { 0, 65535 } }; +static URange32 any32[] = { { 65536, Runemax } }; +static UGroup anygroup = { "Any", +1, any16, 1, any32, 1 }; + +// Look for a Unicode group with the given name (e.g., "Han") +static const UGroup* LookupUnicodeGroup(const StringPiece& name) { + // Special case: "Any" means any. + if (name == StringPiece("Any")) + return &anygroup; + return LookupGroup(name, unicode_groups, num_unicode_groups); +} +#endif + +// Add a UGroup or its negation to the character class. +static void AddUGroup(CharClassBuilder *cc, const UGroup *g, int sign, + Regexp::ParseFlags parse_flags) { + if (sign == +1) { + for (int i = 0; i < g->nr16; i++) { + cc->AddRangeFlags(g->r16[i].lo, g->r16[i].hi, parse_flags); + } + for (int i = 0; i < g->nr32; i++) { + cc->AddRangeFlags(g->r32[i].lo, g->r32[i].hi, parse_flags); + } + } else { + if (parse_flags & Regexp::FoldCase) { + // Normally adding a case-folded group means + // adding all the extra fold-equivalent runes too. + // But if we're adding the negation of the group, + // we have to exclude all the runes that are fold-equivalent + // to what's already missing. Too hard, so do in two steps. + CharClassBuilder ccb1; + AddUGroup(&ccb1, g, +1, parse_flags); + // If the flags say to take out \n, put it in, so that negating will take it out. + // Normally AddRangeFlags does this, but we're bypassing AddRangeFlags. + bool cutnl = !(parse_flags & Regexp::ClassNL) || + (parse_flags & Regexp::NeverNL); + if (cutnl) { + ccb1.AddRange('\n', '\n'); + } + ccb1.Negate(); + cc->AddCharClass(&ccb1); + return; + } + int next = 0; + for (int i = 0; i < g->nr16; i++) { + if (next < g->r16[i].lo) + cc->AddRangeFlags(next, g->r16[i].lo - 1, parse_flags); + next = g->r16[i].hi + 1; + } + for (int i = 0; i < g->nr32; i++) { + if (next < g->r32[i].lo) + cc->AddRangeFlags(next, g->r32[i].lo - 1, parse_flags); + next = g->r32[i].hi + 1; + } + if (next <= Runemax) + cc->AddRangeFlags(next, Runemax, parse_flags); + } +} + +// Maybe parse a Perl character class escape sequence. +// Only recognizes the Perl character classes (\d \s \w \D \S \W), +// not the Perl empty-string classes (\b \B \A \Z \z). +// On success, sets *s to span the remainder of the string +// and returns the corresponding UGroup. +// The StringPiece must *NOT* be edited unless the call succeeds. +const UGroup* MaybeParsePerlCCEscape(StringPiece* s, Regexp::ParseFlags parse_flags) { + if (!(parse_flags & Regexp::PerlClasses)) + return NULL; + if (s->size() < 2 || (*s)[0] != '\\') + return NULL; + // Could use StringPieceToRune, but there aren't + // any non-ASCII Perl group names. + StringPiece name(s->data(), 2); + const UGroup *g = LookupPerlGroup(name); + if (g == NULL) + return NULL; + s->remove_prefix(name.size()); + return g; +} + +enum ParseStatus { + kParseOk, // Did some parsing. + kParseError, // Found an error. + kParseNothing, // Decided not to parse. +}; + +// Maybe parses a Unicode character group like \p{Han} or \P{Han} +// (the latter is a negated group). +ParseStatus ParseUnicodeGroup(StringPiece* s, Regexp::ParseFlags parse_flags, + CharClassBuilder *cc, + RegexpStatus* status) { + // Decide whether to parse. + if (!(parse_flags & Regexp::UnicodeGroups)) + return kParseNothing; + if (s->size() < 2 || (*s)[0] != '\\') + return kParseNothing; + Rune c = (*s)[1]; + if (c != 'p' && c != 'P') + return kParseNothing; + + // Committed to parse. Results: + int sign = +1; // -1 = negated char class + if (c == 'P') + sign = -sign; + StringPiece seq = *s; // \p{Han} or \pL + StringPiece name; // Han or L + s->remove_prefix(2); // '\\', 'p' + + if (!StringPieceToRune(&c, s, status)) + return kParseError; + if (c != '{') { + // Name is the bit of string we just skipped over for c. + const char* p = seq.data() + 2; + name = StringPiece(p, static_cast(s->data() - p)); + } else { + // Name is in braces. Look for closing } + size_t end = s->find('}', 0); + if (end == StringPiece::npos) { + if (!IsValidUTF8(seq, status)) + return kParseError; + status->set_code(kRegexpBadCharRange); + status->set_error_arg(seq); + return kParseError; + } + name = StringPiece(s->data(), end); // without '}' + s->remove_prefix(end + 1); // with '}' + if (!IsValidUTF8(name, status)) + return kParseError; + } + + // Chop seq where s now begins. + seq = StringPiece(seq.data(), static_cast(s->data() - seq.data())); + + if (!name.empty() && name[0] == '^') { + sign = -sign; + name.remove_prefix(1); // '^' + } + +#if !defined(RE2_USE_ICU) + // Look up the group in the RE2 Unicode data. + const UGroup *g = LookupUnicodeGroup(name); + if (g == NULL) { + status->set_code(kRegexpBadCharRange); + status->set_error_arg(seq); + return kParseError; + } + + AddUGroup(cc, g, sign, parse_flags); +#else + // Look up the group in the ICU Unicode data. Because ICU provides full + // Unicode properties support, this could be more than a lookup by name. + ::icu::UnicodeString ustr = ::icu::UnicodeString::fromUTF8( + std::string("\\p{") + std::string(name) + std::string("}")); + UErrorCode uerr = U_ZERO_ERROR; + ::icu::UnicodeSet uset(ustr, uerr); + if (U_FAILURE(uerr)) { + status->set_code(kRegexpBadCharRange); + status->set_error_arg(seq); + return kParseError; + } + + // Convert the UnicodeSet to a URange32 and UGroup that we can add. + int nr = uset.getRangeCount(); + PODArray r(nr); + for (int i = 0; i < nr; i++) { + r[i].lo = uset.getRangeStart(i); + r[i].hi = uset.getRangeEnd(i); + } + UGroup g = {"", +1, 0, 0, r.data(), nr}; + AddUGroup(cc, &g, sign, parse_flags); +#endif + + return kParseOk; +} + +// Parses a character class name like [:alnum:]. +// Sets *s to span the remainder of the string. +// Adds the ranges corresponding to the class to ranges. +static ParseStatus ParseCCName(StringPiece* s, Regexp::ParseFlags parse_flags, + CharClassBuilder *cc, + RegexpStatus* status) { + // Check begins with [: + const char* p = s->data(); + const char* ep = s->data() + s->size(); + if (ep - p < 2 || p[0] != '[' || p[1] != ':') + return kParseNothing; + + // Look for closing :]. + const char* q; + for (q = p+2; q <= ep-2 && (*q != ':' || *(q+1) != ']'); q++) + ; + + // If no closing :], then ignore. + if (q > ep-2) + return kParseNothing; + + // Got it. Check that it's valid. + q += 2; + StringPiece name(p, static_cast(q - p)); + + const UGroup *g = LookupPosixGroup(name); + if (g == NULL) { + status->set_code(kRegexpBadCharRange); + status->set_error_arg(name); + return kParseError; + } + + s->remove_prefix(name.size()); + AddUGroup(cc, g, g->sign, parse_flags); + return kParseOk; +} + +// Parses a character inside a character class. +// There are fewer special characters here than in the rest of the regexp. +// Sets *s to span the remainder of the string. +// Sets *rp to the character. +bool Regexp::ParseState::ParseCCCharacter(StringPiece* s, Rune *rp, + const StringPiece& whole_class, + RegexpStatus* status) { + if (s->empty()) { + status->set_code(kRegexpMissingBracket); + status->set_error_arg(whole_class); + return false; + } + + // Allow regular escape sequences even though + // many need not be escaped in this context. + if ((*s)[0] == '\\') + return ParseEscape(s, rp, status, rune_max_); + + // Otherwise take the next rune. + return StringPieceToRune(rp, s, status) >= 0; +} + +// Parses a character class character, or, if the character +// is followed by a hyphen, parses a character class range. +// For single characters, rr->lo == rr->hi. +// Sets *s to span the remainder of the string. +// Sets *rp to the character. +bool Regexp::ParseState::ParseCCRange(StringPiece* s, RuneRange* rr, + const StringPiece& whole_class, + RegexpStatus* status) { + StringPiece os = *s; + if (!ParseCCCharacter(s, &rr->lo, whole_class, status)) + return false; + // [a-] means (a|-), so check for final ]. + if (s->size() >= 2 && (*s)[0] == '-' && (*s)[1] != ']') { + s->remove_prefix(1); // '-' + if (!ParseCCCharacter(s, &rr->hi, whole_class, status)) + return false; + if (rr->hi < rr->lo) { + status->set_code(kRegexpBadCharRange); + status->set_error_arg( + StringPiece(os.data(), static_cast(s->data() - os.data()))); + return false; + } + } else { + rr->hi = rr->lo; + } + return true; +} + +// Parses a possibly-negated character class expression like [^abx-z[:digit:]]. +// Sets *s to span the remainder of the string. +// Sets *out_re to the regexp for the class. +bool Regexp::ParseState::ParseCharClass(StringPiece* s, + Regexp** out_re, + RegexpStatus* status) { + StringPiece whole_class = *s; + if (s->empty() || (*s)[0] != '[') { + // Caller checked this. + status->set_code(kRegexpInternalError); + status->set_error_arg(StringPiece()); + return false; + } + bool negated = false; + Regexp* re = new Regexp(kRegexpCharClass, flags_ & ~FoldCase); + re->arguments.char_class.ccb_ = new CharClassBuilder; + s->remove_prefix(1); // '[' + if (!s->empty() && (*s)[0] == '^') { + s->remove_prefix(1); // '^' + negated = true; + if (!(flags_ & ClassNL) || (flags_ & NeverNL)) { + // If NL can't match implicitly, then pretend + // negated classes include a leading \n. + re->arguments.char_class.ccb_->AddRange('\n', '\n'); + } + } + bool first = true; // ] is okay as first char in class + while (!s->empty() && ((*s)[0] != ']' || first)) { + // - is only okay unescaped as first or last in class. + // Except that Perl allows - anywhere. + if ((*s)[0] == '-' && !first && !(flags_&PerlX) && + (s->size() == 1 || (*s)[1] != ']')) { + StringPiece t = *s; + t.remove_prefix(1); // '-' + Rune r; + int n = StringPieceToRune(&r, &t, status); + if (n < 0) { + re->Decref(); + return false; + } + status->set_code(kRegexpBadCharRange); + status->set_error_arg(StringPiece(s->data(), 1+n)); + re->Decref(); + return false; + } + first = false; + + // Look for [:alnum:] etc. + if (s->size() > 2 && (*s)[0] == '[' && (*s)[1] == ':') { + switch (ParseCCName(s, flags_, re->arguments.char_class.ccb_, status)) { + case kParseOk: + continue; + case kParseError: + re->Decref(); + return false; + case kParseNothing: + break; + } + } + + // Look for Unicode character group like \p{Han} + if (s->size() > 2 && + (*s)[0] == '\\' && + ((*s)[1] == 'p' || (*s)[1] == 'P')) { + switch (ParseUnicodeGroup(s, flags_, re->arguments.char_class.ccb_, status)) { + case kParseOk: + continue; + case kParseError: + re->Decref(); + return false; + case kParseNothing: + break; + } + } + + // Look for Perl character class symbols (extension). + const UGroup *g = MaybeParsePerlCCEscape(s, flags_); + if (g != NULL) { + AddUGroup(re->arguments.char_class.ccb_, g, g->sign, flags_); + continue; + } + + // Otherwise assume single character or simple range. + RuneRange rr; + if (!ParseCCRange(s, &rr, whole_class, status)) { + re->Decref(); + return false; + } + // AddRangeFlags is usually called in response to a class like + // \p{Foo} or [[:foo:]]; for those, it filters \n out unless + // Regexp::ClassNL is set. In an explicit range or singleton + // like we just parsed, we do not filter \n out, so set ClassNL + // in the flags. + re->arguments.char_class.ccb_->AddRangeFlags(rr.lo, rr.hi, flags_ | Regexp::ClassNL); + } + if (s->empty()) { + status->set_code(kRegexpMissingBracket); + status->set_error_arg(whole_class); + re->Decref(); + return false; + } + s->remove_prefix(1); // ']' + + if (negated) + re->arguments.char_class.ccb_->Negate(); + + *out_re = re; + return true; +} + +// Returns whether name is a valid capture name. +static bool IsValidCaptureName(const StringPiece& name) { + if (name.empty()) + return false; + + // Historically, we effectively used [0-9A-Za-z_]+ to validate; that + // followed Python 2 except for not restricting the first character. + // As of Python 3, Unicode characters beyond ASCII are also allowed; + // accordingly, we permit the Lu, Ll, Lt, Lm, Lo, Nl, Mn, Mc, Nd and + // Pc categories, but again without restricting the first character. + // Also, Unicode normalization (e.g. NFKC) isn't performed: Python 3 + // performs it for identifiers, but seemingly not for capture names; + // if they start doing that for capture names, we won't follow suit. + static const CharClass* const cc = []() { + CharClassBuilder ccb; + for (StringPiece group : + {"Lu", "Ll", "Lt", "Lm", "Lo", "Nl", "Mn", "Mc", "Nd", "Pc"}) + AddUGroup(&ccb, LookupGroup(group, unicode_groups, num_unicode_groups), + +1, Regexp::NoParseFlags); + return ccb.GetCharClass(); + }(); + + StringPiece t = name; + Rune r; + while (!t.empty()) { + if (StringPieceToRune(&r, &t, NULL) < 0) + return false; + if (cc->Contains(r)) + continue; + return false; + } + return true; +} + +// Parses a Perl flag setting or non-capturing group or both, +// like (?i) or (?: or (?i:. Removes from s, updates parse state. +// The caller must check that s begins with "(?". +// Returns true on success. If the Perl flag is not +// well-formed or not supported, sets status_ and returns false. +bool Regexp::ParseState::ParsePerlFlags(StringPiece* s) { + StringPiece t = *s; + + // Caller is supposed to check this. + if (!(flags_ & PerlX) || t.size() < 2 || t[0] != '(' || t[1] != '?') { + status_->set_code(kRegexpInternalError); + LOG(DFATAL) << "Bad call to ParseState::ParsePerlFlags"; + return false; + } + + t.remove_prefix(2); // "(?" + + // Check for named captures, first introduced in Python's regexp library. + // As usual, there are three slightly different syntaxes: + // + // (?Pexpr) the original, introduced by Python + // (?expr) the .NET alteration, adopted by Perl 5.10 + // (?'name'expr) another .NET alteration, adopted by Perl 5.10 + // + // Perl 5.10 gave in and implemented the Python version too, + // but they claim that the last two are the preferred forms. + // PCRE and languages based on it (specifically, PHP and Ruby) + // support all three as well. EcmaScript 4 uses only the Python form. + // + // In both the open source world (via Code Search) and the + // Google source tree, (?Pname) is the dominant form, + // so that's the one we implement. One is enough. + if (t.size() > 2 && t[0] == 'P' && t[1] == '<') { + // Pull out name. + size_t end = t.find('>', 2); + if (end == StringPiece::npos) { + if (!IsValidUTF8(*s, status_)) + return false; + status_->set_code(kRegexpBadNamedCapture); + status_->set_error_arg(*s); + return false; + } + + // t is "P...", t[end] == '>' + StringPiece capture(t.data()-2, end+3); // "(?P" + StringPiece name(t.data()+2, end-2); // "name" + if (!IsValidUTF8(name, status_)) + return false; + if (!IsValidCaptureName(name)) { + status_->set_code(kRegexpBadNamedCapture); + status_->set_error_arg(capture); + return false; + } + + if (!DoLeftParen(name)) { + // DoLeftParen's failure set status_. + return false; + } + + s->remove_prefix( + static_cast(capture.data() + capture.size() - s->data())); + return true; + } + + bool negated = false; + bool sawflags = false; + int nflags = flags_; + Rune c; + for (bool done = false; !done; ) { + if (t.empty()) + goto BadPerlOp; + if (StringPieceToRune(&c, &t, status_) < 0) + return false; + switch (c) { + default: + goto BadPerlOp; + + // Parse flags. + case 'i': + sawflags = true; + if (negated) + nflags &= ~FoldCase; + else + nflags |= FoldCase; + break; + + case 'm': // opposite of our OneLine + sawflags = true; + if (negated) + nflags |= OneLine; + else + nflags &= ~OneLine; + break; + + case 's': + sawflags = true; + if (negated) + nflags &= ~DotNL; + else + nflags |= DotNL; + break; + + case 'U': + sawflags = true; + if (negated) + nflags &= ~NonGreedy; + else + nflags |= NonGreedy; + break; + + // Negation + case '-': + if (negated) + goto BadPerlOp; + negated = true; + sawflags = false; + break; + + // Open new group. + case ':': + if (!DoLeftParenNoCapture()) { + // DoLeftParenNoCapture's failure set status_. + return false; + } + done = true; + break; + + // Finish flags. + case ')': + done = true; + break; + } + } + + if (negated && !sawflags) + goto BadPerlOp; + + flags_ = static_cast(nflags); + *s = t; + return true; + +BadPerlOp: + status_->set_code(kRegexpBadPerlOp); + status_->set_error_arg( + StringPiece(s->data(), static_cast(t.data() - s->data()))); + return false; +} + +// Converts latin1 (assumed to be encoded as Latin1 bytes) +// into UTF8 encoding in string. +// Can't use EncodingUtils::EncodeLatin1AsUTF8 because it is +// deprecated and because it rejects code points 0x80-0x9F. +void ConvertLatin1ToUTF8(const StringPiece& latin1, std::string* utf) { + char buf[UTFmax]; + + utf->clear(); + for (size_t i = 0; i < latin1.size(); i++) { + Rune r = latin1[i] & 0xFF; + int n = runetochar(buf, &r); + utf->append(buf, n); + } +} + +// Parses the regular expression given by s, +// returning the corresponding Regexp tree. +// The caller must Decref the return value when done with it. +// Returns NULL on error. +Regexp* Regexp::Parse(const StringPiece& s, ParseFlags global_flags, + RegexpStatus* status) { + // Make status non-NULL (easier on everyone else). + RegexpStatus xstatus; + if (status == NULL) + status = &xstatus; + + ParseState ps(global_flags, s, status); + StringPiece t = s; + + // Convert regexp to UTF-8 (easier on the rest of the parser). + if (global_flags & Latin1) { + std::string* tmp = new std::string; + ConvertLatin1ToUTF8(t, tmp); + status->set_tmp(tmp); + t = *tmp; + } + + if (global_flags & Literal) { + // Special parse loop for literal string. + while (!t.empty()) { + Rune r; + if (StringPieceToRune(&r, &t, status) < 0) + return NULL; + if (!ps.PushLiteral(r)) + return NULL; + } + return ps.DoFinish(); + } + + StringPiece lastunary = StringPiece(); + while (!t.empty()) { + StringPiece isunary = StringPiece(); + switch (t[0]) { + default: { + Rune r; + if (StringPieceToRune(&r, &t, status) < 0) + return NULL; + if (!ps.PushLiteral(r)) + return NULL; + break; + } + + case '(': + // "(?" introduces Perl escape. + if ((ps.flags() & PerlX) && (t.size() >= 2 && t[1] == '?')) { + // Flag changes and non-capturing groups. + if (!ps.ParsePerlFlags(&t)) + return NULL; + break; + } + if (ps.flags() & NeverCapture) { + if (!ps.DoLeftParenNoCapture()) + return NULL; + } else { + if (!ps.DoLeftParen(StringPiece())) + return NULL; + } + t.remove_prefix(1); // '(' + break; + + case '|': + if (!ps.DoVerticalBar()) + return NULL; + t.remove_prefix(1); // '|' + break; + + case ')': + if (!ps.DoRightParen()) + return NULL; + t.remove_prefix(1); // ')' + break; + + case '^': // Beginning of line. + if (!ps.PushCaret()) + return NULL; + t.remove_prefix(1); // '^' + break; + + case '$': // End of line. + if (!ps.PushDollar()) + return NULL; + t.remove_prefix(1); // '$' + break; + + case '.': // Any character (possibly except newline). + if (!ps.PushDot()) + return NULL; + t.remove_prefix(1); // '.' + break; + + case '[': { // Character class. + Regexp* re; + if (!ps.ParseCharClass(&t, &re, status)) + return NULL; + if (!ps.PushRegexp(re)) + return NULL; + break; + } + + case '*': { // Zero or more. + RegexpOp op; + op = kRegexpStar; + goto Rep; + case '+': // One or more. + op = kRegexpPlus; + goto Rep; + case '?': // Zero or one. + op = kRegexpQuest; + goto Rep; + Rep: + StringPiece opstr = t; + bool nongreedy = false; + t.remove_prefix(1); // '*' or '+' or '?' + if (ps.flags() & PerlX) { + if (!t.empty() && t[0] == '?') { + nongreedy = true; + t.remove_prefix(1); // '?' + } + if (!lastunary.empty()) { + // In Perl it is not allowed to stack repetition operators: + // a** is a syntax error, not a double-star. + // (and a++ means something else entirely, which we don't support!) + status->set_code(kRegexpRepeatOp); + status->set_error_arg(StringPiece( + lastunary.data(), + static_cast(t.data() - lastunary.data()))); + return NULL; + } + } + opstr = StringPiece(opstr.data(), + static_cast(t.data() - opstr.data())); + if (!ps.PushRepeatOp(op, opstr, nongreedy)) + return NULL; + isunary = opstr; + break; + } + + case '{': { // Counted repetition. + int lo, hi; + StringPiece opstr = t; + if (!MaybeParseRepetition(&t, &lo, &hi)) { + // Treat like a literal. + if (!ps.PushLiteral('{')) + return NULL; + t.remove_prefix(1); // '{' + break; + } + bool nongreedy = false; + if (ps.flags() & PerlX) { + if (!t.empty() && t[0] == '?') { + nongreedy = true; + t.remove_prefix(1); // '?' + } + if (!lastunary.empty()) { + // Not allowed to stack repetition operators. + status->set_code(kRegexpRepeatOp); + status->set_error_arg(StringPiece( + lastunary.data(), + static_cast(t.data() - lastunary.data()))); + return NULL; + } + } + opstr = StringPiece(opstr.data(), + static_cast(t.data() - opstr.data())); + if (!ps.PushRepetition(lo, hi, opstr, nongreedy)) + return NULL; + isunary = opstr; + break; + } + + case '\\': { // Escaped character or Perl sequence. + // \b and \B: word boundary or not + if ((ps.flags() & Regexp::PerlB) && + t.size() >= 2 && (t[1] == 'b' || t[1] == 'B')) { + if (!ps.PushWordBoundary(t[1] == 'b')) + return NULL; + t.remove_prefix(2); // '\\', 'b' + break; + } + + if ((ps.flags() & Regexp::PerlX) && t.size() >= 2) { + if (t[1] == 'A') { + if (!ps.PushSimpleOp(kRegexpBeginText)) + return NULL; + t.remove_prefix(2); // '\\', 'A' + break; + } + if (t[1] == 'z') { + if (!ps.PushSimpleOp(kRegexpEndText)) + return NULL; + t.remove_prefix(2); // '\\', 'z' + break; + } + // Do not recognize \Z, because this library can't + // implement the exact Perl/PCRE semantics. + // (This library treats "(?-m)$" as \z, even though + // in Perl and PCRE it is equivalent to \Z.) + + if (t[1] == 'C') { // \C: any byte [sic] + if (!ps.PushSimpleOp(kRegexpAnyByte)) + return NULL; + t.remove_prefix(2); // '\\', 'C' + break; + } + + if (t[1] == 'Q') { // \Q ... \E: the ... is always literals + t.remove_prefix(2); // '\\', 'Q' + while (!t.empty()) { + if (t.size() >= 2 && t[0] == '\\' && t[1] == 'E') { + t.remove_prefix(2); // '\\', 'E' + break; + } + Rune r; + if (StringPieceToRune(&r, &t, status) < 0) + return NULL; + if (!ps.PushLiteral(r)) + return NULL; + } + break; + } + } + + if (t.size() >= 2 && (t[1] == 'p' || t[1] == 'P')) { + Regexp* re = new Regexp(kRegexpCharClass, ps.flags() & ~FoldCase); + re->arguments.char_class.ccb_ = new CharClassBuilder; + switch (ParseUnicodeGroup(&t, ps.flags(), re->arguments.char_class.ccb_, status)) { + case kParseOk: + if (!ps.PushRegexp(re)) + return NULL; + goto Break2; + case kParseError: + re->Decref(); + return NULL; + case kParseNothing: + re->Decref(); + break; + } + } + + const UGroup *g = MaybeParsePerlCCEscape(&t, ps.flags()); + if (g != NULL) { + Regexp* re = new Regexp(kRegexpCharClass, ps.flags() & ~FoldCase); + re->arguments.char_class.ccb_ = new CharClassBuilder; + AddUGroup(re->arguments.char_class.ccb_, g, g->sign, ps.flags()); + if (!ps.PushRegexp(re)) + return NULL; + break; + } + + Rune r; + if (!ParseEscape(&t, &r, status, ps.rune_max())) + return NULL; + if (!ps.PushLiteral(r)) + return NULL; + break; + } + } + Break2: + lastunary = isunary; + } + return ps.DoFinish(); +} + +} // namespace re2 diff --git a/internal/cpp/re2/perl_groups.cc b/internal/cpp/re2/perl_groups.cc new file mode 100644 index 00000000000..643c1c3ca77 --- /dev/null +++ b/internal/cpp/re2/perl_groups.cc @@ -0,0 +1,118 @@ +// GENERATED BY make_perl_groups.pl; DO NOT EDIT. +// make_perl_groups.pl >perl_groups.cc + +#include "re2/unicode_groups.h" + +namespace re2 { + +static const URange16 code1[] = { + /* \d */ + {0x30, 0x39}, +}; +static const URange16 code2[] = { + /* \s */ + {0x9, 0xa}, + {0xc, 0xd}, + {0x20, 0x20}, +}; +static const URange16 code3[] = { + /* \w */ + {0x30, 0x39}, + {0x41, 0x5a}, + {0x5f, 0x5f}, + {0x61, 0x7a}, +}; +const UGroup perl_groups[] = { + {"\\d", +1, code1, 1, 0, 0}, + {"\\D", -1, code1, 1, 0, 0}, + {"\\s", +1, code2, 3, 0, 0}, + {"\\S", -1, code2, 3, 0, 0}, + {"\\w", +1, code3, 4, 0, 0}, + {"\\W", -1, code3, 4, 0, 0}, +}; +const int num_perl_groups = 6; +static const URange16 code4[] = { + /* [:alnum:] */ + {0x30, 0x39}, + {0x41, 0x5a}, + {0x61, 0x7a}, +}; +static const URange16 code5[] = { + /* [:alpha:] */ + {0x41, 0x5a}, + {0x61, 0x7a}, +}; +static const URange16 code6[] = { + /* [:ascii:] */ + {0x0, 0x7f}, +}; +static const URange16 code7[] = { + /* [:blank:] */ + {0x9, 0x9}, + {0x20, 0x20}, +}; +static const URange16 code8[] = { + /* [:cntrl:] */ + {0x0, 0x1f}, + {0x7f, 0x7f}, +}; +static const URange16 code9[] = { + /* [:digit:] */ + {0x30, 0x39}, +}; +static const URange16 code10[] = { + /* [:graph:] */ + {0x21, 0x7e}, +}; +static const URange16 code11[] = { + /* [:lower:] */ + {0x61, 0x7a}, +}; +static const URange16 code12[] = { + /* [:print:] */ + {0x20, 0x7e}, +}; +static const URange16 code13[] = { + /* [:punct:] */ + {0x21, 0x2f}, + {0x3a, 0x40}, + {0x5b, 0x60}, + {0x7b, 0x7e}, +}; +static const URange16 code14[] = { + /* [:space:] */ + {0x9, 0xd}, + {0x20, 0x20}, +}; +static const URange16 code15[] = { + /* [:upper:] */ + {0x41, 0x5a}, +}; +static const URange16 code16[] = { + /* [:word:] */ + {0x30, 0x39}, + {0x41, 0x5a}, + {0x5f, 0x5f}, + {0x61, 0x7a}, +}; +static const URange16 code17[] = { + /* [:xdigit:] */ + {0x30, 0x39}, + {0x41, 0x46}, + {0x61, 0x66}, +}; +const UGroup posix_groups[] = { + {"[:alnum:]", +1, code4, 3, 0, 0}, {"[:^alnum:]", -1, code4, 3, 0, 0}, {"[:alpha:]", +1, code5, 2, 0, 0}, + {"[:^alpha:]", -1, code5, 2, 0, 0}, {"[:ascii:]", +1, code6, 1, 0, 0}, {"[:^ascii:]", -1, code6, 1, 0, 0}, + {"[:blank:]", +1, code7, 2, 0, 0}, {"[:^blank:]", -1, code7, 2, 0, 0}, {"[:cntrl:]", +1, code8, 2, 0, 0}, + {"[:^cntrl:]", -1, code8, 2, 0, 0}, {"[:digit:]", +1, code9, 1, 0, 0}, {"[:^digit:]", -1, code9, 1, 0, 0}, + {"[:graph:]", +1, code10, 1, 0, 0}, {"[:^graph:]", -1, code10, 1, 0, 0}, {"[:lower:]", +1, code11, 1, 0, 0}, + {"[:^lower:]", -1, code11, 1, 0, 0}, {"[:print:]", +1, code12, 1, 0, 0}, {"[:^print:]", -1, code12, 1, 0, 0}, + {"[:punct:]", +1, code13, 4, 0, 0}, {"[:^punct:]", -1, code13, 4, 0, 0}, {"[:space:]", +1, code14, 2, 0, 0}, + {"[:^space:]", -1, code14, 2, 0, 0}, {"[:upper:]", +1, code15, 1, 0, 0}, {"[:^upper:]", -1, code15, 1, 0, 0}, + {"[:word:]", +1, code16, 4, 0, 0}, {"[:^word:]", -1, code16, 4, 0, 0}, {"[:xdigit:]", +1, code17, 3, 0, 0}, + {"[:^xdigit:]", -1, code17, 3, 0, 0}, +}; +const int num_posix_groups = 28; + +} // namespace re2 diff --git a/internal/cpp/re2/pod_array.h b/internal/cpp/re2/pod_array.h new file mode 100644 index 00000000000..f234e976f40 --- /dev/null +++ b/internal/cpp/re2/pod_array.h @@ -0,0 +1,55 @@ +// Copyright 2018 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_POD_ARRAY_H_ +#define RE2_POD_ARRAY_H_ + +#include +#include + +namespace re2 { + +template +class PODArray { + public: + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "T must be POD"); + + PODArray() + : ptr_() {} + explicit PODArray(int len) + : ptr_(std::allocator().allocate(len), Deleter(len)) {} + + T* data() const { + return ptr_.get(); + } + + int size() const { + return ptr_.get_deleter().len_; + } + + T& operator[](int pos) const { + return ptr_[pos]; + } + + private: + struct Deleter { + Deleter() + : len_(0) {} + explicit Deleter(int len) + : len_(len) {} + + void operator()(T* ptr) const { + std::allocator().deallocate(ptr, len_); + } + + int len_; + }; + + std::unique_ptr ptr_; +}; + +} // namespace re2 + +#endif // RE2_POD_ARRAY_H_ diff --git a/internal/cpp/re2/prefilter.cc b/internal/cpp/re2/prefilter.cc new file mode 100644 index 00000000000..d20e5711aaf --- /dev/null +++ b/internal/cpp/re2/prefilter.cc @@ -0,0 +1,663 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/prefilter.h" + +#include +#include +#include +#include +#include + +#include "re2/re2.h" +#include "re2/unicode_casefold.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +// Initializes a Prefilter, allocating subs_ as necessary. +Prefilter::Prefilter(Op op) { + op_ = op; + subs_ = NULL; + if (op_ == AND || op_ == OR) + subs_ = new std::vector; +} + +// Destroys a Prefilter. +Prefilter::~Prefilter() { + if (subs_) { + for (size_t i = 0; i < subs_->size(); i++) + delete (*subs_)[i]; + delete subs_; + subs_ = NULL; + } +} + +// Simplify if the node is an empty Or or And. +Prefilter *Prefilter::Simplify() { + if (op_ != AND && op_ != OR) { + return this; + } + + // Nothing left in the AND/OR. + if (subs_->empty()) { + if (op_ == AND) + op_ = ALL; // AND of nothing is true + else + op_ = NONE; // OR of nothing is false + + return this; + } + + // Just one subnode: throw away wrapper. + if (subs_->size() == 1) { + Prefilter *a = (*subs_)[0]; + subs_->clear(); + delete this; + return a->Simplify(); + } + + return this; +} + +// Combines two Prefilters together to create an "op" (AND or OR). +// The passed Prefilters will be part of the returned Prefilter or deleted. +// Does lots of work to avoid creating unnecessarily complicated structures. +Prefilter *Prefilter::AndOr(Op op, Prefilter *a, Prefilter *b) { + // If a, b can be rewritten as op, do so. + a = a->Simplify(); + b = b->Simplify(); + + // Canonicalize: a->op <= b->op. + if (a->op() > b->op()) { + Prefilter *t = a; + a = b; + b = t; + } + + // Trivial cases. + // ALL AND b = b + // NONE OR b = b + // ALL OR b = ALL + // NONE AND b = NONE + // Don't need to look at b, because of canonicalization above. + // ALL and NONE are smallest opcodes. + if (a->op() == ALL || a->op() == NONE) { + if ((a->op() == ALL && op == AND) || (a->op() == NONE && op == OR)) { + delete a; + return b; + } else { + delete b; + return a; + } + } + + // If a and b match op, merge their contents. + if (a->op() == op && b->op() == op) { + for (size_t i = 0; i < b->subs()->size(); i++) { + Prefilter *bb = (*b->subs())[i]; + a->subs()->push_back(bb); + } + b->subs()->clear(); + delete b; + return a; + } + + // If a already has the same op as the op that is under construction + // add in b (similarly if b already has the same op, add in a). + if (b->op() == op) { + Prefilter *t = a; + a = b; + b = t; + } + if (a->op() == op) { + a->subs()->push_back(b); + return a; + } + + // Otherwise just return the op. + Prefilter *c = new Prefilter(op); + c->subs()->push_back(a); + c->subs()->push_back(b); + return c; +} + +Prefilter *Prefilter::And(Prefilter *a, Prefilter *b) { return AndOr(AND, a, b); } + +Prefilter *Prefilter::Or(Prefilter *a, Prefilter *b) { return AndOr(OR, a, b); } + +void Prefilter::SimplifyStringSet(SSet *ss) { + // Now make sure that the strings aren't redundant. For example, if + // we know "ab" is a required string, then it doesn't help at all to + // know that "abc" is also a required string, so delete "abc". This + // is because, when we are performing a string search to filter + // regexps, matching "ab" will already allow this regexp to be a + // candidate for match, so further matching "abc" is redundant. + // Note that we must ignore "" because find() would find it at the + // start of everything and thus we would end up erasing everything. + // + // The SSet sorts strings by length, then lexicographically. Note that + // smaller strings appear first and all strings must be unique. These + // observations let us skip string comparisons when possible. + SSIter i = ss->begin(); + if (i != ss->end() && i->empty()) { + ++i; + } + for (; i != ss->end(); ++i) { + SSIter j = i; + ++j; + while (j != ss->end()) { + if (j->size() > i->size() && j->find(*i) != std::string::npos) { + j = ss->erase(j); + continue; + } + ++j; + } + } +} + +Prefilter *Prefilter::OrStrings(SSet *ss) { + Prefilter *or_prefilter = new Prefilter(NONE); + SimplifyStringSet(ss); + for (SSIter i = ss->begin(); i != ss->end(); ++i) + or_prefilter = Or(or_prefilter, FromString(*i)); + return or_prefilter; +} + +static Rune ToLowerRune(Rune r) { + if (r < Runeself) { + if ('A' <= r && r <= 'Z') + r += 'a' - 'A'; + return r; + } + + const CaseFold *f = LookupCaseFold(unicode_tolower, num_unicode_tolower, r); + if (f == NULL || r < f->lo) + return r; + return ApplyFold(f, r); +} + +static Rune ToLowerRuneLatin1(Rune r) { + if ('A' <= r && r <= 'Z') + r += 'a' - 'A'; + return r; +} + +Prefilter *Prefilter::FromString(const std::string &str) { + Prefilter *m = new Prefilter(Prefilter::ATOM); + m->atom_ = str; + return m; +} + +// Information about a regexp used during computation of Prefilter. +// Can be thought of as information about the set of strings matching +// the given regular expression. +class Prefilter::Info { +public: + Info(); + ~Info(); + + // More constructors. They delete their Info* arguments. + static Info *Alt(Info *a, Info *b); + static Info *Concat(Info *a, Info *b); + static Info *And(Info *a, Info *b); + static Info *Star(Info *a); + static Info *Plus(Info *a); + static Info *Quest(Info *a); + static Info *EmptyString(); + static Info *NoMatch(); + static Info *AnyCharOrAnyByte(); + static Info *CClass(CharClass *cc, bool latin1); + static Info *Literal(Rune r); + static Info *LiteralLatin1(Rune r); + static Info *AnyMatch(); + + // Format Info as a string. + std::string ToString(); + + // Caller takes ownership of the Prefilter. + Prefilter *TakeMatch(); + + SSet &exact() { return exact_; } + + bool is_exact() const { return is_exact_; } + + class Walker; + +private: + SSet exact_; + + // When is_exact_ is true, the strings that match + // are placed in exact_. When it is no longer an exact + // set of strings that match this RE, then is_exact_ + // is false and the match_ contains the required match + // criteria. + bool is_exact_; + + // Accumulated Prefilter query that any + // match for this regexp is guaranteed to match. + Prefilter *match_; +}; + +Prefilter::Info::Info() : is_exact_(false), match_(NULL) {} + +Prefilter::Info::~Info() { delete match_; } + +Prefilter *Prefilter::Info::TakeMatch() { + if (is_exact_) { + match_ = Prefilter::OrStrings(&exact_); + is_exact_ = false; + } + Prefilter *m = match_; + match_ = NULL; + return m; +} + +// Format a Info in string form. +std::string Prefilter::Info::ToString() { + if (is_exact_) { + int n = 0; + std::string s; + for (SSIter i = exact_.begin(); i != exact_.end(); ++i) { + if (n++ > 0) + s += ","; + s += *i; + } + return s; + } + + if (match_) + return match_->DebugString(); + + return ""; +} + +void Prefilter::CrossProduct(const SSet &a, const SSet &b, SSet *dst) { + for (ConstSSIter i = a.begin(); i != a.end(); ++i) + for (ConstSSIter j = b.begin(); j != b.end(); ++j) + dst->insert(*i + *j); +} + +// Concats a and b. Requires that both are exact sets. +// Forms an exact set that is a crossproduct of a and b. +Prefilter::Info *Prefilter::Info::Concat(Info *a, Info *b) { + if (a == NULL) + return b; + DCHECK(a->is_exact_); + DCHECK(b && b->is_exact_); + Info *ab = new Info(); + + CrossProduct(a->exact_, b->exact_, &ab->exact_); + ab->is_exact_ = true; + + delete a; + delete b; + return ab; +} + +// Constructs an inexact Info for ab given a and b. +// Used only when a or b is not exact or when the +// exact cross product is likely to be too big. +Prefilter::Info *Prefilter::Info::And(Info *a, Info *b) { + if (a == NULL) + return b; + if (b == NULL) + return a; + + Info *ab = new Info(); + + ab->match_ = Prefilter::And(a->TakeMatch(), b->TakeMatch()); + ab->is_exact_ = false; + delete a; + delete b; + return ab; +} + +// Constructs Info for a|b given a and b. +Prefilter::Info *Prefilter::Info::Alt(Info *a, Info *b) { + Info *ab = new Info(); + + if (a->is_exact_ && b->is_exact_) { + // Avoid string copies by moving the larger exact_ set into + // ab directly, then merge in the smaller set. + if (a->exact_.size() < b->exact_.size()) { + using std::swap; + swap(a, b); + } + ab->exact_ = std::move(a->exact_); + ab->exact_.insert(b->exact_.begin(), b->exact_.end()); + ab->is_exact_ = true; + } else { + // Either a or b has is_exact_ = false. If the other + // one has is_exact_ = true, we move it to match_ and + // then create a OR of a,b. The resulting Info has + // is_exact_ = false. + ab->match_ = Prefilter::Or(a->TakeMatch(), b->TakeMatch()); + ab->is_exact_ = false; + } + + delete a; + delete b; + return ab; +} + +// Constructs Info for a? given a. +Prefilter::Info *Prefilter::Info::Quest(Info *a) { + Info *ab = new Info(); + + ab->is_exact_ = false; + ab->match_ = new Prefilter(ALL); + delete a; + return ab; +} + +// Constructs Info for a* given a. +// Same as a? -- not much to do. +Prefilter::Info *Prefilter::Info::Star(Info *a) { return Quest(a); } + +// Constructs Info for a+ given a. If a was exact set, it isn't +// anymore. +Prefilter::Info *Prefilter::Info::Plus(Info *a) { + Info *ab = new Info(); + + ab->match_ = a->TakeMatch(); + ab->is_exact_ = false; + + delete a; + return ab; +} + +static std::string RuneToString(Rune r) { + char buf[UTFmax]; + int n = runetochar(buf, &r); + return std::string(buf, n); +} + +static std::string RuneToStringLatin1(Rune r) { + char c = r & 0xff; + return std::string(&c, 1); +} + +// Constructs Info for literal rune. +Prefilter::Info *Prefilter::Info::Literal(Rune r) { + Info *info = new Info(); + info->exact_.insert(RuneToString(ToLowerRune(r))); + info->is_exact_ = true; + return info; +} + +// Constructs Info for literal rune for Latin1 encoded string. +Prefilter::Info *Prefilter::Info::LiteralLatin1(Rune r) { + Info *info = new Info(); + info->exact_.insert(RuneToStringLatin1(ToLowerRuneLatin1(r))); + info->is_exact_ = true; + return info; +} + +// Constructs Info for dot (any character) or \C (any byte). +Prefilter::Info *Prefilter::Info::AnyCharOrAnyByte() { + Prefilter::Info *info = new Prefilter::Info(); + info->match_ = new Prefilter(ALL); + return info; +} + +// Constructs Prefilter::Info for no possible match. +Prefilter::Info *Prefilter::Info::NoMatch() { + Prefilter::Info *info = new Prefilter::Info(); + info->match_ = new Prefilter(NONE); + return info; +} + +// Constructs Prefilter::Info for any possible match. +// This Prefilter::Info is valid for any regular expression, +// since it makes no assertions whatsoever about the +// strings being matched. +Prefilter::Info *Prefilter::Info::AnyMatch() { + Prefilter::Info *info = new Prefilter::Info(); + info->match_ = new Prefilter(ALL); + return info; +} + +// Constructs Prefilter::Info for just the empty string. +Prefilter::Info *Prefilter::Info::EmptyString() { + Prefilter::Info *info = new Prefilter::Info(); + info->is_exact_ = true; + info->exact_.insert(""); + return info; +} + +// Constructs Prefilter::Info for a character class. +typedef CharClass::iterator CCIter; +Prefilter::Info *Prefilter::Info::CClass(CharClass *cc, bool latin1) { + + // If the class is too large, it's okay to overestimate. + if (cc->size() > 10) + return AnyCharOrAnyByte(); + + Prefilter::Info *a = new Prefilter::Info(); + for (CCIter i = cc->begin(); i != cc->end(); ++i) + for (Rune r = i->lo; r <= i->hi; r++) { + if (latin1) { + a->exact_.insert(RuneToStringLatin1(ToLowerRuneLatin1(r))); + } else { + a->exact_.insert(RuneToString(ToLowerRune(r))); + } + } + + a->is_exact_ = true; + return a; +} + +class Prefilter::Info::Walker : public Regexp::Walker { +public: + Walker(bool latin1) : latin1_(latin1) {} + + virtual Info *PostVisit(Regexp *re, Info *parent_arg, Info *pre_arg, Info **child_args, int nchild_args); + + virtual Info *ShortVisit(Regexp *re, Info *parent_arg); + + bool latin1() { return latin1_; } + +private: + bool latin1_; + + Walker(const Walker &) = delete; + Walker &operator=(const Walker &) = delete; +}; + +Prefilter::Info *Prefilter::BuildInfo(Regexp *re) { + bool latin1 = (re->parse_flags() & Regexp::Latin1) != 0; + Prefilter::Info::Walker w(latin1); + Prefilter::Info *info = w.WalkExponential(re, NULL, 100000); + + if (w.stopped_early()) { + delete info; + return NULL; + } + + return info; +} + +Prefilter::Info *Prefilter::Info::Walker::ShortVisit(Regexp *re, Prefilter::Info *parent_arg) { return AnyMatch(); } + +// Constructs the Prefilter::Info for the given regular expression. +// Assumes re is simplified. +Prefilter::Info * +Prefilter::Info::Walker::PostVisit(Regexp *re, Prefilter::Info *parent_arg, Prefilter::Info *pre_arg, Prefilter::Info **child_args, int nchild_args) { + Prefilter::Info *info; + switch (re->op()) { + default: + case kRegexpRepeat: + info = EmptyString(); + LOG(DFATAL) << "Bad regexp op " << re->op(); + break; + + case kRegexpNoMatch: + info = NoMatch(); + break; + + // These ops match the empty string: + case kRegexpEmptyMatch: // anywhere + case kRegexpBeginLine: // at beginning of line + case kRegexpEndLine: // at end of line + case kRegexpBeginText: // at beginning of text + case kRegexpEndText: // at end of text + case kRegexpWordBoundary: // at word boundary + case kRegexpNoWordBoundary: // not at word boundary + info = EmptyString(); + break; + + case kRegexpLiteral: + if (latin1()) { + info = LiteralLatin1(re->rune()); + } else { + info = Literal(re->rune()); + } + break; + + case kRegexpLiteralString: + if (re->nrunes() == 0) { + info = NoMatch(); + break; + } + if (latin1()) { + info = LiteralLatin1(re->runes()[0]); + for (int i = 1; i < re->nrunes(); i++) { + info = Concat(info, LiteralLatin1(re->runes()[i])); + } + } else { + info = Literal(re->runes()[0]); + for (int i = 1; i < re->nrunes(); i++) { + info = Concat(info, Literal(re->runes()[i])); + } + } + break; + + case kRegexpConcat: { + // Accumulate in info. + // Exact is concat of recent contiguous exact nodes. + info = NULL; + Info *exact = NULL; + for (int i = 0; i < nchild_args; i++) { + Info *ci = child_args[i]; // child info + if (!ci->is_exact() || (exact && ci->exact().size() * exact->exact().size() > 16)) { + // Exact run is over. + info = And(info, exact); + exact = NULL; + // Add this child's info. + info = And(info, ci); + } else { + // Append to exact run. + exact = Concat(exact, ci); + } + } + info = And(info, exact); + } break; + + case kRegexpAlternate: + info = child_args[0]; + for (int i = 1; i < nchild_args; i++) + info = Alt(info, child_args[i]); + break; + + case kRegexpStar: + info = Star(child_args[0]); + break; + + case kRegexpQuest: + info = Quest(child_args[0]); + break; + + case kRegexpPlus: + info = Plus(child_args[0]); + break; + + case kRegexpAnyChar: + case kRegexpAnyByte: + // Claim nothing, except that it's not empty. + info = AnyCharOrAnyByte(); + break; + + case kRegexpCharClass: + info = CClass(re->cc(), latin1()); + break; + + case kRegexpCapture: + // These don't affect the set of matching strings. + info = child_args[0]; + break; + } + + return info; +} + +Prefilter *Prefilter::FromRegexp(Regexp *re) { + if (re == NULL) + return NULL; + + Regexp *simple = re->Simplify(); + if (simple == NULL) + return NULL; + + Prefilter::Info *info = BuildInfo(simple); + simple->Decref(); + if (info == NULL) + return NULL; + + Prefilter *m = info->TakeMatch(); + delete info; + return m; +} + +std::string Prefilter::DebugString() const { + switch (op_) { + default: + LOG(DFATAL) << "Bad op in Prefilter::DebugString: " << op_; + return StringPrintf("op%d", op_); + case NONE: + return "*no-matches*"; + case ATOM: + return atom_; + case ALL: + return ""; + case AND: { + std::string s = ""; + for (size_t i = 0; i < subs_->size(); i++) { + if (i > 0) + s += " "; + Prefilter *sub = (*subs_)[i]; + s += sub ? sub->DebugString() : ""; + } + return s; + } + case OR: { + std::string s = "("; + for (size_t i = 0; i < subs_->size(); i++) { + if (i > 0) + s += "|"; + Prefilter *sub = (*subs_)[i]; + s += sub ? sub->DebugString() : ""; + } + s += ")"; + return s; + } + } +} + +Prefilter *Prefilter::FromRE2(const RE2 *re2) { + if (re2 == NULL) + return NULL; + + Regexp *regexp = re2->Regexp(); + if (regexp == NULL) + return NULL; + + return FromRegexp(regexp); +} + +} // namespace re2 diff --git a/internal/cpp/re2/prefilter.h b/internal/cpp/re2/prefilter.h new file mode 100644 index 00000000000..e149e59a866 --- /dev/null +++ b/internal/cpp/re2/prefilter.h @@ -0,0 +1,130 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_PREFILTER_H_ +#define RE2_PREFILTER_H_ + +// Prefilter is the class used to extract string guards from regexps. +// Rather than using Prefilter class directly, use FilteredRE2. +// See filtered_re2.h + +#include +#include +#include + +#include "util/util.h" +#include "util/logging.h" + +namespace re2 { + +class RE2; + +class Regexp; + +class Prefilter { + // Instead of using Prefilter directly, use FilteredRE2; see filtered_re2.h + public: + enum Op { + ALL = 0, // Everything matches + NONE, // Nothing matches + ATOM, // The string atom() must match + AND, // All in subs() must match + OR, // One of subs() must match + }; + + explicit Prefilter(Op op); + ~Prefilter(); + + Op op() { return op_; } + const std::string& atom() const { return atom_; } + void set_unique_id(int id) { unique_id_ = id; } + int unique_id() const { return unique_id_; } + + // The children of the Prefilter node. + std::vector* subs() { + DCHECK(op_ == AND || op_ == OR); + return subs_; + } + + // Set the children vector. Prefilter takes ownership of subs and + // subs_ will be deleted when Prefilter is deleted. + void set_subs(std::vector* subs) { subs_ = subs; } + + // Given a RE2, return a Prefilter. The caller takes ownership of + // the Prefilter and should deallocate it. Returns NULL if Prefilter + // cannot be formed. + static Prefilter* FromRE2(const RE2* re2); + + // Returns a readable debug string of the prefilter. + std::string DebugString() const; + + private: + // A comparator used to store exact strings. We compare by length, + // then lexicographically. This ordering makes it easier to reduce the + // set of strings in SimplifyStringSet. + struct LengthThenLex { + bool operator()(const std::string& a, const std::string& b) const { + return (a.size() < b.size()) || (a.size() == b.size() && a < b); + } + }; + + class Info; + + using SSet = std::set; + using SSIter = SSet::iterator; + using ConstSSIter = SSet::const_iterator; + + // Combines two prefilters together to create an AND. The passed + // Prefilters will be part of the returned Prefilter or deleted. + static Prefilter* And(Prefilter* a, Prefilter* b); + + // Combines two prefilters together to create an OR. The passed + // Prefilters will be part of the returned Prefilter or deleted. + static Prefilter* Or(Prefilter* a, Prefilter* b); + + // Generalized And/Or + static Prefilter* AndOr(Op op, Prefilter* a, Prefilter* b); + + static Prefilter* FromRegexp(Regexp* a); + + static Prefilter* FromString(const std::string& str); + + static Prefilter* OrStrings(SSet* ss); + + static Info* BuildInfo(Regexp* re); + + Prefilter* Simplify(); + + // Removes redundant strings from the set. A string is redundant if + // any of the other strings appear as a substring. The empty string + // is a special case, which is ignored. + static void SimplifyStringSet(SSet* ss); + + // Adds the cross-product of a and b to dst. + // (For each string i in a and j in b, add i+j.) + static void CrossProduct(const SSet& a, const SSet& b, SSet* dst); + + // Kind of Prefilter. + Op op_; + + // Sub-matches for AND or OR Prefilter. + std::vector* subs_; + + // Actual string to match in leaf node. + std::string atom_; + + // If different prefilters have the same string atom, or if they are + // structurally the same (e.g., OR of same atom strings) they are + // considered the same unique nodes. This is the id for each unique + // node. This field is populated with a unique id for every node, + // and -1 for duplicate nodes. + int unique_id_; + + Prefilter(const Prefilter&) = delete; + Prefilter& operator=(const Prefilter&) = delete; +}; + +} // namespace re2 + +#endif // RE2_PREFILTER_H_ diff --git a/internal/cpp/re2/prefilter_tree.cc b/internal/cpp/re2/prefilter_tree.cc new file mode 100644 index 00000000000..755395309f5 --- /dev/null +++ b/internal/cpp/re2/prefilter_tree.cc @@ -0,0 +1,370 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/prefilter_tree.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "re2/prefilter.h" +#include "re2/re2.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/util.h" + +namespace re2 { + +PrefilterTree::PrefilterTree() : compiled_(false), min_atom_len_(3) {} + +PrefilterTree::PrefilterTree(int min_atom_len) : compiled_(false), min_atom_len_(min_atom_len) {} + +PrefilterTree::~PrefilterTree() { + for (size_t i = 0; i < prefilter_vec_.size(); i++) + delete prefilter_vec_[i]; +} + +void PrefilterTree::Add(Prefilter *prefilter) { + if (compiled_) { + LOG(DFATAL) << "Add called after Compile."; + return; + } + if (prefilter != NULL && !KeepNode(prefilter)) { + delete prefilter; + prefilter = NULL; + } + + prefilter_vec_.push_back(prefilter); +} + +void PrefilterTree::Compile(std::vector *atom_vec) { + if (compiled_) { + LOG(DFATAL) << "Compile called already."; + return; + } + + // Some legacy users of PrefilterTree call Compile() before + // adding any regexps and expect Compile() to have no effect. + if (prefilter_vec_.empty()) + return; + + compiled_ = true; + + NodeMap nodes; + AssignUniqueIds(&nodes, atom_vec); +} + +Prefilter *PrefilterTree::CanonicalNode(NodeMap *nodes, Prefilter *node) { + std::string node_string = NodeString(node); + NodeMap::iterator iter = nodes->find(node_string); + if (iter == nodes->end()) + return NULL; + return (*iter).second; +} + +std::string PrefilterTree::NodeString(Prefilter *node) const { + // Adding the operation disambiguates AND/OR/atom nodes. + std::string s = StringPrintf("%d", node->op()) + ":"; + if (node->op() == Prefilter::ATOM) { + s += node->atom(); + } else { + for (size_t i = 0; i < node->subs()->size(); i++) { + if (i > 0) + s += ','; + s += StringPrintf("%d", (*node->subs())[i]->unique_id()); + } + } + return s; +} + +bool PrefilterTree::KeepNode(Prefilter *node) const { + if (node == NULL) + return false; + + switch (node->op()) { + default: + LOG(DFATAL) << "Unexpected op in KeepNode: " << node->op(); + return false; + + case Prefilter::ALL: + case Prefilter::NONE: + return false; + + case Prefilter::ATOM: + return node->atom().size() >= static_cast(min_atom_len_); + + case Prefilter::AND: { + int j = 0; + std::vector *subs = node->subs(); + for (size_t i = 0; i < subs->size(); i++) + if (KeepNode((*subs)[i])) + (*subs)[j++] = (*subs)[i]; + else + delete (*subs)[i]; + + subs->resize(j); + return j > 0; + } + + case Prefilter::OR: + for (size_t i = 0; i < node->subs()->size(); i++) + if (!KeepNode((*node->subs())[i])) + return false; + return true; + } +} + +void PrefilterTree::AssignUniqueIds(NodeMap *nodes, std::vector *atom_vec) { + atom_vec->clear(); + + // Build vector of all filter nodes, sorted topologically + // from top to bottom in v. + std::vector v; + + // Add the top level nodes of each regexp prefilter. + for (size_t i = 0; i < prefilter_vec_.size(); i++) { + Prefilter *f = prefilter_vec_[i]; + if (f == NULL) + unfiltered_.push_back(static_cast(i)); + + // We push NULL also on to v, so that we maintain the + // mapping of index==regexpid for level=0 prefilter nodes. + v.push_back(f); + } + + // Now add all the descendant nodes. + for (size_t i = 0; i < v.size(); i++) { + Prefilter *f = v[i]; + if (f == NULL) + continue; + if (f->op() == Prefilter::AND || f->op() == Prefilter::OR) { + const std::vector &subs = *f->subs(); + for (size_t j = 0; j < subs.size(); j++) + v.push_back(subs[j]); + } + } + + // Identify unique nodes. + int unique_id = 0; + for (int i = static_cast(v.size()) - 1; i >= 0; i--) { + Prefilter *node = v[i]; + if (node == NULL) + continue; + node->set_unique_id(-1); + Prefilter *canonical = CanonicalNode(nodes, node); + if (canonical == NULL) { + // Any further nodes that have the same node string + // will find this node as the canonical node. + nodes->emplace(NodeString(node), node); + if (node->op() == Prefilter::ATOM) { + atom_vec->push_back(node->atom()); + atom_index_to_id_.push_back(unique_id); + } + node->set_unique_id(unique_id++); + } else { + node->set_unique_id(canonical->unique_id()); + } + } + entries_.resize(unique_id); + + // Fill the entries. + for (int i = static_cast(v.size()) - 1; i >= 0; i--) { + Prefilter *prefilter = v[i]; + if (prefilter == NULL) + continue; + if (CanonicalNode(nodes, prefilter) != prefilter) + continue; + int id = prefilter->unique_id(); + switch (prefilter->op()) { + default: + LOG(DFATAL) << "Unexpected op: " << prefilter->op(); + return; + + case Prefilter::ATOM: + entries_[id].propagate_up_at_count = 1; + break; + + case Prefilter::OR: + case Prefilter::AND: { + // For each child, we append our id to the child's list of + // parent ids... unless we happen to have done so already. + // The number of appends is the number of unique children, + // which allows correct upward propagation from AND nodes. + int up_count = 0; + for (size_t j = 0; j < prefilter->subs()->size(); j++) { + int child_id = (*prefilter->subs())[j]->unique_id(); + std::vector &parents = entries_[child_id].parents; + if (parents.empty() || parents.back() != id) { + parents.push_back(id); + up_count++; + } + } + entries_[id].propagate_up_at_count = prefilter->op() == Prefilter::AND ? up_count : 1; + break; + } + } + } + + // For top level nodes, populate regexp id. + for (size_t i = 0; i < prefilter_vec_.size(); i++) { + if (prefilter_vec_[i] == NULL) + continue; + int id = CanonicalNode(nodes, prefilter_vec_[i])->unique_id(); + DCHECK_LE(0, id); + Entry *entry = &entries_[id]; + entry->regexps.push_back(static_cast(i)); + } + + // Lastly, using probability-based heuristics, we identify nodes + // that trigger too many parents and then we try to prune edges. + // We use logarithms below to avoid the likelihood of underflow. + double log_num_regexps = std::log(prefilter_vec_.size() - unfiltered_.size()); + // Hoisted this above the loop so that we don't thrash the heap. + std::vector> entries_by_num_edges; + for (int i = static_cast(v.size()) - 1; i >= 0; i--) { + Prefilter *prefilter = v[i]; + // Pruning applies only to AND nodes because it "just" reduces + // precision; applied to OR nodes, it would break correctness. + if (prefilter == NULL || prefilter->op() != Prefilter::AND) + continue; + if (CanonicalNode(nodes, prefilter) != prefilter) + continue; + int id = prefilter->unique_id(); + + // Sort the current node's children by the numbers of parents. + entries_by_num_edges.clear(); + for (size_t j = 0; j < prefilter->subs()->size(); j++) { + int child_id = (*prefilter->subs())[j]->unique_id(); + const std::vector &parents = entries_[child_id].parents; + entries_by_num_edges.emplace_back(parents.size(), child_id); + } + std::stable_sort(entries_by_num_edges.begin(), entries_by_num_edges.end()); + + // A running estimate of how many regexps will be triggered by + // pruning the remaining children's edges to the current node. + // Our nominal target is one, so the threshold is log(1) == 0; + // pruning occurs iff the child has more than nine edges left. + double log_num_triggered = log_num_regexps; + for (const auto &pair : entries_by_num_edges) { + int child_id = pair.second; + std::vector &parents = entries_[child_id].parents; + if (log_num_triggered > 0.) { + log_num_triggered += std::log(parents.size()); + log_num_triggered -= log_num_regexps; + } else if (parents.size() > 9) { + auto it = std::find(parents.begin(), parents.end(), id); + if (it != parents.end()) { + parents.erase(it); + entries_[id].propagate_up_at_count--; + } + } + } + } +} + +// Functions for triggering during search. +void PrefilterTree::RegexpsGivenStrings(const std::vector &matched_atoms, std::vector *regexps) const { + regexps->clear(); + if (!compiled_) { + // Some legacy users of PrefilterTree call Compile() before + // adding any regexps and expect Compile() to have no effect. + // This kludge is a counterpart to that kludge. + if (prefilter_vec_.empty()) + return; + + LOG(ERROR) << "RegexpsGivenStrings called before Compile."; + for (size_t i = 0; i < prefilter_vec_.size(); i++) + regexps->push_back(static_cast(i)); + } else { + IntMap regexps_map(static_cast(prefilter_vec_.size())); + std::vector matched_atom_ids; + for (size_t j = 0; j < matched_atoms.size(); j++) + matched_atom_ids.push_back(atom_index_to_id_[matched_atoms[j]]); + PropagateMatch(matched_atom_ids, ®exps_map); + for (IntMap::iterator it = regexps_map.begin(); it != regexps_map.end(); ++it) + regexps->push_back(it->index()); + + regexps->insert(regexps->end(), unfiltered_.begin(), unfiltered_.end()); + } + std::sort(regexps->begin(), regexps->end()); +} + +void PrefilterTree::PropagateMatch(const std::vector &atom_ids, IntMap *regexps) const { + IntMap count(static_cast(entries_.size())); + IntMap work(static_cast(entries_.size())); + for (size_t i = 0; i < atom_ids.size(); i++) + work.set(atom_ids[i], 1); + for (IntMap::iterator it = work.begin(); it != work.end(); ++it) { + const Entry &entry = entries_[it->index()]; + // Record regexps triggered. + for (size_t i = 0; i < entry.regexps.size(); i++) + regexps->set(entry.regexps[i], 1); + int c; + // Pass trigger up to parents. + for (int j : entry.parents) { + const Entry &parent = entries_[j]; + // Delay until all the children have succeeded. + if (parent.propagate_up_at_count > 1) { + if (count.has_index(j)) { + c = count.get_existing(j) + 1; + count.set_existing(j, c); + } else { + c = 1; + count.set_new(j, c); + } + if (c < parent.propagate_up_at_count) + continue; + } + // Trigger the parent. + work.set(j, 1); + } + } +} + +// Debugging help. +void PrefilterTree::PrintPrefilter(int regexpid) { LOG(ERROR) << DebugNodeString(prefilter_vec_[regexpid]); } + +void PrefilterTree::PrintDebugInfo(NodeMap *nodes) { + LOG(ERROR) << "#Unique Atoms: " << atom_index_to_id_.size(); + LOG(ERROR) << "#Unique Nodes: " << entries_.size(); + + for (size_t i = 0; i < entries_.size(); i++) { + const std::vector &parents = entries_[i].parents; + const std::vector ®exps = entries_[i].regexps; + LOG(ERROR) << "EntryId: " << i << " N: " << parents.size() << " R: " << regexps.size(); + for (int parent : parents) + LOG(ERROR) << parent; + } + LOG(ERROR) << "Map:"; + for (NodeMap::const_iterator iter = nodes->begin(); iter != nodes->end(); ++iter) + LOG(ERROR) << "NodeId: " << (*iter).second->unique_id() << " Str: " << (*iter).first; +} + +std::string PrefilterTree::DebugNodeString(Prefilter *node) const { + std::string node_string = ""; + if (node->op() == Prefilter::ATOM) { + DCHECK(!node->atom().empty()); + node_string += node->atom(); + } else { + // Adding the operation disambiguates AND and OR nodes. + node_string += node->op() == Prefilter::AND ? "AND" : "OR"; + node_string += "("; + for (size_t i = 0; i < node->subs()->size(); i++) { + if (i > 0) + node_string += ','; + node_string += StringPrintf("%d", (*node->subs())[i]->unique_id()); + node_string += ":"; + node_string += DebugNodeString((*node->subs())[i]); + } + node_string += ")"; + } + return node_string; +} + +} // namespace re2 diff --git a/internal/cpp/re2/prefilter_tree.h b/internal/cpp/re2/prefilter_tree.h new file mode 100644 index 00000000000..2a293ed7ff0 --- /dev/null +++ b/internal/cpp/re2/prefilter_tree.h @@ -0,0 +1,138 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_PREFILTER_TREE_H_ +#define RE2_PREFILTER_TREE_H_ + +// The PrefilterTree class is used to form an AND-OR tree of strings +// that would trigger each regexp. The 'prefilter' of each regexp is +// added to PrefilterTree, and then PrefilterTree is used to find all +// the unique strings across the prefilters. During search, by using +// matches from a string matching engine, PrefilterTree deduces the +// set of regexps that are to be triggered. The 'string matching +// engine' itself is outside of this class, and the caller can use any +// favorite engine. PrefilterTree provides a set of strings (called +// atoms) that the user of this class should use to do the string +// matching. + +#include +#include +#include + +#include "re2/prefilter.h" +#include "re2/sparse_array.h" +#include "util/util.h" + +namespace re2 { + +class PrefilterTree { +public: + PrefilterTree(); + explicit PrefilterTree(int min_atom_len); + ~PrefilterTree(); + + // Adds the prefilter for the next regexp. Note that we assume that + // Add called sequentially for all regexps. All Add calls + // must precede Compile. + void Add(Prefilter *prefilter); + + // The Compile returns a vector of string in atom_vec. + // Call this after all the prefilters are added through Add. + // No calls to Add after Compile are allowed. + // The caller should use the returned set of strings to do string matching. + // Each time a string matches, the corresponding index then has to be + // and passed to RegexpsGivenStrings below. + void Compile(std::vector *atom_vec); + + // Given the indices of the atoms that matched, returns the indexes + // of regexps that should be searched. The matched_atoms should + // contain all the ids of string atoms that were found to match the + // content. The caller can use any string match engine to perform + // this function. This function is thread safe. + void RegexpsGivenStrings(const std::vector &matched_atoms, std::vector *regexps) const; + + // Print debug prefilter. Also prints unique ids associated with + // nodes of the prefilter of the regexp. + void PrintPrefilter(int regexpid); + +private: + typedef SparseArray IntMap; + // TODO(junyer): Use std::unordered_set instead? + // It should be trivial to get rid of the stringification... + typedef std::map NodeMap; + + // Each unique node has a corresponding Entry that helps in + // passing the matching trigger information along the tree. + struct Entry { + public: + // How many children should match before this node triggers the + // parent. For an atom and an OR node, this is 1 and for an AND + // node, it is the number of unique children. + int propagate_up_at_count; + + // When this node is ready to trigger the parent, what are the indices + // of the parent nodes to trigger. The reason there may be more than + // one is because of sharing. For example (abc | def) and (xyz | def) + // are two different nodes, but they share the atom 'def'. So when + // 'def' matches, it triggers two parents, corresponding to the two + // different OR nodes. + std::vector parents; + + // When this node is ready to trigger the parent, what are the + // regexps that are triggered. + std::vector regexps; + }; + + // Returns true if the prefilter node should be kept. + bool KeepNode(Prefilter *node) const; + + // This function assigns unique ids to various parts of the + // prefilter, by looking at if these nodes are already in the + // PrefilterTree. + void AssignUniqueIds(NodeMap *nodes, std::vector *atom_vec); + + // Given the matching atoms, find the regexps to be triggered. + void PropagateMatch(const std::vector &atom_ids, IntMap *regexps) const; + + // Returns the prefilter node that has the same NodeString as this + // node. For the canonical node, returns node. + Prefilter *CanonicalNode(NodeMap *nodes, Prefilter *node); + + // A string that uniquely identifies the node. Assumes that the + // children of node has already been assigned unique ids. + std::string NodeString(Prefilter *node) const; + + // Recursively constructs a readable prefilter string. + std::string DebugNodeString(Prefilter *node) const; + + // Used for debugging. + void PrintDebugInfo(NodeMap *nodes); + + // These are all the nodes formed by Compile. Essentially, there is + // one node for each unique atom and each unique AND/OR node. + std::vector entries_; + + // indices of regexps that always pass through the filter (since we + // found no required literals in these regexps). + std::vector unfiltered_; + + // vector of Prefilter for all regexps. + std::vector prefilter_vec_; + + // Atom index in returned strings to entry id mapping. + std::vector atom_index_to_id_; + + // Has the prefilter tree been compiled. + bool compiled_; + + // Strings less than this length are not stored as atoms. + const int min_atom_len_; + + PrefilterTree(const PrefilterTree &) = delete; + PrefilterTree &operator=(const PrefilterTree &) = delete; +}; + +} // namespace re2 + +#endif // RE2_PREFILTER_TREE_H_ diff --git a/internal/cpp/re2/prog.cc b/internal/cpp/re2/prog.cc new file mode 100644 index 00000000000..ad7661deefa --- /dev/null +++ b/internal/cpp/re2/prog.cc @@ -0,0 +1,1158 @@ +// Copyright 2007 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Compiled regular expression representation. +// Tested by compile_test.cc + +#include "re2/prog.h" + +#if defined(__AVX2__) +#include +#ifdef _MSC_VER +#include +#endif +#endif +#include +#include +#include +#include +#include + +#include "re2/bitmap256.h" +#include "re2/stringpiece.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/util.h" + +namespace re2 { + +// Constructors per Inst opcode + +void Prog::Inst::InitAlt(uint32_t out, uint32_t out1) { + DCHECK_EQ(out_opcode_, 0); + set_out_opcode(out, kInstAlt); + out1_ = out1; +} + +void Prog::Inst::InitByteRange(int lo, int hi, int foldcase, uint32_t out) { + DCHECK_EQ(out_opcode_, 0); + set_out_opcode(out, kInstByteRange); + byte_range.lo_ = lo & 0xFF; + byte_range.hi_ = hi & 0xFF; + byte_range.hint_foldcase_ = foldcase & 1; +} + +void Prog::Inst::InitCapture(int cap, uint32_t out) { + DCHECK_EQ(out_opcode_, 0); + set_out_opcode(out, kInstCapture); + cap_ = cap; +} + +void Prog::Inst::InitEmptyWidth(EmptyOp empty, uint32_t out) { + DCHECK_EQ(out_opcode_, 0); + set_out_opcode(out, kInstEmptyWidth); + empty_ = empty; +} + +void Prog::Inst::InitMatch(int32_t id) { + DCHECK_EQ(out_opcode_, 0); + set_opcode(kInstMatch); + match_id_ = id; +} + +void Prog::Inst::InitNop(uint32_t out) { + DCHECK_EQ(out_opcode_, 0); + set_opcode(kInstNop); +} + +void Prog::Inst::InitFail() { + DCHECK_EQ(out_opcode_, 0); + set_opcode(kInstFail); +} + +std::string Prog::Inst::Dump() { + switch (opcode()) { + default: + return StringPrintf("opcode %d", static_cast(opcode())); + + case kInstAlt: + return StringPrintf("alt -> %d | %d", out(), out1_); + + case kInstAltMatch: + return StringPrintf("altmatch -> %d | %d", out(), out1_); + + case kInstByteRange: + return StringPrintf("byte%s [%02x-%02x] %d -> %d", foldcase() ? "/i" : "", byte_range.lo_, byte_range.hi_, hint(), out()); + + case kInstCapture: + return StringPrintf("capture %d -> %d", cap_, out()); + + case kInstEmptyWidth: + return StringPrintf("emptywidth %#x -> %d", static_cast(empty_), out()); + + case kInstMatch: + return StringPrintf("match! %d", match_id()); + + case kInstNop: + return StringPrintf("nop -> %d", out()); + + case kInstFail: + return StringPrintf("fail"); + } +} + +Prog::Prog() + : anchor_start_(false), anchor_end_(false), reversed_(false), did_flatten_(false), did_onepass_(false), start_(0), start_unanchored_(0), size_(0), + bytemap_range_(0), prefix_foldcase_(false), prefix_size_(0), list_count_(0), bit_state_text_max_size_(0), dfa_mem_(0), dfa_first_(NULL), + dfa_longest_(NULL) {} + +Prog::~Prog() { + DeleteDFA(dfa_longest_); + DeleteDFA(dfa_first_); + if (prefix_foldcase_) + delete[] prefix_dfa_; +} + +typedef SparseSet Workq; + +static inline void AddToQueue(Workq* q, int id) { + if (id != 0) + q->insert(id); +} + +static std::string ProgToString(Prog* prog, Workq* q) { + std::string s; + for (Workq::iterator i = q->begin(); i != q->end(); ++i) { + int id = *i; + Prog::Inst* ip = prog->inst(id); + s += StringPrintf("%d. %s\n", id, ip->Dump().c_str()); + AddToQueue(q, ip->out()); + if (ip->opcode() == kInstAlt || ip->opcode() == kInstAltMatch) + AddToQueue(q, ip->out1()); + } + return s; +} + +static std::string FlattenedProgToString(Prog* prog, int start) { + std::string s; + for (int id = start; id < prog->size(); id++) { + Prog::Inst* ip = prog->inst(id); + if (ip->last()) + s += StringPrintf("%d. %s\n", id, ip->Dump().c_str()); + else + s += StringPrintf("%d+ %s\n", id, ip->Dump().c_str()); + } + return s; +} + +std::string Prog::Dump() { + if (did_flatten_) + return FlattenedProgToString(this, start_); + + Workq q(size_); + AddToQueue(&q, start_); + return ProgToString(this, &q); +} + +std::string Prog::DumpUnanchored() { + if (did_flatten_) + return FlattenedProgToString(this, start_unanchored_); + + Workq q(size_); + AddToQueue(&q, start_unanchored_); + return ProgToString(this, &q); +} + +std::string Prog::DumpByteMap() { + std::string map; + for (int c = 0; c < 256; c++) { + int b = bytemap_[c]; + int lo = c; + while (c < 256-1 && bytemap_[c+1] == b) + c++; + int hi = c; + map += StringPrintf("[%02x-%02x] -> %d\n", lo, hi, b); + } + return map; +} + +// Is ip a guaranteed match at end of text, perhaps after some capturing? +static bool IsMatch(Prog* prog, Prog::Inst* ip) { + for (;;) { + switch (ip->opcode()) { + default: + LOG(DFATAL) << "Unexpected opcode in IsMatch: " << ip->opcode(); + return false; + + case kInstAlt: + case kInstAltMatch: + case kInstByteRange: + case kInstFail: + case kInstEmptyWidth: + return false; + + case kInstCapture: + case kInstNop: + ip = prog->inst(ip->out()); + break; + + case kInstMatch: + return true; + } + } +} + +// Peep-hole optimizer. +void Prog::Optimize() { + Workq q(size_); + + // Eliminate nops. Most are taken out during compilation + // but a few are hard to avoid. + q.clear(); + AddToQueue(&q, start_); + for (Workq::iterator i = q.begin(); i != q.end(); ++i) { + int id = *i; + + Inst* ip = inst(id); + int j = ip->out(); + Inst* jp; + while (j != 0 && (jp=inst(j))->opcode() == kInstNop) { + j = jp->out(); + } + ip->set_out(j); + AddToQueue(&q, ip->out()); + + if (ip->opcode() == kInstAlt) { + j = ip->out1(); + while (j != 0 && (jp=inst(j))->opcode() == kInstNop) { + j = jp->out(); + } + ip->out1_ = j; + AddToQueue(&q, ip->out1()); + } + } + + // Insert kInstAltMatch instructions + // Look for + // ip: Alt -> j | k + // j: ByteRange [00-FF] -> ip + // k: Match + // or the reverse (the above is the greedy one). + // Rewrite Alt to AltMatch. + q.clear(); + AddToQueue(&q, start_); + for (Workq::iterator i = q.begin(); i != q.end(); ++i) { + int id = *i; + Inst* ip = inst(id); + AddToQueue(&q, ip->out()); + if (ip->opcode() == kInstAlt) + AddToQueue(&q, ip->out1()); + + if (ip->opcode() == kInstAlt) { + Inst* j = inst(ip->out()); + Inst* k = inst(ip->out1()); + if (j->opcode() == kInstByteRange && j->out() == id && + j->lo() == 0x00 && j->hi() == 0xFF && + IsMatch(this, k)) { + ip->set_opcode(kInstAltMatch); + continue; + } + if (IsMatch(this, j) && + k->opcode() == kInstByteRange && k->out() == id && + k->lo() == 0x00 && k->hi() == 0xFF) { + ip->set_opcode(kInstAltMatch); + } + } + } +} + +uint32_t Prog::EmptyFlags(const StringPiece& text, const char* p) { + int flags = 0; + + // ^ and \A + if (p == text.data()) + flags |= kEmptyBeginText | kEmptyBeginLine; + else if (p[-1] == '\n') + flags |= kEmptyBeginLine; + + // $ and \z + if (p == text.data() + text.size()) + flags |= kEmptyEndText | kEmptyEndLine; + else if (p < text.data() + text.size() && p[0] == '\n') + flags |= kEmptyEndLine; + + // \b and \B + if (p == text.data() && p == text.data() + text.size()) { + // no word boundary here + } else if (p == text.data()) { + if (IsWordChar(p[0])) + flags |= kEmptyWordBoundary; + } else if (p == text.data() + text.size()) { + if (IsWordChar(p[-1])) + flags |= kEmptyWordBoundary; + } else { + if (IsWordChar(p[-1]) != IsWordChar(p[0])) + flags |= kEmptyWordBoundary; + } + if (!(flags & kEmptyWordBoundary)) + flags |= kEmptyNonWordBoundary; + + return flags; +} + +// ByteMapBuilder implements a coloring algorithm. +// +// The first phase is a series of "mark and merge" batches: we mark one or more +// [lo-hi] ranges, then merge them into our internal state. Batching is not for +// performance; rather, it means that the ranges are treated indistinguishably. +// +// Internally, the ranges are represented using a bitmap that stores the splits +// and a vector that stores the colors; both of them are indexed by the ranges' +// last bytes. Thus, in order to merge a [lo-hi] range, we split at lo-1 and at +// hi (if not already split), then recolor each range in between. The color map +// (i.e. from the old color to the new color) is maintained for the lifetime of +// the batch and so underpins this somewhat obscure approach to set operations. +// +// The second phase builds the bytemap from our internal state: we recolor each +// range, then store the new color (which is now the byte class) in each of the +// corresponding array elements. Finally, we output the number of byte classes. +class ByteMapBuilder { + public: + ByteMapBuilder() { + // Initial state: the [0-255] range has color 256. + // This will avoid problems during the second phase, + // in which we assign byte classes numbered from 0. + splits_.Set(255); + colors_[255] = 256; + nextcolor_ = 257; + } + + void Mark(int lo, int hi); + void Merge(); + void Build(uint8_t* bytemap, int* bytemap_range); + + private: + int Recolor(int oldcolor); + + Bitmap256 splits_; + int colors_[256]; + int nextcolor_; + std::vector> colormap_; + std::vector> ranges_; + + ByteMapBuilder(const ByteMapBuilder&) = delete; + ByteMapBuilder& operator=(const ByteMapBuilder&) = delete; +}; + +void ByteMapBuilder::Mark(int lo, int hi) { + DCHECK_GE(lo, 0); + DCHECK_GE(hi, 0); + DCHECK_LE(lo, 255); + DCHECK_LE(hi, 255); + DCHECK_LE(lo, hi); + + // Ignore any [0-255] ranges. They cause us to recolor every range, which + // has no effect on the eventual result and is therefore a waste of time. + if (lo == 0 && hi == 255) + return; + + ranges_.emplace_back(lo, hi); +} + +void ByteMapBuilder::Merge() { + for (std::vector>::const_iterator it = ranges_.begin(); + it != ranges_.end(); + ++it) { + int lo = it->first-1; + int hi = it->second; + + if (0 <= lo && !splits_.Test(lo)) { + splits_.Set(lo); + int next = splits_.FindNextSetBit(lo+1); + colors_[lo] = colors_[next]; + } + if (!splits_.Test(hi)) { + splits_.Set(hi); + int next = splits_.FindNextSetBit(hi+1); + colors_[hi] = colors_[next]; + } + + int c = lo+1; + while (c < 256) { + int next = splits_.FindNextSetBit(c); + colors_[next] = Recolor(colors_[next]); + if (next == hi) + break; + c = next+1; + } + } + colormap_.clear(); + ranges_.clear(); +} + +void ByteMapBuilder::Build(uint8_t* bytemap, int* bytemap_range) { + // Assign byte classes numbered from 0. + nextcolor_ = 0; + + int c = 0; + while (c < 256) { + int next = splits_.FindNextSetBit(c); + uint8_t b = static_cast(Recolor(colors_[next])); + while (c <= next) { + bytemap[c] = b; + c++; + } + } + + *bytemap_range = nextcolor_; +} + +int ByteMapBuilder::Recolor(int oldcolor) { + // Yes, this is a linear search. There can be at most 256 + // colors and there will typically be far fewer than that. + // Also, we need to consider keys *and* values in order to + // avoid recoloring a given range more than once per batch. + std::vector>::const_iterator it = + std::find_if(colormap_.begin(), colormap_.end(), + [=](const std::pair& kv) -> bool { + return kv.first == oldcolor || kv.second == oldcolor; + }); + if (it != colormap_.end()) + return it->second; + int newcolor = nextcolor_; + nextcolor_++; + colormap_.emplace_back(oldcolor, newcolor); + return newcolor; +} + +void Prog::ComputeByteMap() { + // Fill in bytemap with byte classes for the program. + // Ranges of bytes that are treated indistinguishably + // will be mapped to a single byte class. + ByteMapBuilder builder; + + // Don't repeat the work for ^ and $. + bool marked_line_boundaries = false; + // Don't repeat the work for \b and \B. + bool marked_word_boundaries = false; + + for (int id = 0; id < size(); id++) { + Inst* ip = inst(id); + if (ip->opcode() == kInstByteRange) { + int lo = ip->lo(); + int hi = ip->hi(); + builder.Mark(lo, hi); + if (ip->foldcase() && lo <= 'z' && hi >= 'a') { + int foldlo = lo; + int foldhi = hi; + if (foldlo < 'a') + foldlo = 'a'; + if (foldhi > 'z') + foldhi = 'z'; + if (foldlo <= foldhi) { + foldlo += 'A' - 'a'; + foldhi += 'A' - 'a'; + builder.Mark(foldlo, foldhi); + } + } + // If this Inst is not the last Inst in its list AND the next Inst is + // also a ByteRange AND the Insts have the same out, defer the merge. + if (!ip->last() && + inst(id+1)->opcode() == kInstByteRange && + ip->out() == inst(id+1)->out()) + continue; + builder.Merge(); + } else if (ip->opcode() == kInstEmptyWidth) { + if (ip->empty() & (kEmptyBeginLine|kEmptyEndLine) && + !marked_line_boundaries) { + builder.Mark('\n', '\n'); + builder.Merge(); + marked_line_boundaries = true; + } + if (ip->empty() & (kEmptyWordBoundary|kEmptyNonWordBoundary) && + !marked_word_boundaries) { + // We require two batches here: the first for ranges that are word + // characters, the second for ranges that are not word characters. + for (bool isword : {true, false}) { + int j; + for (int i = 0; i < 256; i = j) { + for (j = i + 1; j < 256 && + Prog::IsWordChar(static_cast(i)) == + Prog::IsWordChar(static_cast(j)); + j++) + ; + if (Prog::IsWordChar(static_cast(i)) == isword) + builder.Mark(i, j - 1); + } + builder.Merge(); + } + marked_word_boundaries = true; + } + } + } + + builder.Build(bytemap_, &bytemap_range_); + + if ((0)) { // For debugging, use trivial bytemap. + LOG(ERROR) << "Using trivial bytemap."; + for (int i = 0; i < 256; i++) + bytemap_[i] = static_cast(i); + bytemap_range_ = 256; + } +} + +// Prog::Flatten() implements a graph rewriting algorithm. +// +// The overall process is similar to epsilon removal, but retains some epsilon +// transitions: those from Capture and EmptyWidth instructions; and those from +// nullable subexpressions. (The latter avoids quadratic blowup in transitions +// in the worst case.) It might be best thought of as Alt instruction elision. +// +// In conceptual terms, it divides the Prog into "trees" of instructions, then +// traverses the "trees" in order to produce "lists" of instructions. A "tree" +// is one or more instructions that grow from one "root" instruction to one or +// more "leaf" instructions; if a "tree" has exactly one instruction, then the +// "root" is also the "leaf". In most cases, a "root" is the successor of some +// "leaf" (i.e. the "leaf" instruction's out() returns the "root" instruction) +// and is considered a "successor root". A "leaf" can be a ByteRange, Capture, +// EmptyWidth or Match instruction. However, this is insufficient for handling +// nested nullable subexpressions correctly, so in some cases, a "root" is the +// dominator of the instructions reachable from some "successor root" (i.e. it +// has an unreachable predecessor) and is considered a "dominator root". Since +// only Alt instructions can be "dominator roots" (other instructions would be +// "leaves"), only Alt instructions are required to be marked as predecessors. +// +// Dividing the Prog into "trees" comprises two passes: marking the "successor +// roots" and the predecessors; and marking the "dominator roots". Sorting the +// "successor roots" by their bytecode offsets enables iteration in order from +// greatest to least during the second pass; by working backwards in this case +// and flooding the graph no further than "leaves" and already marked "roots", +// it becomes possible to mark "dominator roots" without doing excessive work. +// +// Traversing the "trees" is just iterating over the "roots" in order of their +// marking and flooding the graph no further than "leaves" and "roots". When a +// "leaf" is reached, the instruction is copied with its successor remapped to +// its "root" number. When a "root" is reached, a Nop instruction is generated +// with its successor remapped similarly. As each "list" is produced, its last +// instruction is marked as such. After all of the "lists" have been produced, +// a pass over their instructions remaps their successors to bytecode offsets. +void Prog::Flatten() { + if (did_flatten_) + return; + did_flatten_ = true; + + // Scratch structures. It's important that these are reused by functions + // that we call in loops because they would thrash the heap otherwise. + SparseSet reachable(size()); + std::vector stk; + stk.reserve(size()); + + // First pass: Marks "successor roots" and predecessors. + // Builds the mapping from inst-ids to root-ids. + SparseArray rootmap(size()); + SparseArray predmap(size()); + std::vector> predvec; + MarkSuccessors(&rootmap, &predmap, &predvec, &reachable, &stk); + + // Second pass: Marks "dominator roots". + SparseArray sorted(rootmap); + std::sort(sorted.begin(), sorted.end(), sorted.less); + for (SparseArray::const_iterator i = sorted.end() - 1; + i != sorted.begin(); + --i) { + if (i->index() != start_unanchored() && i->index() != start()) + MarkDominator(i->index(), &rootmap, &predmap, &predvec, &reachable, &stk); + } + + // Third pass: Emits "lists". Remaps outs to root-ids. + // Builds the mapping from root-ids to flat-ids. + std::vector flatmap(rootmap.size()); + std::vector flat; + flat.reserve(size()); + for (SparseArray::const_iterator i = rootmap.begin(); + i != rootmap.end(); + ++i) { + flatmap[i->value()] = static_cast(flat.size()); + EmitList(i->index(), &rootmap, &flat, &reachable, &stk); + flat.back().set_last(); + // We have the bounds of the "list", so this is the + // most convenient point at which to compute hints. + ComputeHints(&flat, flatmap[i->value()], static_cast(flat.size())); + } + + list_count_ = static_cast(flatmap.size()); + for (int i = 0; i < kNumInst; i++) + inst_count_[i] = 0; + + // Fourth pass: Remaps outs to flat-ids. + // Counts instructions by opcode. + for (int id = 0; id < static_cast(flat.size()); id++) { + Inst* ip = &flat[id]; + if (ip->opcode() != kInstAltMatch) // handled in EmitList() + ip->set_out(flatmap[ip->out()]); + inst_count_[ip->opcode()]++; + } + +#if !defined(NDEBUG) + // Address a `-Wunused-but-set-variable' warning from Clang 13.x. + size_t total = 0; + for (int i = 0; i < kNumInst; i++) + total += inst_count_[i]; + CHECK_EQ(total, flat.size()); +#endif + + // Remap start_unanchored and start. + if (start_unanchored() == 0) { + DCHECK_EQ(start(), 0); + } else if (start_unanchored() == start()) { + set_start_unanchored(flatmap[1]); + set_start(flatmap[1]); + } else { + set_start_unanchored(flatmap[1]); + set_start(flatmap[2]); + } + + // Finally, replace the old instructions with the new instructions. + size_ = static_cast(flat.size()); + inst_ = PODArray(size_); + memmove(inst_.data(), flat.data(), size_*sizeof inst_[0]); + + // Populate the list heads for BitState. + // 512 instructions limits the memory footprint to 1KiB. + if (size_ <= 512) { + list_heads_ = PODArray(size_); + // 0xFF makes it more obvious if we try to look up a non-head. + memset(list_heads_.data(), 0xFF, size_*sizeof list_heads_[0]); + for (int i = 0; i < list_count_; ++i) + list_heads_[flatmap[i]] = i; + } + + // BitState allocates a bitmap of size list_count_ * (text.size()+1) + // for tracking pairs of possibilities that it has already explored. + const size_t kBitStateBitmapMaxSize = 256*1024; // max size in bits + bit_state_text_max_size_ = kBitStateBitmapMaxSize / list_count_ - 1; +} + +void Prog::MarkSuccessors(SparseArray* rootmap, + SparseArray* predmap, + std::vector>* predvec, + SparseSet* reachable, std::vector* stk) { + // Mark the kInstFail instruction. + rootmap->set_new(0, rootmap->size()); + + // Mark the start_unanchored and start instructions. + if (!rootmap->has_index(start_unanchored())) + rootmap->set_new(start_unanchored(), rootmap->size()); + if (!rootmap->has_index(start())) + rootmap->set_new(start(), rootmap->size()); + + reachable->clear(); + stk->clear(); + stk->push_back(start_unanchored()); + while (!stk->empty()) { + int id = stk->back(); + stk->pop_back(); + Loop: + if (reachable->contains(id)) + continue; + reachable->insert_new(id); + + Inst* ip = inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstAltMatch: + case kInstAlt: + // Mark this instruction as a predecessor of each out. + for (int out : {ip->out(), ip->out1()}) { + if (!predmap->has_index(out)) { + predmap->set_new(out, static_cast(predvec->size())); + predvec->emplace_back(); + } + (*predvec)[predmap->get_existing(out)].emplace_back(id); + } + stk->push_back(ip->out1()); + id = ip->out(); + goto Loop; + + case kInstByteRange: + case kInstCapture: + case kInstEmptyWidth: + // Mark the out of this instruction as a "root". + if (!rootmap->has_index(ip->out())) + rootmap->set_new(ip->out(), rootmap->size()); + id = ip->out(); + goto Loop; + + case kInstNop: + id = ip->out(); + goto Loop; + + case kInstMatch: + case kInstFail: + break; + } + } +} + +void Prog::MarkDominator(int root, SparseArray* rootmap, + SparseArray* predmap, + std::vector>* predvec, + SparseSet* reachable, std::vector* stk) { + reachable->clear(); + stk->clear(); + stk->push_back(root); + while (!stk->empty()) { + int id = stk->back(); + stk->pop_back(); + Loop: + if (reachable->contains(id)) + continue; + reachable->insert_new(id); + + if (id != root && rootmap->has_index(id)) { + // We reached another "tree" via epsilon transition. + continue; + } + + Inst* ip = inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstAltMatch: + case kInstAlt: + stk->push_back(ip->out1()); + id = ip->out(); + goto Loop; + + case kInstByteRange: + case kInstCapture: + case kInstEmptyWidth: + break; + + case kInstNop: + id = ip->out(); + goto Loop; + + case kInstMatch: + case kInstFail: + break; + } + } + + for (SparseSet::const_iterator i = reachable->begin(); + i != reachable->end(); + ++i) { + int id = *i; + if (predmap->has_index(id)) { + for (int pred : (*predvec)[predmap->get_existing(id)]) { + if (!reachable->contains(pred)) { + // id has a predecessor that cannot be reached from root! + // Therefore, id must be a "root" too - mark it as such. + if (!rootmap->has_index(id)) + rootmap->set_new(id, rootmap->size()); + } + } + } + } +} + +void Prog::EmitList(int root, SparseArray* rootmap, + std::vector* flat, + SparseSet* reachable, std::vector* stk) { + reachable->clear(); + stk->clear(); + stk->push_back(root); + while (!stk->empty()) { + int id = stk->back(); + stk->pop_back(); + Loop: + if (reachable->contains(id)) + continue; + reachable->insert_new(id); + + if (id != root && rootmap->has_index(id)) { + // We reached another "tree" via epsilon transition. Emit a kInstNop + // instruction so that the Prog does not become quadratically larger. + flat->emplace_back(); + flat->back().set_opcode(kInstNop); + flat->back().set_out(rootmap->get_existing(id)); + continue; + } + + Inst* ip = inst(id); + switch (ip->opcode()) { + default: + LOG(DFATAL) << "unhandled opcode: " << ip->opcode(); + break; + + case kInstAltMatch: + flat->emplace_back(); + flat->back().set_opcode(kInstAltMatch); + flat->back().set_out(static_cast(flat->size())); + flat->back().out1_ = static_cast(flat->size())+1; + FALLTHROUGH_INTENDED; + + case kInstAlt: + stk->push_back(ip->out1()); + id = ip->out(); + goto Loop; + + case kInstByteRange: + case kInstCapture: + case kInstEmptyWidth: + flat->emplace_back(); + memmove(&flat->back(), ip, sizeof *ip); + flat->back().set_out(rootmap->get_existing(ip->out())); + break; + + case kInstNop: + id = ip->out(); + goto Loop; + + case kInstMatch: + case kInstFail: + flat->emplace_back(); + memmove(&flat->back(), ip, sizeof *ip); + break; + } + } +} + +// For each ByteRange instruction in [begin, end), computes a hint to execution +// engines: the delta to the next instruction (in flat) worth exploring iff the +// current instruction matched. +// +// Implements a coloring algorithm related to ByteMapBuilder, but in this case, +// colors are instructions and recoloring ranges precisely identifies conflicts +// between instructions. Iterating backwards over [begin, end) is guaranteed to +// identify the nearest conflict (if any) with only linear complexity. +void Prog::ComputeHints(std::vector* flat, int begin, int end) { + Bitmap256 splits; + int colors[256]; + + bool dirty = false; + for (int id = end; id >= begin; --id) { + if (id == end || + (*flat)[id].opcode() != kInstByteRange) { + if (dirty) { + dirty = false; + splits.Clear(); + } + splits.Set(255); + colors[255] = id; + // At this point, the [0-255] range is colored with id. + // Thus, hints cannot point beyond id; and if id == end, + // hints that would have pointed to id will be 0 instead. + continue; + } + dirty = true; + + // We recolor the [lo-hi] range with id. Note that first ratchets backwards + // from end to the nearest conflict (if any) during recoloring. + int first = end; + auto Recolor = [&](int lo, int hi) { + // Like ByteMapBuilder, we split at lo-1 and at hi. + --lo; + + if (0 <= lo && !splits.Test(lo)) { + splits.Set(lo); + int next = splits.FindNextSetBit(lo+1); + colors[lo] = colors[next]; + } + if (!splits.Test(hi)) { + splits.Set(hi); + int next = splits.FindNextSetBit(hi+1); + colors[hi] = colors[next]; + } + + int c = lo+1; + while (c < 256) { + int next = splits.FindNextSetBit(c); + // Ratchet backwards... + first = std::min(first, colors[next]); + // Recolor with id - because it's the new nearest conflict! + colors[next] = id; + if (next == hi) + break; + c = next+1; + } + }; + + Inst* ip = &(*flat)[id]; + int lo = ip->lo(); + int hi = ip->hi(); + Recolor(lo, hi); + if (ip->foldcase() && lo <= 'z' && hi >= 'a') { + int foldlo = lo; + int foldhi = hi; + if (foldlo < 'a') + foldlo = 'a'; + if (foldhi > 'z') + foldhi = 'z'; + if (foldlo <= foldhi) { + foldlo += 'A' - 'a'; + foldhi += 'A' - 'a'; + Recolor(foldlo, foldhi); + } + } + + if (first != end) { + uint16_t hint = static_cast(std::min(first - id, 32767)); + ip->byte_range.hint_foldcase_ |= hint<<1; + } + } +} + +// The final state will always be this, which frees up a register for the hot +// loop and thus avoids the spilling that can occur when building with Clang. +static const size_t kShiftDFAFinal = 9; + +// This function takes the prefix as std::string (i.e. not const std::string& +// as normal) because it's going to clobber it, so a temporary is convenient. +static uint64_t* BuildShiftDFA(std::string prefix) { + // This constant is for convenience now and also for correctness later when + // we clobber the prefix, but still need to know how long it was initially. + const size_t size = prefix.size(); + + // Construct the NFA. + // The table is indexed by input byte; each element is a bitfield of states + // reachable by the input byte. Given a bitfield of the current states, the + // bitfield of states reachable from those is - for this specific purpose - + // always ((ncurr << 1) | 1). Intersecting the reachability bitfields gives + // the bitfield of the next states reached by stepping over the input byte. + // Credits for this technique: the Hyperscan paper by Geoff Langdale et al. + uint16_t nfa[256]{}; + for (size_t i = 0; i < size; ++i) { + uint8_t b = prefix[i]; + nfa[b] |= 1 << (i+1); + } + // This is the `\C*?` for unanchored search. + for (int b = 0; b < 256; ++b) + nfa[b] |= 1; + + // This maps from DFA state to NFA states; the reverse mapping is used when + // recording transitions and gets implemented with plain old linear search. + // The "Shift DFA" technique limits this to ten states when using uint64_t; + // to allow for the initial state, we use at most nine bytes of the prefix. + // That same limit is also why uint16_t is sufficient for the NFA bitfield. + uint16_t states[kShiftDFAFinal+1]{}; + states[0] = 1; + for (size_t dcurr = 0; dcurr < size; ++dcurr) { + uint8_t b = prefix[dcurr]; + uint16_t ncurr = states[dcurr]; + uint16_t nnext = nfa[b] & ((ncurr << 1) | 1); + size_t dnext = dcurr+1; + if (dnext == size) + dnext = kShiftDFAFinal; + states[dnext] = nnext; + } + + // Sort and unique the bytes of the prefix to avoid repeating work while we + // record transitions. This clobbers the prefix, but it's no longer needed. + std::sort(prefix.begin(), prefix.end()); + prefix.erase(std::unique(prefix.begin(), prefix.end()), prefix.end()); + + // Construct the DFA. + // The table is indexed by input byte; each element is effectively a packed + // array of uint6_t; each array value will be multiplied by six in order to + // avoid having to do so later in the hot loop as well as masking/shifting. + // Credits for this technique: "Shift-based DFAs" on GitHub by Per Vognsen. + uint64_t* dfa = new uint64_t[256]{}; + // Record a transition from each state for each of the bytes of the prefix. + // Note that all other input bytes go back to the initial state by default. + for (size_t dcurr = 0; dcurr < size; ++dcurr) { + for (uint8_t b : prefix) { + uint16_t ncurr = states[dcurr]; + uint16_t nnext = nfa[b] & ((ncurr << 1) | 1); + size_t dnext = 0; + while (states[dnext] != nnext) + ++dnext; + dfa[b] |= static_cast(dnext * 6) << (dcurr * 6); + // Convert ASCII letters to uppercase and record the extra transitions. + // Note that ASCII letters are guaranteed to be lowercase at this point + // because that's how the parser normalises them. #FunFact: 'k' and 's' + // match U+212A and U+017F, respectively, so they won't occur here when + // using UTF-8 encoding because the parser will emit character classes. + if ('a' <= b && b <= 'z') { + b -= 'a' - 'A'; + dfa[b] |= static_cast(dnext * 6) << (dcurr * 6); + } + } + } + // This lets the final state "saturate", which will matter for performance: + // in the hot loop, we check for a match only at the end of each iteration, + // so we must keep signalling the match until we get around to checking it. + for (int b = 0; b < 256; ++b) + dfa[b] |= static_cast(kShiftDFAFinal * 6) << (kShiftDFAFinal * 6); + + return dfa; +} + +void Prog::ConfigurePrefixAccel(const std::string& prefix, + bool prefix_foldcase) { + prefix_foldcase_ = prefix_foldcase; + prefix_size_ = prefix.size(); + if (prefix_foldcase_) { + // Use PrefixAccel_ShiftDFA(). + // ... and no more than nine bytes of the prefix. (See above for details.) + prefix_size_ = std::min(prefix_size_, kShiftDFAFinal); + prefix_dfa_ = BuildShiftDFA(prefix.substr(0, prefix_size_)); + } else if (prefix_size_ != 1) { + // Use PrefixAccel_FrontAndBack(). + prefix_front_back.prefix_front_ = prefix.front(); + prefix_front_back.prefix_back_ = prefix.back(); + } else { + // Use memchr(3). + prefix_front_back.prefix_front_ = prefix.front(); + } +} + +const void* Prog::PrefixAccel_ShiftDFA(const void* data, size_t size) { + if (size < prefix_size_) + return NULL; + + uint64_t curr = 0; + + // At the time of writing, rough benchmarks on a Broadwell machine showed + // that this unroll factor (i.e. eight) achieves a speedup factor of two. + if (size >= 8) { + const uint8_t* p = reinterpret_cast(data); + const uint8_t* endp = p + (size&~7); + do { + uint8_t b0 = p[0]; + uint8_t b1 = p[1]; + uint8_t b2 = p[2]; + uint8_t b3 = p[3]; + uint8_t b4 = p[4]; + uint8_t b5 = p[5]; + uint8_t b6 = p[6]; + uint8_t b7 = p[7]; + + uint64_t next0 = prefix_dfa_[b0]; + uint64_t next1 = prefix_dfa_[b1]; + uint64_t next2 = prefix_dfa_[b2]; + uint64_t next3 = prefix_dfa_[b3]; + uint64_t next4 = prefix_dfa_[b4]; + uint64_t next5 = prefix_dfa_[b5]; + uint64_t next6 = prefix_dfa_[b6]; + uint64_t next7 = prefix_dfa_[b7]; + + uint64_t curr0 = next0 >> (curr & 63); + uint64_t curr1 = next1 >> (curr0 & 63); + uint64_t curr2 = next2 >> (curr1 & 63); + uint64_t curr3 = next3 >> (curr2 & 63); + uint64_t curr4 = next4 >> (curr3 & 63); + uint64_t curr5 = next5 >> (curr4 & 63); + uint64_t curr6 = next6 >> (curr5 & 63); + uint64_t curr7 = next7 >> (curr6 & 63); + + if ((curr7 & 63) == kShiftDFAFinal * 6) { + // At the time of writing, using the same masking subexpressions from + // the preceding lines caused Clang to clutter the hot loop computing + // them - even though they aren't actually needed for shifting! Hence + // these rewritten conditions, which achieve a speedup factor of two. + if (((curr7-curr0) & 63) == 0) return p+1-prefix_size_; + if (((curr7-curr1) & 63) == 0) return p+2-prefix_size_; + if (((curr7-curr2) & 63) == 0) return p+3-prefix_size_; + if (((curr7-curr3) & 63) == 0) return p+4-prefix_size_; + if (((curr7-curr4) & 63) == 0) return p+5-prefix_size_; + if (((curr7-curr5) & 63) == 0) return p+6-prefix_size_; + if (((curr7-curr6) & 63) == 0) return p+7-prefix_size_; + if (((curr7-curr7) & 63) == 0) return p+8-prefix_size_; + } + + curr = curr7; + p += 8; + } while (p != endp); + data = p; + size = size&7; + } + + const uint8_t* p = reinterpret_cast(data); + const uint8_t* endp = p + size; + while (p != endp) { + uint8_t b = *p++; + uint64_t next = prefix_dfa_[b]; + curr = next >> (curr & 63); + if ((curr & 63) == kShiftDFAFinal * 6) + return p-prefix_size_; + } + return NULL; +} + +#if defined(__AVX2__) +// Finds the least significant non-zero bit in n. +static int FindLSBSet(uint32_t n) { + DCHECK_NE(n, 0); +#if defined(__GNUC__) + return __builtin_ctz(n); +#elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)) + unsigned long c; + _BitScanForward(&c, n); + return static_cast(c); +#else + int c = 31; + for (int shift = 1 << 4; shift != 0; shift >>= 1) { + uint32_t word = n << shift; + if (word != 0) { + n = word; + c -= shift; + } + } + return c; +#endif +} +#endif + +const void* Prog::PrefixAccel_FrontAndBack(const void* data, size_t size) { + DCHECK_GE(prefix_size_, 2); + if (size < prefix_size_) + return NULL; + // Don't bother searching the last prefix_size_-1 bytes for prefix_front_. + // This also means that probing for prefix_back_ doesn't go out of bounds. + size -= prefix_size_-1; + +#if defined(__AVX2__) + // Use AVX2 to look for prefix_front_ and prefix_back_ 32 bytes at a time. + if (size >= sizeof(__m256i)) { + const __m256i* fp = reinterpret_cast( + reinterpret_cast(data)); + const __m256i* bp = reinterpret_cast( + reinterpret_cast(data) + prefix_size_-1); + const __m256i* endfp = fp + size/sizeof(__m256i); + const __m256i f_set1 = _mm256_set1_epi8(prefix_front_back.prefix_front_); + const __m256i b_set1 = _mm256_set1_epi8(prefix_front_back.prefix_back_); + do { + const __m256i f_loadu = _mm256_loadu_si256(fp++); + const __m256i b_loadu = _mm256_loadu_si256(bp++); + const __m256i f_cmpeq = _mm256_cmpeq_epi8(f_set1, f_loadu); + const __m256i b_cmpeq = _mm256_cmpeq_epi8(b_set1, b_loadu); + const int fb_testz = _mm256_testz_si256(f_cmpeq, b_cmpeq); + if (fb_testz == 0) { // ZF: 1 means zero, 0 means non-zero. + const __m256i fb_and = _mm256_and_si256(f_cmpeq, b_cmpeq); + const int fb_movemask = _mm256_movemask_epi8(fb_and); + const int fb_ctz = FindLSBSet(fb_movemask); + return reinterpret_cast(fp-1) + fb_ctz; + } + } while (fp != endfp); + data = fp; + size = size%sizeof(__m256i); + } +#endif + + const char* p0 = reinterpret_cast(data); + for (const char* p = p0;; p++) { + DCHECK_GE(size, static_cast(p-p0)); + p = reinterpret_cast(memchr(p, prefix_front_back.prefix_front_, size - (p-p0))); + if (p == NULL || p[prefix_size_-1] == prefix_front_back.prefix_back_) + return p; + } +} + +} // namespace re2 diff --git a/internal/cpp/re2/prog.h b/internal/cpp/re2/prog.h new file mode 100644 index 00000000000..c78beacf55f --- /dev/null +++ b/internal/cpp/re2/prog.h @@ -0,0 +1,469 @@ +// Copyright 2007 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_PROG_H_ +#define RE2_PROG_H_ + +// Compiled representation of regular expressions. +// See regexp.h for the Regexp class, which represents a regular +// expression symbolically. + +#include +#include +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/re2.h" +#include "re2/sparse_array.h" +#include "re2/sparse_set.h" +#include "util/logging.h" +#include "util/util.h" + +namespace re2 { + +// Opcodes for Inst +enum InstOp { + kInstAlt = 0, // choose between out_ and out1_ + kInstAltMatch, // Alt: out_ is [00-FF] and back, out1_ is match; or vice versa. + kInstByteRange, // next (possible case-folded) byte must be in [lo_, hi_] + kInstCapture, // capturing parenthesis number cap_ + kInstEmptyWidth, // empty-width special (^ $ ...); bit(s) set in empty_ + kInstMatch, // found a match! + kInstNop, // no-op; occasionally unavoidable + kInstFail, // never match; occasionally unavoidable + kNumInst, +}; + +// Bit flags for empty-width specials +enum EmptyOp { + kEmptyBeginLine = 1 << 0, // ^ - beginning of line + kEmptyEndLine = 1 << 1, // $ - end of line + kEmptyBeginText = 1 << 2, // \A - beginning of text + kEmptyEndText = 1 << 3, // \z - end of text + kEmptyWordBoundary = 1 << 4, // \b - word boundary + kEmptyNonWordBoundary = 1 << 5, // \B - not \b + kEmptyAllFlags = (1 << 6) - 1, +}; + +class DFA; +class Regexp; + +// Compiled form of regexp program. +class Prog { +public: + Prog(); + ~Prog(); + + // Single instruction in regexp program. + class Inst { + public: + // See the assertion below for why this is so. + Inst() = default; + + // Copyable. + Inst(const Inst &) = default; + Inst &operator=(const Inst &) = default; + + // Constructors per opcode + void InitAlt(uint32_t out, uint32_t out1); + void InitByteRange(int lo, int hi, int foldcase, uint32_t out); + void InitCapture(int cap, uint32_t out); + void InitEmptyWidth(EmptyOp empty, uint32_t out); + void InitMatch(int id); + void InitNop(uint32_t out); + void InitFail(); + + // Getters + int id(Prog *p) { return static_cast(this - p->inst_.data()); } + InstOp opcode() { return static_cast(out_opcode_ & 7); } + int last() { return (out_opcode_ >> 3) & 1; } + int out() { return out_opcode_ >> 4; } + int out1() { + DCHECK(opcode() == kInstAlt || opcode() == kInstAltMatch); + return out1_; + } + int cap() { + DCHECK_EQ(opcode(), kInstCapture); + return cap_; + } + int lo() { + DCHECK_EQ(opcode(), kInstByteRange); + return byte_range.lo_; + } + int hi() { + DCHECK_EQ(opcode(), kInstByteRange); + return byte_range.hi_; + } + int foldcase() { + DCHECK_EQ(opcode(), kInstByteRange); + return byte_range.hint_foldcase_ & 1; + } + int hint() { + DCHECK_EQ(opcode(), kInstByteRange); + return byte_range.hint_foldcase_ >> 1; + } + int match_id() { + DCHECK_EQ(opcode(), kInstMatch); + return match_id_; + } + EmptyOp empty() { + DCHECK_EQ(opcode(), kInstEmptyWidth); + return empty_; + } + + bool greedy(Prog *p) { + DCHECK_EQ(opcode(), kInstAltMatch); + return p->inst(out())->opcode() == kInstByteRange || + (p->inst(out())->opcode() == kInstNop && p->inst(p->inst(out())->out())->opcode() == kInstByteRange); + } + + // Does this inst (an kInstByteRange) match c? + inline bool Matches(int c) { + DCHECK_EQ(opcode(), kInstByteRange); + if (foldcase() && 'A' <= c && c <= 'Z') + c += 'a' - 'A'; + return byte_range.lo_ <= c && c <= byte_range.hi_; + } + + // Returns string representation for debugging. + std::string Dump(); + + // Maximum instruction id. + // (Must fit in out_opcode_. PatchList/last steal another bit.) + static const int kMaxInst = (1 << 28) - 1; + + private: + void set_opcode(InstOp opcode) { out_opcode_ = (out() << 4) | (last() << 3) | opcode; } + + void set_last() { out_opcode_ = (out() << 4) | (1 << 3) | opcode(); } + + void set_out(int out) { out_opcode_ = (out << 4) | (last() << 3) | opcode(); } + + void set_out_opcode(int out, InstOp opcode) { out_opcode_ = (out << 4) | (last() << 3) | opcode; } + + uint32_t out_opcode_; // 28 bits: out, 1 bit: last, 3 (low) bits: opcode + union { // additional instruction arguments: + uint32_t out1_; // opcode == kInstAlt + // alternate next instruction + + int32_t cap_; // opcode == kInstCapture + // Index of capture register (holds text + // position recorded by capturing parentheses). + // For \n (the submatch for the nth parentheses), + // the left parenthesis captures into register 2*n + // and the right one captures into register 2*n+1. + + int32_t match_id_; // opcode == kInstMatch + // Match ID to identify this match (for re2::Set). + + struct { // opcode == kInstByteRange + uint8_t lo_; // byte range is lo_-hi_ inclusive + uint8_t hi_; // + uint16_t hint_foldcase_; // 15 bits: hint, 1 (low) bit: foldcase + // hint to execution engines: the delta to the + // next instruction (in the current list) worth + // exploring iff this instruction matched; 0 + // means there are no remaining possibilities, + // which is most likely for character classes. + // foldcase: A-Z -> a-z before checking range. + } byte_range; + + EmptyOp empty_; // opcode == kInstEmptyWidth + // empty_ is bitwise OR of kEmpty* flags above. + }; + + friend class Compiler; + friend struct PatchList; + friend class Prog; + }; + + // Inst must be trivial so that we can freely clear it with memset(3). + // Arrays of Inst are initialised by copying the initial elements with + // memmove(3) and then clearing any remaining elements with memset(3). + static_assert(std::is_trivial::value, "Inst must be trivial"); + + // Whether to anchor the search. + enum Anchor { + kUnanchored, // match anywhere + kAnchored, // match only starting at beginning of text + }; + + // Kind of match to look for (for anchor != kFullMatch) + // + // kLongestMatch mode finds the overall longest + // match but still makes its submatch choices the way + // Perl would, not in the way prescribed by POSIX. + // The POSIX rules are much more expensive to implement, + // and no one has needed them. + // + // kFullMatch is not strictly necessary -- we could use + // kLongestMatch and then check the length of the match -- but + // the matching code can run faster if it knows to consider only + // full matches. + enum MatchKind { + kFirstMatch, // like Perl, PCRE + kLongestMatch, // like egrep or POSIX + kFullMatch, // match only entire text; implies anchor==kAnchored + kManyMatch // for SearchDFA, records set of matches + }; + + Inst *inst(int id) { return &inst_[id]; } + int start() { return start_; } + void set_start(int start) { start_ = start; } + int start_unanchored() { return start_unanchored_; } + void set_start_unanchored(int start) { start_unanchored_ = start; } + int size() { return size_; } + bool reversed() { return reversed_; } + void set_reversed(bool reversed) { reversed_ = reversed; } + int list_count() { return list_count_; } + int inst_count(InstOp op) { return inst_count_[op]; } + uint16_t *list_heads() { return list_heads_.data(); } + size_t bit_state_text_max_size() { return bit_state_text_max_size_; } + int64_t dfa_mem() { return dfa_mem_; } + void set_dfa_mem(int64_t dfa_mem) { dfa_mem_ = dfa_mem; } + bool anchor_start() { return anchor_start_; } + void set_anchor_start(bool b) { anchor_start_ = b; } + bool anchor_end() { return anchor_end_; } + void set_anchor_end(bool b) { anchor_end_ = b; } + int bytemap_range() { return bytemap_range_; } + const uint8_t *bytemap() { return bytemap_; } + bool can_prefix_accel() { return prefix_size_ != 0; } + + // Accelerates to the first likely occurrence of the prefix. + // Returns a pointer to the first byte or NULL if not found. + const void *PrefixAccel(const void *data, size_t size) { + DCHECK(can_prefix_accel()); + if (prefix_foldcase_) { + return PrefixAccel_ShiftDFA(data, size); + } else if (prefix_size_ != 1) { + return PrefixAccel_FrontAndBack(data, size); + } else { + return memchr(data, prefix_front_back.prefix_front_, size); + } + } + + // Configures prefix accel using the analysis performed during compilation. + void ConfigurePrefixAccel(const std::string &prefix, bool prefix_foldcase); + + // An implementation of prefix accel that uses prefix_dfa_ to perform + // case-insensitive search. + const void *PrefixAccel_ShiftDFA(const void *data, size_t size); + + // An implementation of prefix accel that looks for prefix_front_ and + // prefix_back_ to return fewer false positives than memchr(3) alone. + const void *PrefixAccel_FrontAndBack(const void *data, size_t size); + + // Returns string representation of program for debugging. + std::string Dump(); + std::string DumpUnanchored(); + std::string DumpByteMap(); + + // Returns the set of kEmpty flags that are in effect at + // position p within context. + static uint32_t EmptyFlags(const StringPiece &context, const char *p); + + // Returns whether byte c is a word character: ASCII only. + // Used by the implementation of \b and \B. + // This is not right for Unicode, but: + // - it's hard to get right in a byte-at-a-time matching world + // (the DFA has only one-byte lookahead). + // - even if the lookahead were possible, the Progs would be huge. + // This crude approximation is the same one PCRE uses. + static bool IsWordChar(uint8_t c) { return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') || c == '_'; } + + // Execution engines. They all search for the regexp (run the prog) + // in text, which is in the larger context (used for ^ $ \b etc). + // Anchor and kind control the kind of search. + // Returns true if match found, false if not. + // If match found, fills match[0..nmatch-1] with submatch info. + // match[0] is overall match, match[1] is first set of parens, etc. + // If a particular submatch is not matched during the regexp match, + // it is set to NULL. + // + // Matching text == StringPiece(NULL, 0) is treated as any other empty + // string, but note that on return, it will not be possible to distinguish + // submatches that matched that empty string from submatches that didn't + // match anything. Either way, match[i] == NULL. + + // Search using NFA: can find submatches but kind of slow. + bool SearchNFA(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch); + + // Search using DFA: much faster than NFA but only finds + // end of match and can use a lot more memory. + // Returns whether a match was found. + // If the DFA runs out of memory, sets *failed to true and returns false. + // If matches != NULL and kind == kManyMatch and there is a match, + // SearchDFA fills matches with the match IDs of the final matching state. + bool SearchDFA(const StringPiece &text, + const StringPiece &context, + Anchor anchor, + MatchKind kind, + StringPiece *match0, + bool *failed, + SparseSet *matches); + + // The callback issued after building each DFA state with BuildEntireDFA(). + // If next is null, then the memory budget has been exhausted and building + // will halt. Otherwise, the state has been built and next points to an array + // of bytemap_range()+1 slots holding the next states as per the bytemap and + // kByteEndText. The number of the state is implied by the callback sequence: + // the first callback is for state 0, the second callback is for state 1, ... + // match indicates whether the state is a matching state. + using DFAStateCallback = std::function; + + // Build the entire DFA for the given match kind. + // Usually the DFA is built out incrementally, as needed, which + // avoids lots of unnecessary work. + // If cb is not empty, it receives one callback per state built. + // Returns the number of states built. + // FOR TESTING OR EXPERIMENTAL PURPOSES ONLY. + int BuildEntireDFA(MatchKind kind, const DFAStateCallback &cb); + + // Compute bytemap. + void ComputeByteMap(); + + // Run peep-hole optimizer on program. + void Optimize(); + + // One-pass NFA: only correct if IsOnePass() is true, + // but much faster than NFA (competitive with PCRE) + // for those expressions. + bool IsOnePass(); + bool SearchOnePass(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch); + + // Bit-state backtracking. Fast on small cases but uses memory + // proportional to the product of the list count and the text size. + bool CanBitState() { return list_heads_.data() != NULL; } + bool SearchBitState(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch); + + static const int kMaxOnePassCapture = 5; // $0 through $4 + + // Backtracking search: the gold standard against which the other + // implementations are checked. FOR TESTING ONLY. + // It allocates a ton of memory to avoid running forever. + // It is also recursive, so can't use in production (will overflow stacks). + // The name "Unsafe" here is supposed to be a flag that + // you should not be using this function. + bool UnsafeSearchBacktrack(const StringPiece &text, const StringPiece &context, Anchor anchor, MatchKind kind, StringPiece *match, int nmatch); + + // Computes range for any strings matching regexp. The min and max can in + // some cases be arbitrarily precise, so the caller gets to specify the + // maximum desired length of string returned. + // + // Assuming PossibleMatchRange(&min, &max, N) returns successfully, any + // string s that is an anchored match for this regexp satisfies + // min <= s && s <= max. + // + // Note that PossibleMatchRange() will only consider the first copy of an + // infinitely repeated element (i.e., any regexp element followed by a '*' or + // '+' operator). Regexps with "{N}" constructions are not affected, as those + // do not compile down to infinite repetitions. + // + // Returns true on success, false on error. + bool PossibleMatchRange(std::string *min, std::string *max, int maxlen); + + // Outputs the program fanout into the given sparse array. + void Fanout(SparseArray *fanout); + + // Compiles a collection of regexps to Prog. Each regexp will have + // its own Match instruction recording the index in the output vector. + static Prog *CompileSet(Regexp *re, RE2::Anchor anchor, int64_t max_mem); + + // Flattens the Prog from "tree" form to "list" form. This is an in-place + // operation in the sense that the old instructions are lost. + void Flatten(); + + // Walks the Prog; the "successor roots" or predecessors of the reachable + // instructions are marked in rootmap or predmap/predvec, respectively. + // reachable and stk are preallocated scratch structures. + void MarkSuccessors(SparseArray *rootmap, + SparseArray *predmap, + std::vector> *predvec, + SparseSet *reachable, + std::vector *stk); + + // Walks the Prog from the given "root" instruction; the "dominator root" + // of the reachable instructions (if such exists) is marked in rootmap. + // reachable and stk are preallocated scratch structures. + void MarkDominator(int root, + SparseArray *rootmap, + SparseArray *predmap, + std::vector> *predvec, + SparseSet *reachable, + std::vector *stk); + + // Walks the Prog from the given "root" instruction; the reachable + // instructions are emitted in "list" form and appended to flat. + // reachable and stk are preallocated scratch structures. + void EmitList(int root, SparseArray *rootmap, std::vector *flat, SparseSet *reachable, std::vector *stk); + + // Computes hints for ByteRange instructions in [begin, end). + void ComputeHints(std::vector *flat, int begin, int end); + + // Controls whether the DFA should bail out early if the NFA would be faster. + // FOR TESTING ONLY. + static void TESTING_ONLY_set_dfa_should_bail_when_slow(bool b); + +private: + friend class Compiler; + + DFA *GetDFA(MatchKind kind); + void DeleteDFA(DFA *dfa); + + bool anchor_start_; // regexp has explicit start anchor + bool anchor_end_; // regexp has explicit end anchor + bool reversed_; // whether program runs backward over input + bool did_flatten_; // has Flatten been called? + bool did_onepass_; // has IsOnePass been called? + + int start_; // entry point for program + int start_unanchored_; // unanchored entry point for program + int size_; // number of instructions + int bytemap_range_; // bytemap_[x] < bytemap_range_ + + bool prefix_foldcase_; // whether prefix is case-insensitive + size_t prefix_size_; // size of prefix (0 if no prefix) + union { + uint64_t *prefix_dfa_; // "Shift DFA" for prefix + struct { + int prefix_front_; // first byte of prefix + int prefix_back_; // last byte of prefix + } prefix_front_back; + }; + + int list_count_; // count of lists (see above) + int inst_count_[kNumInst]; // count of instructions by opcode + PODArray list_heads_; // sparse array enumerating list heads + // not populated if size_ is overly large + size_t bit_state_text_max_size_; // upper bound (inclusive) on text.size() + + PODArray inst_; // pointer to instruction array + PODArray onepass_nodes_; // data for OnePass nodes + + int64_t dfa_mem_; // Maximum memory for DFAs. + DFA *dfa_first_; // DFA cached for kFirstMatch/kManyMatch + DFA *dfa_longest_; // DFA cached for kLongestMatch/kFullMatch + + uint8_t bytemap_[256]; // map from input bytes to byte classes + + std::once_flag dfa_first_once_; + std::once_flag dfa_longest_once_; + + Prog(const Prog &) = delete; + Prog &operator=(const Prog &) = delete; +}; + +// std::string_view in MSVC has iterators that aren't just pointers and +// that don't allow comparisons between different objects - not even if +// those objects are views into the same string! Thus, we provide these +// conversion functions for convenience. +static inline const char *BeginPtr(const StringPiece &s) { return s.data(); } +static inline const char *EndPtr(const StringPiece &s) { return s.data() + s.size(); } + +} // namespace re2 + +#endif // RE2_PROG_H_ diff --git a/internal/cpp/re2/re2.cc b/internal/cpp/re2/re2.cc new file mode 100644 index 00000000000..80ec4b08dc8 --- /dev/null +++ b/internal/cpp/re2/re2.cc @@ -0,0 +1,1326 @@ +// Copyright 2003-2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Regular expression interface RE2. +// +// Originally the PCRE C++ wrapper, but adapted to use +// the new automata-based regular expression engines. + +#include "re2/re2.h" + +#include +#include +#include +#ifdef _MSC_VER +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "re2/prog.h" +#include "re2/regexp.h" +#include "re2/sparse_array.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +// Controls the maximum count permitted by GlobalReplace(); -1 is unlimited. +static int maximum_global_replace_count = -1; + +void RE2::FUZZING_ONLY_set_maximum_global_replace_count(int i) { maximum_global_replace_count = i; } + +// Maximum number of args we can set +static const int kMaxArgs = 16; +static const int kVecSize = 1 + kMaxArgs; + +const int RE2::Options::kDefaultMaxMem; // initialized in re2.h + +RE2::Options::Options(RE2::CannedOptions opt) + : max_mem_(kDefaultMaxMem), encoding_(opt == RE2::Latin1 ? EncodingLatin1 : EncodingUTF8), posix_syntax_(opt == RE2::POSIX), + longest_match_(opt == RE2::POSIX), log_errors_(opt != RE2::Quiet), literal_(false), never_nl_(false), dot_nl_(false), never_capture_(false), + case_sensitive_(true), perl_classes_(false), word_boundary_(false), one_line_(false) {} + +// Empty objects for use as const references. +// Statically allocating the storage and then +// lazily constructing the objects (in a once +// in RE2::Init()) avoids global constructors +// and the false positives (thanks, Valgrind) +// about memory leaks at program termination. +struct EmptyStorage { + std::string empty_string; + std::map empty_named_groups; + std::map empty_group_names; +}; +alignas(EmptyStorage) static char empty_storage[sizeof(EmptyStorage)]; + +static inline std::string *empty_string() { return &reinterpret_cast(empty_storage)->empty_string; } + +static inline std::map *empty_named_groups() { return &reinterpret_cast(empty_storage)->empty_named_groups; } + +static inline std::map *empty_group_names() { return &reinterpret_cast(empty_storage)->empty_group_names; } + +// Converts from Regexp error code to RE2 error code. +// Maybe some day they will diverge. In any event, this +// hides the existence of Regexp from RE2 users. +static RE2::ErrorCode RegexpErrorToRE2(re2::RegexpStatusCode code) { + switch (code) { + case re2::kRegexpSuccess: + return RE2::NoError; + case re2::kRegexpInternalError: + return RE2::ErrorInternal; + case re2::kRegexpBadEscape: + return RE2::ErrorBadEscape; + case re2::kRegexpBadCharClass: + return RE2::ErrorBadCharClass; + case re2::kRegexpBadCharRange: + return RE2::ErrorBadCharRange; + case re2::kRegexpMissingBracket: + return RE2::ErrorMissingBracket; + case re2::kRegexpMissingParen: + return RE2::ErrorMissingParen; + case re2::kRegexpUnexpectedParen: + return RE2::ErrorUnexpectedParen; + case re2::kRegexpTrailingBackslash: + return RE2::ErrorTrailingBackslash; + case re2::kRegexpRepeatArgument: + return RE2::ErrorRepeatArgument; + case re2::kRegexpRepeatSize: + return RE2::ErrorRepeatSize; + case re2::kRegexpRepeatOp: + return RE2::ErrorRepeatOp; + case re2::kRegexpBadPerlOp: + return RE2::ErrorBadPerlOp; + case re2::kRegexpBadUTF8: + return RE2::ErrorBadUTF8; + case re2::kRegexpBadNamedCapture: + return RE2::ErrorBadNamedCapture; + } + return RE2::ErrorInternal; +} + +static std::string trunc(const StringPiece &pattern) { + if (pattern.size() < 100) + return std::string(pattern); + return std::string(pattern.substr(0, 100)) + "..."; +} + +RE2::RE2(const char *pattern) { Init(pattern, DefaultOptions); } + +RE2::RE2(const std::string &pattern) { Init(pattern, DefaultOptions); } + +RE2::RE2(const StringPiece &pattern) { Init(pattern, DefaultOptions); } + +RE2::RE2(const StringPiece &pattern, const Options &options) { Init(pattern, options); } + +int RE2::Options::ParseFlags() const { + int flags = Regexp::ClassNL; + switch (encoding()) { + default: + if (log_errors()) + LOG(ERROR) << "Unknown encoding " << encoding(); + break; + case RE2::Options::EncodingUTF8: + break; + case RE2::Options::EncodingLatin1: + flags |= Regexp::Latin1; + break; + } + + if (!posix_syntax()) + flags |= Regexp::LikePerl; + + if (literal()) + flags |= Regexp::Literal; + + if (never_nl()) + flags |= Regexp::NeverNL; + + if (dot_nl()) + flags |= Regexp::DotNL; + + if (never_capture()) + flags |= Regexp::NeverCapture; + + if (!case_sensitive()) + flags |= Regexp::FoldCase; + + if (perl_classes()) + flags |= Regexp::PerlClasses; + + if (word_boundary()) + flags |= Regexp::PerlB; + + if (one_line()) + flags |= Regexp::OneLine; + + return flags; +} + +void RE2::Init(const StringPiece &pattern, const Options &options) { + static std::once_flag empty_once; + std::call_once(empty_once, []() { (void)new (empty_storage) EmptyStorage; }); + + pattern_ = new std::string(pattern); + options_.Copy(options); + entire_regexp_ = NULL; + suffix_regexp_ = NULL; + error_ = empty_string(); + error_arg_ = empty_string(); + + num_captures_ = -1; + error_code_ = NoError; + longest_match_ = options_.longest_match(); + is_one_pass_ = false; + prefix_foldcase_ = false; + prefix_.clear(); + prog_ = NULL; + + rprog_ = NULL; + named_groups_ = NULL; + group_names_ = NULL; + + RegexpStatus status; + entire_regexp_ = Regexp::Parse(*pattern_, static_cast(options_.ParseFlags()), &status); + if (entire_regexp_ == NULL) { + if (options_.log_errors()) { + LOG(ERROR) << "Error parsing '" << trunc(*pattern_) << "': " << status.Text(); + } + error_ = new std::string(status.Text()); + error_code_ = RegexpErrorToRE2(status.code()); + error_arg_ = new std::string(status.error_arg()); + return; + } + + bool foldcase; + re2::Regexp *suffix; + if (entire_regexp_->RequiredPrefix(&prefix_, &foldcase, &suffix)) { + prefix_foldcase_ = foldcase; + suffix_regexp_ = suffix; + } else { + suffix_regexp_ = entire_regexp_->Incref(); + } + + // Two thirds of the memory goes to the forward Prog, + // one third to the reverse prog, because the forward + // Prog has two DFAs but the reverse prog has one. + prog_ = suffix_regexp_->CompileToProg(options_.max_mem() * 2 / 3); + if (prog_ == NULL) { + if (options_.log_errors()) + LOG(ERROR) << "Error compiling '" << trunc(*pattern_) << "'"; + error_ = new std::string("pattern too large - compile failed"); + error_code_ = RE2::ErrorPatternTooLarge; + return; + } + + // We used to compute this lazily, but it's used during the + // typical control flow for a match call, so we now compute + // it eagerly, which avoids the overhead of std::once_flag. + num_captures_ = suffix_regexp_->NumCaptures(); + + // Could delay this until the first match call that + // cares about submatch information, but the one-pass + // machine's memory gets cut from the DFA memory budget, + // and that is harder to do if the DFA has already + // been built. + is_one_pass_ = prog_->IsOnePass(); +} + +// Returns rprog_, computing it if needed. +re2::Prog *RE2::ReverseProg() const { + std::call_once( + rprog_once_, + [](const RE2 *re) { + re->rprog_ = re->suffix_regexp_->CompileToReverseProg(re->options_.max_mem() / 3); + if (re->rprog_ == NULL) { + if (re->options_.log_errors()) + LOG(ERROR) << "Error reverse compiling '" << trunc(*re->pattern_) << "'"; + // We no longer touch error_ and error_code_ because failing to compile + // the reverse Prog is not a showstopper: falling back to NFA execution + // is fine. More importantly, an RE2 object is supposed to be logically + // immutable: whatever ok() would have returned after Init() completed, + // it should continue to return that no matter what ReverseProg() does. + } + }, + this); + return rprog_; +} + +RE2::~RE2() { + if (group_names_ != empty_group_names()) + delete group_names_; + if (named_groups_ != empty_named_groups()) + delete named_groups_; + delete rprog_; + delete prog_; + if (error_arg_ != empty_string()) + delete error_arg_; + if (error_ != empty_string()) + delete error_; + if (suffix_regexp_) + suffix_regexp_->Decref(); + if (entire_regexp_) + entire_regexp_->Decref(); + delete pattern_; +} + +int RE2::ProgramSize() const { + if (prog_ == NULL) + return -1; + return prog_->size(); +} + +int RE2::ReverseProgramSize() const { + if (prog_ == NULL) + return -1; + Prog *prog = ReverseProg(); + if (prog == NULL) + return -1; + return prog->size(); +} + +// Finds the most significant non-zero bit in n. +static int FindMSBSet(uint32_t n) { + DCHECK_NE(n, 0); +#if defined(__GNUC__) + return 31 ^ __builtin_clz(n); +#elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)) + unsigned long c; + _BitScanReverse(&c, n); + return static_cast(c); +#else + int c = 0; + for (int shift = 1 << 4; shift != 0; shift >>= 1) { + uint32_t word = n >> shift; + if (word != 0) { + n = word; + c += shift; + } + } + return c; +#endif +} + +static int Fanout(Prog *prog, std::vector *histogram) { + SparseArray fanout(prog->size()); + prog->Fanout(&fanout); + int data[32] = {}; + int size = 0; + for (SparseArray::iterator i = fanout.begin(); i != fanout.end(); ++i) { + if (i->value() == 0) + continue; + uint32_t value = i->value(); + int bucket = FindMSBSet(value); + bucket += value & (value - 1) ? 1 : 0; + ++data[bucket]; + size = std::max(size, bucket + 1); + } + if (histogram != NULL) + histogram->assign(data, data + size); + return size - 1; +} + +int RE2::ProgramFanout(std::vector *histogram) const { + if (prog_ == NULL) + return -1; + return Fanout(prog_, histogram); +} + +int RE2::ReverseProgramFanout(std::vector *histogram) const { + if (prog_ == NULL) + return -1; + Prog *prog = ReverseProg(); + if (prog == NULL) + return -1; + return Fanout(prog, histogram); +} + +// Returns named_groups_, computing it if needed. +const std::map &RE2::NamedCapturingGroups() const { + std::call_once( + named_groups_once_, + [](const RE2 *re) { + if (re->suffix_regexp_ != NULL) + re->named_groups_ = re->suffix_regexp_->NamedCaptures(); + if (re->named_groups_ == NULL) + re->named_groups_ = empty_named_groups(); + }, + this); + return *named_groups_; +} + +// Returns group_names_, computing it if needed. +const std::map &RE2::CapturingGroupNames() const { + std::call_once( + group_names_once_, + [](const RE2 *re) { + if (re->suffix_regexp_ != NULL) + re->group_names_ = re->suffix_regexp_->CaptureNames(); + if (re->group_names_ == NULL) + re->group_names_ = empty_group_names(); + }, + this); + return *group_names_; +} + +/***** Convenience interfaces *****/ + +bool RE2::FullMatchN(const StringPiece &text, const RE2 &re, const Arg *const args[], int n) { return re.DoMatch(text, ANCHOR_BOTH, NULL, args, n); } + +bool RE2::PartialMatchN(const StringPiece &text, const RE2 &re, const Arg *const args[], int n) { + return re.DoMatch(text, UNANCHORED, NULL, args, n); +} + +bool RE2::ConsumeN(StringPiece *input, const RE2 &re, const Arg *const args[], int n) { + size_t consumed; + if (re.DoMatch(*input, ANCHOR_START, &consumed, args, n)) { + input->remove_prefix(consumed); + return true; + } else { + return false; + } +} + +bool RE2::FindAndConsumeN(StringPiece *input, const RE2 &re, const Arg *const args[], int n) { + size_t consumed; + if (re.DoMatch(*input, UNANCHORED, &consumed, args, n)) { + input->remove_prefix(consumed); + return true; + } else { + return false; + } +} + +bool RE2::Replace(std::string *str, const RE2 &re, const StringPiece &rewrite) { + StringPiece vec[kVecSize]; + int nvec = 1 + MaxSubmatch(rewrite); + if (nvec > 1 + re.NumberOfCapturingGroups()) + return false; + if (nvec > static_cast(arraysize(vec))) + return false; + if (!re.Match(*str, 0, str->size(), UNANCHORED, vec, nvec)) + return false; + + std::string s; + if (!re.Rewrite(&s, rewrite, vec, nvec)) + return false; + + assert(vec[0].data() >= str->data()); + assert(vec[0].data() + vec[0].size() <= str->data() + str->size()); + str->replace(vec[0].data() - str->data(), vec[0].size(), s); + return true; +} + +int RE2::GlobalReplace(std::string *str, const RE2 &re, const StringPiece &rewrite) { + StringPiece vec[kVecSize]; + int nvec = 1 + MaxSubmatch(rewrite); + if (nvec > 1 + re.NumberOfCapturingGroups()) + return false; + if (nvec > static_cast(arraysize(vec))) + return false; + + const char *p = str->data(); + const char *ep = p + str->size(); + const char *lastend = NULL; + std::string out; + int count = 0; + while (p <= ep) { + if (maximum_global_replace_count != -1 && count >= maximum_global_replace_count) + break; + if (!re.Match(*str, static_cast(p - str->data()), str->size(), UNANCHORED, vec, nvec)) + break; + if (p < vec[0].data()) + out.append(p, vec[0].data() - p); + if (vec[0].data() == lastend && vec[0].empty()) { + // Disallow empty match at end of last match: skip ahead. + // + // fullrune() takes int, not ptrdiff_t. However, it just looks + // at the leading byte and treats any length >= 4 the same. + if (re.options().encoding() == RE2::Options::EncodingUTF8 && fullrune(p, static_cast(std::min(ptrdiff_t{4}, ep - p)))) { + // re is in UTF-8 mode and there is enough left of str + // to allow us to advance by up to UTFmax bytes. + Rune r; + int n = chartorune(&r, p); + // Some copies of chartorune have a bug that accepts + // encodings of values in (10FFFF, 1FFFFF] as valid. + if (r > Runemax) { + n = 1; + r = Runeerror; + } + if (!(n == 1 && r == Runeerror)) { // no decoding error + out.append(p, n); + p += n; + continue; + } + } + // Most likely, re is in Latin-1 mode. If it is in UTF-8 mode, + // we fell through from above and the GIGO principle applies. + if (p < ep) + out.append(p, 1); + p++; + continue; + } + re.Rewrite(&out, rewrite, vec, nvec); + p = vec[0].data() + vec[0].size(); + lastend = p; + count++; + } + + if (count == 0) + return 0; + + if (p < ep) + out.append(p, ep - p); + using std::swap; + swap(out, *str); + return count; +} + +bool RE2::Extract(const StringPiece &text, const RE2 &re, const StringPiece &rewrite, std::string *out) { + StringPiece vec[kVecSize]; + int nvec = 1 + MaxSubmatch(rewrite); + if (nvec > 1 + re.NumberOfCapturingGroups()) + return false; + if (nvec > static_cast(arraysize(vec))) + return false; + if (!re.Match(text, 0, text.size(), UNANCHORED, vec, nvec)) + return false; + + out->clear(); + return re.Rewrite(out, rewrite, vec, nvec); +} + +std::string RE2::QuoteMeta(const StringPiece &unquoted) { + std::string result; + result.reserve(unquoted.size() << 1); + + // Escape any ascii character not in [A-Za-z_0-9]. + // + // Note that it's legal to escape a character even if it has no + // special meaning in a regular expression -- so this function does + // that. (This also makes it identical to the perl function of the + // same name except for the null-character special case; + // see `perldoc -f quotemeta`.) + for (size_t ii = 0; ii < unquoted.size(); ++ii) { + // Note that using 'isalnum' here raises the benchmark time from + // 32ns to 58ns: + if ((unquoted[ii] < 'a' || unquoted[ii] > 'z') && (unquoted[ii] < 'A' || unquoted[ii] > 'Z') && (unquoted[ii] < '0' || unquoted[ii] > '9') && + unquoted[ii] != '_' && + // If this is the part of a UTF8 or Latin1 character, we need + // to copy this byte without escaping. Experimentally this is + // what works correctly with the regexp library. + !(unquoted[ii] & 128)) { + if (unquoted[ii] == '\0') { // Special handling for null chars. + // Note that this special handling is not strictly required for RE2, + // but this quoting is required for other regexp libraries such as + // PCRE. + // Can't use "\\0" since the next character might be a digit. + result += "\\x00"; + continue; + } + result += '\\'; + } + result += unquoted[ii]; + } + + return result; +} + +bool RE2::PossibleMatchRange(std::string *min, std::string *max, int maxlen) const { + if (prog_ == NULL) + return false; + + int n = static_cast(prefix_.size()); + if (n > maxlen) + n = maxlen; + + // Determine initial min max from prefix_ literal. + *min = prefix_.substr(0, n); + *max = prefix_.substr(0, n); + if (prefix_foldcase_) { + // prefix is ASCII lowercase; change *min to uppercase. + for (int i = 0; i < n; i++) { + char &c = (*min)[i]; + if ('a' <= c && c <= 'z') + c += 'A' - 'a'; + } + } + + // Add to prefix min max using PossibleMatchRange on regexp. + std::string dmin, dmax; + maxlen -= n; + if (maxlen > 0 && prog_->PossibleMatchRange(&dmin, &dmax, maxlen)) { + min->append(dmin); + max->append(dmax); + } else if (!max->empty()) { + // prog_->PossibleMatchRange has failed us, + // but we still have useful information from prefix_. + // Round up *max to allow any possible suffix. + PrefixSuccessor(max); + } else { + // Nothing useful. + *min = ""; + *max = ""; + return false; + } + + return true; +} + +// Avoid possible locale nonsense in standard strcasecmp. +// The string a is known to be all lowercase. +static int ascii_strcasecmp(const char *a, const char *b, size_t len) { + const char *ae = a + len; + + for (; a < ae; a++, b++) { + uint8_t x = *a; + uint8_t y = *b; + if ('A' <= y && y <= 'Z') + y += 'a' - 'A'; + if (x != y) + return x - y; + } + return 0; +} + +/***** Actual matching and rewriting code *****/ + +bool RE2::Match(const StringPiece &text, size_t startpos, size_t endpos, Anchor re_anchor, StringPiece *submatch, int nsubmatch) const { + if (!ok()) { + if (options_.log_errors()) + LOG(ERROR) << "Invalid RE2: " << *error_; + return false; + } + + if (startpos > endpos || endpos > text.size()) { + if (options_.log_errors()) + LOG(ERROR) << "RE2: invalid startpos, endpos pair. [" + << "startpos: " << startpos << ", " + << "endpos: " << endpos << ", " + << "text size: " << text.size() << "]"; + return false; + } + + StringPiece subtext = text; + subtext.remove_prefix(startpos); + subtext.remove_suffix(text.size() - endpos); + + // Use DFAs to find exact location of match, filter out non-matches. + + // Don't ask for the location if we won't use it. + // SearchDFA can do extra optimizations in that case. + StringPiece match; + StringPiece *matchp = &match; + if (nsubmatch == 0) + matchp = NULL; + + int ncap = 1 + NumberOfCapturingGroups(); + if (ncap > nsubmatch) + ncap = nsubmatch; + + // If the regexp is anchored explicitly, must not be in middle of text. + if (prog_->anchor_start() && startpos != 0) + return false; + if (prog_->anchor_end() && endpos != text.size()) + return false; + + // If the regexp is anchored explicitly, update re_anchor + // so that we can potentially fall into a faster case below. + if (prog_->anchor_start() && prog_->anchor_end()) + re_anchor = ANCHOR_BOTH; + else if (prog_->anchor_start() && re_anchor != ANCHOR_BOTH) + re_anchor = ANCHOR_START; + + // Check for the required prefix, if any. + size_t prefixlen = 0; + if (!prefix_.empty()) { + if (startpos != 0) + return false; + prefixlen = prefix_.size(); + if (prefixlen > subtext.size()) + return false; + if (prefix_foldcase_) { + if (ascii_strcasecmp(&prefix_[0], subtext.data(), prefixlen) != 0) + return false; + } else { + if (memcmp(&prefix_[0], subtext.data(), prefixlen) != 0) + return false; + } + subtext.remove_prefix(prefixlen); + // If there is a required prefix, the anchor must be at least ANCHOR_START. + if (re_anchor != ANCHOR_BOTH) + re_anchor = ANCHOR_START; + } + + Prog::Anchor anchor = Prog::kUnanchored; + Prog::MatchKind kind = longest_match_ ? Prog::kLongestMatch : Prog::kFirstMatch; + + bool can_one_pass = is_one_pass_ && ncap <= Prog::kMaxOnePassCapture; + bool can_bit_state = prog_->CanBitState(); + size_t bit_state_text_max_size = prog_->bit_state_text_max_size(); + +#ifdef RE2_HAVE_THREAD_LOCAL + hooks::context = this; +#endif + bool dfa_failed = false; + bool skipped_test = false; + switch (re_anchor) { + default: + LOG(DFATAL) << "Unexpected re_anchor value: " << re_anchor; + return false; + + case UNANCHORED: { + if (prog_->anchor_end()) { + // This is a very special case: we don't need the forward DFA because + // we already know where the match must end! Instead, the reverse DFA + // can say whether there is a match and (optionally) where it starts. + Prog *prog = ReverseProg(); + if (prog == NULL) { + // Fall back to NFA below. + skipped_test = true; + break; + } + if (!prog->SearchDFA(subtext, text, Prog::kAnchored, Prog::kLongestMatch, matchp, &dfa_failed, NULL)) { + if (dfa_failed) { + if (options_.log_errors()) + LOG(ERROR) << "DFA out of memory: " + << "pattern length " << pattern_->size() << ", " + << "program size " << prog->size() << ", " + << "list count " << prog->list_count() << ", " + << "bytemap range " << prog->bytemap_range(); + // Fall back to NFA below. + skipped_test = true; + break; + } + return false; + } + if (matchp == NULL) // Matched. Don't care where. + return true; + break; + } + + if (!prog_->SearchDFA(subtext, text, anchor, kind, matchp, &dfa_failed, NULL)) { + if (dfa_failed) { + if (options_.log_errors()) + LOG(ERROR) << "DFA out of memory: " + << "pattern length " << pattern_->size() << ", " + << "program size " << prog_->size() << ", " + << "list count " << prog_->list_count() << ", " + << "bytemap range " << prog_->bytemap_range(); + // Fall back to NFA below. + skipped_test = true; + break; + } + return false; + } + if (matchp == NULL) // Matched. Don't care where. + return true; + // SearchDFA set match.end() but didn't know where the + // match started. Run the regexp backward from match.end() + // to find the longest possible match -- that's where it started. + Prog *prog = ReverseProg(); + if (prog == NULL) { + // Fall back to NFA below. + skipped_test = true; + break; + } + if (!prog->SearchDFA(match, text, Prog::kAnchored, Prog::kLongestMatch, &match, &dfa_failed, NULL)) { + if (dfa_failed) { + if (options_.log_errors()) + LOG(ERROR) << "DFA out of memory: " + << "pattern length " << pattern_->size() << ", " + << "program size " << prog->size() << ", " + << "list count " << prog->list_count() << ", " + << "bytemap range " << prog->bytemap_range(); + // Fall back to NFA below. + skipped_test = true; + break; + } + if (options_.log_errors()) + LOG(ERROR) << "SearchDFA inconsistency"; + return false; + } + break; + } + + case ANCHOR_BOTH: + case ANCHOR_START: + if (re_anchor == ANCHOR_BOTH) + kind = Prog::kFullMatch; + anchor = Prog::kAnchored; + + // If only a small amount of text and need submatch + // information anyway and we're going to use OnePass or BitState + // to get it, we might as well not even bother with the DFA: + // OnePass or BitState will be fast enough. + // On tiny texts, OnePass outruns even the DFA, and + // it doesn't have the shared state and occasional mutex that + // the DFA does. + if (can_one_pass && text.size() <= 4096 && (ncap > 1 || text.size() <= 16)) { + skipped_test = true; + break; + } + if (can_bit_state && text.size() <= bit_state_text_max_size && ncap > 1) { + skipped_test = true; + break; + } + if (!prog_->SearchDFA(subtext, text, anchor, kind, &match, &dfa_failed, NULL)) { + if (dfa_failed) { + if (options_.log_errors()) + LOG(ERROR) << "DFA out of memory: " + << "pattern length " << pattern_->size() << ", " + << "program size " << prog_->size() << ", " + << "list count " << prog_->list_count() << ", " + << "bytemap range " << prog_->bytemap_range(); + // Fall back to NFA below. + skipped_test = true; + break; + } + return false; + } + break; + } + + if (!skipped_test && ncap <= 1) { + // We know exactly where it matches. That's enough. + if (ncap == 1) + submatch[0] = match; + } else { + StringPiece subtext1; + if (skipped_test) { + // DFA ran out of memory or was skipped: + // need to search in entire original text. + subtext1 = subtext; + } else { + // DFA found the exact match location: + // let NFA run an anchored, full match search + // to find submatch locations. + subtext1 = match; + anchor = Prog::kAnchored; + kind = Prog::kFullMatch; + } + + if (can_one_pass && anchor != Prog::kUnanchored) { + if (!prog_->SearchOnePass(subtext1, text, anchor, kind, submatch, ncap)) { + if (!skipped_test && options_.log_errors()) + LOG(ERROR) << "SearchOnePass inconsistency"; + return false; + } + } else if (can_bit_state && subtext1.size() <= bit_state_text_max_size) { + if (!prog_->SearchBitState(subtext1, text, anchor, kind, submatch, ncap)) { + if (!skipped_test && options_.log_errors()) + LOG(ERROR) << "SearchBitState inconsistency"; + return false; + } + } else { + if (!prog_->SearchNFA(subtext1, text, anchor, kind, submatch, ncap)) { + if (!skipped_test && options_.log_errors()) + LOG(ERROR) << "SearchNFA inconsistency"; + return false; + } + } + } + + // Adjust overall match for required prefix that we stripped off. + if (prefixlen > 0 && nsubmatch > 0) + submatch[0] = StringPiece(submatch[0].data() - prefixlen, submatch[0].size() + prefixlen); + + // Zero submatches that don't exist in the regexp. + for (int i = ncap; i < nsubmatch; i++) + submatch[i] = StringPiece(); + return true; +} + +// Internal matcher - like Match() but takes Args not StringPieces. +bool RE2::DoMatch(const StringPiece &text, Anchor re_anchor, size_t *consumed, const Arg *const *args, int n) const { + if (!ok()) { + if (options_.log_errors()) + LOG(ERROR) << "Invalid RE2: " << *error_; + return false; + } + + if (NumberOfCapturingGroups() < n) { + // RE has fewer capturing groups than number of Arg pointers passed in. + return false; + } + + // Count number of capture groups needed. + int nvec; + if (n == 0 && consumed == NULL) + nvec = 0; + else + nvec = n + 1; + + StringPiece *vec; + StringPiece stkvec[kVecSize]; + StringPiece *heapvec = NULL; + + if (nvec <= static_cast(arraysize(stkvec))) { + vec = stkvec; + } else { + vec = new StringPiece[nvec]; + heapvec = vec; + } + + if (!Match(text, 0, text.size(), re_anchor, vec, nvec)) { + delete[] heapvec; + return false; + } + + if (consumed != NULL) + *consumed = static_cast(EndPtr(vec[0]) - BeginPtr(text)); + + if (n == 0 || args == NULL) { + // We are not interested in results + delete[] heapvec; + return true; + } + + // If we got here, we must have matched the whole pattern. + for (int i = 0; i < n; i++) { + const StringPiece &s = vec[i + 1]; + if (!args[i]->Parse(s.data(), s.size())) { + // TODO: Should we indicate what the error was? + delete[] heapvec; + return false; + } + } + + delete[] heapvec; + return true; +} + +// Checks that the rewrite string is well-formed with respect to this +// regular expression. +bool RE2::CheckRewriteString(const StringPiece &rewrite, std::string *error) const { + int max_token = -1; + for (const char *s = rewrite.data(), *end = s + rewrite.size(); s < end; s++) { + int c = *s; + if (c != '\\') { + continue; + } + if (++s == end) { + *error = "Rewrite schema error: '\\' not allowed at end."; + return false; + } + c = *s; + if (c == '\\') { + continue; + } + if (!isdigit(c)) { + *error = "Rewrite schema error: " + "'\\' must be followed by a digit or '\\'."; + return false; + } + int n = (c - '0'); + if (max_token < n) { + max_token = n; + } + } + + if (max_token > NumberOfCapturingGroups()) { + *error = StringPrintf("Rewrite schema requests %d matches, but the regexp only has %d " + "parenthesized subexpressions.", + max_token, + NumberOfCapturingGroups()); + return false; + } + return true; +} + +// Returns the maximum submatch needed for the rewrite to be done by Replace(). +// E.g. if rewrite == "foo \\2,\\1", returns 2. +int RE2::MaxSubmatch(const StringPiece &rewrite) { + int max = 0; + for (const char *s = rewrite.data(), *end = s + rewrite.size(); s < end; s++) { + if (*s == '\\') { + s++; + int c = (s < end) ? *s : -1; + if (isdigit(c)) { + int n = (c - '0'); + if (n > max) + max = n; + } + } + } + return max; +} + +// Append the "rewrite" string, with backslash subsitutions from "vec", +// to string "out". +bool RE2::Rewrite(std::string *out, const StringPiece &rewrite, const StringPiece *vec, int veclen) const { + for (const char *s = rewrite.data(), *end = s + rewrite.size(); s < end; s++) { + if (*s != '\\') { + out->push_back(*s); + continue; + } + s++; + int c = (s < end) ? *s : -1; + if (isdigit(c)) { + int n = (c - '0'); + if (n >= veclen) { + if (options_.log_errors()) { + LOG(ERROR) << "invalid substitution \\" << n << " from " << veclen << " groups"; + } + return false; + } + StringPiece snip = vec[n]; + if (!snip.empty()) + out->append(snip.data(), snip.size()); + } else if (c == '\\') { + out->push_back('\\'); + } else { + if (options_.log_errors()) + LOG(ERROR) << "invalid rewrite pattern: " << rewrite.data(); + return false; + } + } + return true; +} + +/***** Parsers for various types *****/ + +namespace re2_internal { + +template <> +bool Parse(const char *str, size_t n, void *dest) { + // We fail if somebody asked us to store into a non-NULL void* pointer + return (dest == NULL); +} + +template <> +bool Parse(const char *str, size_t n, std::string *dest) { + if (dest == NULL) + return true; + dest->assign(str, n); + return true; +} + +template <> +bool Parse(const char *str, size_t n, StringPiece *dest) { + if (dest == NULL) + return true; + *dest = StringPiece(str, n); + return true; +} + +template <> +bool Parse(const char *str, size_t n, char *dest) { + if (n != 1) + return false; + if (dest == NULL) + return true; + *dest = str[0]; + return true; +} + +template <> +bool Parse(const char *str, size_t n, signed char *dest) { + if (n != 1) + return false; + if (dest == NULL) + return true; + *dest = str[0]; + return true; +} + +template <> +bool Parse(const char *str, size_t n, unsigned char *dest) { + if (n != 1) + return false; + if (dest == NULL) + return true; + *dest = str[0]; + return true; +} + +// Largest number spec that we are willing to parse +static const int kMaxNumberLength = 32; + +// REQUIRES "buf" must have length at least nbuf. +// Copies "str" into "buf" and null-terminates. +// Overwrites *np with the new length. +static const char *TerminateNumber(char *buf, size_t nbuf, const char *str, size_t *np, bool accept_spaces) { + size_t n = *np; + if (n == 0) + return ""; + if (n > 0 && isspace(*str)) { + // We are less forgiving than the strtoxxx() routines and do not + // allow leading spaces. We do allow leading spaces for floats. + if (!accept_spaces) { + return ""; + } + while (n > 0 && isspace(*str)) { + n--; + str++; + } + } + + // Although buf has a fixed maximum size, we can still handle + // arbitrarily large integers correctly by omitting leading zeros. + // (Numbers that are still too long will be out of range.) + // Before deciding whether str is too long, + // remove leading zeros with s/000+/00/. + // Leaving the leading two zeros in place means that + // we don't change 0000x123 (invalid) into 0x123 (valid). + // Skip over leading - before replacing. + bool neg = false; + if (n >= 1 && str[0] == '-') { + neg = true; + n--; + str++; + } + + if (n >= 3 && str[0] == '0' && str[1] == '0') { + while (n >= 3 && str[2] == '0') { + n--; + str++; + } + } + + if (neg) { // make room in buf for - + n++; + str--; + } + + if (n > nbuf - 1) + return ""; + + memmove(buf, str, n); + if (neg) { + buf[0] = '-'; + } + buf[n] = '\0'; + *np = n; + return buf; +} + +template <> +bool Parse(const char *str, size_t n, float *dest) { + if (n == 0) + return false; + static const int kMaxLength = 200; + char buf[kMaxLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, true); + char *end; + errno = 0; + float r = strtof(str, &end); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, double *dest) { + if (n == 0) + return false; + static const int kMaxLength = 200; + char buf[kMaxLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, true); + char *end; + errno = 0; + double r = strtod(str, &end); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, long *dest, int radix) { + if (n == 0) + return false; + char buf[kMaxNumberLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, false); + char *end; + errno = 0; + long r = strtol(str, &end, radix); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, unsigned long *dest, int radix) { + if (n == 0) + return false; + char buf[kMaxNumberLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, false); + if (str[0] == '-') { + // strtoul() will silently accept negative numbers and parse + // them. This module is more strict and treats them as errors. + return false; + } + + char *end; + errno = 0; + unsigned long r = strtoul(str, &end, radix); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, short *dest, int radix) { + long r; + if (!Parse(str, n, &r, radix)) + return false; // Could not parse + if ((short)r != r) + return false; // Out of range + if (dest == NULL) + return true; + *dest = (short)r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, unsigned short *dest, int radix) { + unsigned long r; + if (!Parse(str, n, &r, radix)) + return false; // Could not parse + if ((unsigned short)r != r) + return false; // Out of range + if (dest == NULL) + return true; + *dest = (unsigned short)r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, int *dest, int radix) { + long r; + if (!Parse(str, n, &r, radix)) + return false; // Could not parse + if ((int)r != r) + return false; // Out of range + if (dest == NULL) + return true; + *dest = (int)r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, unsigned int *dest, int radix) { + unsigned long r; + if (!Parse(str, n, &r, radix)) + return false; // Could not parse + if ((unsigned int)r != r) + return false; // Out of range + if (dest == NULL) + return true; + *dest = (unsigned int)r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, long long *dest, int radix) { + if (n == 0) + return false; + char buf[kMaxNumberLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, false); + char *end; + errno = 0; + long long r = strtoll(str, &end, radix); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +template <> +bool Parse(const char *str, size_t n, unsigned long long *dest, int radix) { + if (n == 0) + return false; + char buf[kMaxNumberLength + 1]; + str = TerminateNumber(buf, sizeof buf, str, &n, false); + if (str[0] == '-') { + // strtoull() will silently accept negative numbers and parse + // them. This module is more strict and treats them as errors. + return false; + } + char *end; + errno = 0; + unsigned long long r = strtoull(str, &end, radix); + if (end != str + n) + return false; // Leftover junk + if (errno) + return false; + if (dest == NULL) + return true; + *dest = r; + return true; +} + +} // namespace re2_internal + +namespace hooks { + +#ifdef RE2_HAVE_THREAD_LOCAL +thread_local const RE2 *context = NULL; +#endif + +template +union Hook { + void Store(T *cb) { cb_.store(cb, std::memory_order_release); } + T *Load() const { return cb_.load(std::memory_order_acquire); } + +#if !defined(__clang__) && defined(_MSC_VER) + // Citing https://github.com/protocolbuffers/protobuf/pull/4777 as precedent, + // this is a gross hack to make std::atomic constant-initialized on MSVC. + static_assert(ATOMIC_POINTER_LOCK_FREE == 2, "std::atomic must be always lock-free"); + T *cb_for_constinit_; +#endif + + std::atomic cb_; +}; + +template +static void DoNothing(const T &) {} + +#define DEFINE_HOOK(type, name) \ + static Hook name##_hook = {{&DoNothing}}; \ + void Set##type##Hook(type##Callback *cb) { name##_hook.Store(cb); } \ + type##Callback *Get##type##Hook() { return name##_hook.Load(); } + +DEFINE_HOOK(DFAStateCacheReset, dfa_state_cache_reset) +DEFINE_HOOK(DFASearchFailure, dfa_search_failure) + +#undef DEFINE_HOOK + +} // namespace hooks + +} // namespace re2 diff --git a/internal/cpp/re2/re2.h b/internal/cpp/re2/re2.h new file mode 100644 index 00000000000..51872db547e --- /dev/null +++ b/internal/cpp/re2/re2.h @@ -0,0 +1,991 @@ +// Copyright 2003-2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_RE2_H_ +#define RE2_RE2_H_ + +// C++ interface to the re2 regular-expression library. +// RE2 supports Perl-style regular expressions (with extensions like +// \d, \w, \s, ...). +// +// ----------------------------------------------------------------------- +// REGEXP SYNTAX: +// +// This module uses the re2 library and hence supports +// its syntax for regular expressions, which is similar to Perl's with +// some of the more complicated things thrown away. In particular, +// backreferences and generalized assertions are not available, nor is \Z. +// +// See https://github.com/google/re2/wiki/Syntax for the syntax +// supported by RE2, and a comparison with PCRE and PERL regexps. +// +// For those not familiar with Perl's regular expressions, +// here are some examples of the most commonly used extensions: +// +// "hello (\\w+) world" -- \w matches a "word" character +// "version (\\d+)" -- \d matches a digit +// "hello\\s+world" -- \s matches any whitespace character +// "\\b(\\w+)\\b" -- \b matches non-empty string at word boundary +// "(?i)hello" -- (?i) turns on case-insensitive matching +// "/\\*(.*?)\\*/" -- .*? matches . minimum no. of times possible +// +// The double backslashes are needed when writing C++ string literals. +// However, they should NOT be used when writing C++11 raw string literals: +// +// R"(hello (\w+) world)" -- \w matches a "word" character +// R"(version (\d+))" -- \d matches a digit +// R"(hello\s+world)" -- \s matches any whitespace character +// R"(\b(\w+)\b)" -- \b matches non-empty string at word boundary +// R"((?i)hello)" -- (?i) turns on case-insensitive matching +// R"(/\*(.*?)\*/)" -- .*? matches . minimum no. of times possible +// +// When using UTF-8 encoding, case-insensitive matching will perform +// simple case folding, not full case folding. +// +// ----------------------------------------------------------------------- +// MATCHING INTERFACE: +// +// The "FullMatch" operation checks that supplied text matches a +// supplied pattern exactly. +// +// Example: successful match +// CHECK(RE2::FullMatch("hello", "h.*o")); +// +// Example: unsuccessful match (requires full match): +// CHECK(!RE2::FullMatch("hello", "e")); +// +// ----------------------------------------------------------------------- +// UTF-8 AND THE MATCHING INTERFACE: +// +// By default, the pattern and input text are interpreted as UTF-8. +// The RE2::Latin1 option causes them to be interpreted as Latin-1. +// +// Example: +// CHECK(RE2::FullMatch(utf8_string, RE2(utf8_pattern))); +// CHECK(RE2::FullMatch(latin1_string, RE2(latin1_pattern, RE2::Latin1))); +// +// ----------------------------------------------------------------------- +// SUBMATCH EXTRACTION: +// +// You can supply extra pointer arguments to extract submatches. +// On match failure, none of the pointees will have been modified. +// On match success, the submatches will be converted (as necessary) and +// their values will be assigned to their pointees until all conversions +// have succeeded or one conversion has failed. +// On conversion failure, the pointees will be in an indeterminate state +// because the caller has no way of knowing which conversion failed. +// However, conversion cannot fail for types like string and StringPiece +// that do not inspect the submatch contents. Hence, in the common case +// where all of the pointees are of such types, failure is always due to +// match failure and thus none of the pointees will have been modified. +// +// Example: extracts "ruby" into "s" and 1234 into "i" +// int i; +// std::string s; +// CHECK(RE2::FullMatch("ruby:1234", "(\\w+):(\\d+)", &s, &i)); +// +// Example: fails because string cannot be stored in integer +// CHECK(!RE2::FullMatch("ruby", "(.*)", &i)); +// +// Example: fails because there aren't enough sub-patterns +// CHECK(!RE2::FullMatch("ruby:1234", "\\w+:\\d+", &s)); +// +// Example: does not try to extract any extra sub-patterns +// CHECK(RE2::FullMatch("ruby:1234", "(\\w+):(\\d+)", &s)); +// +// Example: does not try to extract into NULL +// CHECK(RE2::FullMatch("ruby:1234", "(\\w+):(\\d+)", NULL, &i)); +// +// Example: integer overflow causes failure +// CHECK(!RE2::FullMatch("ruby:1234567891234", "\\w+:(\\d+)", &i)); +// +// NOTE(rsc): Asking for submatches slows successful matches quite a bit. +// This may get a little faster in the future, but right now is slower +// than PCRE. On the other hand, failed matches run *very* fast (faster +// than PCRE), as do matches without submatch extraction. +// +// ----------------------------------------------------------------------- +// PARTIAL MATCHES +// +// You can use the "PartialMatch" operation when you want the pattern +// to match any substring of the text. +// +// Example: simple search for a string: +// CHECK(RE2::PartialMatch("hello", "ell")); +// +// Example: find first number in a string +// int number; +// CHECK(RE2::PartialMatch("x*100 + 20", "(\\d+)", &number)); +// CHECK_EQ(number, 100); +// +// ----------------------------------------------------------------------- +// PRE-COMPILED REGULAR EXPRESSIONS +// +// RE2 makes it easy to use any string as a regular expression, without +// requiring a separate compilation step. +// +// If speed is of the essence, you can create a pre-compiled "RE2" +// object from the pattern and use it multiple times. If you do so, +// you can typically parse text faster than with sscanf. +// +// Example: precompile pattern for faster matching: +// RE2 pattern("h.*o"); +// while (ReadLine(&str)) { +// if (RE2::FullMatch(str, pattern)) ...; +// } +// +// ----------------------------------------------------------------------- +// SCANNING TEXT INCREMENTALLY +// +// The "Consume" operation may be useful if you want to repeatedly +// match regular expressions at the front of a string and skip over +// them as they match. This requires use of the "StringPiece" type, +// which represents a sub-range of a real string. +// +// Example: read lines of the form "var = value" from a string. +// std::string contents = ...; // Fill string somehow +// StringPiece input(contents); // Wrap a StringPiece around it +// +// std::string var; +// int value; +// while (RE2::Consume(&input, "(\\w+) = (\\d+)\n", &var, &value)) { +// ...; +// } +// +// Each successful call to "Consume" will set "var/value", and also +// advance "input" so it points past the matched text. Note that if the +// regular expression matches an empty string, input will advance +// by 0 bytes. If the regular expression being used might match +// an empty string, the loop body must check for this case and either +// advance the string or break out of the loop. +// +// The "FindAndConsume" operation is similar to "Consume" but does not +// anchor your match at the beginning of the string. For example, you +// could extract all words from a string by repeatedly calling +// RE2::FindAndConsume(&input, "(\\w+)", &word) +// +// ----------------------------------------------------------------------- +// USING VARIABLE NUMBER OF ARGUMENTS +// +// The above operations require you to know the number of arguments +// when you write the code. This is not always possible or easy (for +// example, the regular expression may be calculated at run time). +// You can use the "N" version of the operations when the number of +// match arguments are determined at run time. +// +// Example: +// const RE2::Arg* args[10]; +// int n; +// // ... populate args with pointers to RE2::Arg values ... +// // ... set n to the number of RE2::Arg objects ... +// bool match = RE2::FullMatchN(input, pattern, args, n); +// +// The last statement is equivalent to +// +// bool match = RE2::FullMatch(input, pattern, +// *args[0], *args[1], ..., *args[n - 1]); +// +// ----------------------------------------------------------------------- +// PARSING HEX/OCTAL/C-RADIX NUMBERS +// +// By default, if you pass a pointer to a numeric value, the +// corresponding text is interpreted as a base-10 number. You can +// instead wrap the pointer with a call to one of the operators Hex(), +// Octal(), or CRadix() to interpret the text in another base. The +// CRadix operator interprets C-style "0" (base-8) and "0x" (base-16) +// prefixes, but defaults to base-10. +// +// Example: +// int a, b, c, d; +// CHECK(RE2::FullMatch("100 40 0100 0x40", "(.*) (.*) (.*) (.*)", +// RE2::Octal(&a), RE2::Hex(&b), RE2::CRadix(&c), RE2::CRadix(&d)); +// will leave 64 in a, b, c, and d. + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +#include +#endif + +#include "stringpiece.h" + +namespace re2 { +class Prog; +class Regexp; +} // namespace re2 + +namespace re2 { + +// Interface for regular expression matching. Also corresponds to a +// pre-compiled regular expression. An "RE2" object is safe for +// concurrent use by multiple threads. +class RE2 { +public: + // We convert user-passed pointers into special Arg objects + class Arg; + class Options; + + // Defined in set.h. + class Set; + + enum ErrorCode { + NoError = 0, + + // Unexpected error + ErrorInternal, + + // Parse errors + ErrorBadEscape, // bad escape sequence + ErrorBadCharClass, // bad character class + ErrorBadCharRange, // bad character class range + ErrorMissingBracket, // missing closing ] + ErrorMissingParen, // missing closing ) + ErrorUnexpectedParen, // unexpected closing ) + ErrorTrailingBackslash, // trailing \ at end of regexp + ErrorRepeatArgument, // repeat argument missing, e.g. "*" + ErrorRepeatSize, // bad repetition argument + ErrorRepeatOp, // bad repetition operator + ErrorBadPerlOp, // bad perl operator + ErrorBadUTF8, // invalid UTF-8 in regexp + ErrorBadNamedCapture, // bad named capture group + ErrorPatternTooLarge // pattern too large (compile failed) + }; + + // Predefined common options. + // If you need more complicated things, instantiate + // an Option class, possibly passing one of these to + // the Option constructor, change the settings, and pass that + // Option class to the RE2 constructor. + enum CannedOptions { + DefaultOptions = 0, + Latin1, // treat input as Latin-1 (default UTF-8) + POSIX, // POSIX syntax, leftmost-longest match + Quiet // do not log about regexp parse errors + }; + + // Need to have the const char* and const std::string& forms for implicit + // conversions when passing string literals to FullMatch and PartialMatch. + // Otherwise the StringPiece form would be sufficient. + RE2(const char *pattern); + RE2(const std::string &pattern); + RE2(const StringPiece &pattern); + RE2(const StringPiece &pattern, const Options &options); + ~RE2(); + + // Not copyable. + // RE2 objects are expensive. You should probably use std::shared_ptr + // instead. If you really must copy, RE2(first.pattern(), first.options()) + // effectively does so: it produces a second object that mimics the first. + RE2(const RE2 &) = delete; + RE2 &operator=(const RE2 &) = delete; + // Not movable. + // RE2 objects are thread-safe and logically immutable. You should probably + // use std::unique_ptr instead. Otherwise, consider std::deque if + // direct emplacement into a container is desired. If you really must move, + // be prepared to submit a design document along with your feature request. + RE2(RE2 &&) = delete; + RE2 &operator=(RE2 &&) = delete; + + // Returns whether RE2 was created properly. + bool ok() const { return error_code() == NoError; } + + // The string specification for this RE2. E.g. + // RE2 re("ab*c?d+"); + // re.pattern(); // "ab*c?d+" + const std::string &pattern() const { return *pattern_; } + + // If RE2 could not be created properly, returns an error string. + // Else returns the empty string. + const std::string &error() const { return *error_; } + + // If RE2 could not be created properly, returns an error code. + // Else returns RE2::NoError (== 0). + ErrorCode error_code() const { return error_code_; } + + // If RE2 could not be created properly, returns the offending + // portion of the regexp. + const std::string &error_arg() const { return *error_arg_; } + + // Returns the program size, a very approximate measure of a regexp's "cost". + // Larger numbers are more expensive than smaller numbers. + int ProgramSize() const; + int ReverseProgramSize() const; + + // If histogram is not null, outputs the program fanout + // as a histogram bucketed by powers of 2. + // Returns the number of the largest non-empty bucket. + int ProgramFanout(std::vector *histogram) const; + int ReverseProgramFanout(std::vector *histogram) const; + + // Returns the underlying Regexp; not for general use. + // Returns entire_regexp_ so that callers don't need + // to know about prefix_ and prefix_foldcase_. + re2::Regexp *Regexp() const { return entire_regexp_; } + + /***** The array-based matching interface ******/ + + // The functions here have names ending in 'N' and are used to implement + // the functions whose names are the prefix before the 'N'. It is sometimes + // useful to invoke them directly, but the syntax is awkward, so the 'N'-less + // versions should be preferred. + static bool FullMatchN(const StringPiece &text, const RE2 &re, const Arg *const args[], int n); + static bool PartialMatchN(const StringPiece &text, const RE2 &re, const Arg *const args[], int n); + static bool ConsumeN(StringPiece *input, const RE2 &re, const Arg *const args[], int n); + static bool FindAndConsumeN(StringPiece *input, const RE2 &re, const Arg *const args[], int n); + +private: + template + static inline bool Apply(F f, SP sp, const RE2 &re) { + return f(sp, re, NULL, 0); + } + + template + static inline bool Apply(F f, SP sp, const RE2 &re, const A &...a) { + const Arg *const args[] = {&a...}; + const int n = sizeof...(a); + return f(sp, re, args, n); + } + +public: + // In order to allow FullMatch() et al. to be called with a varying number + // of arguments of varying types, we use two layers of variadic templates. + // The first layer constructs the temporary Arg objects. The second layer + // (above) constructs the array of pointers to the temporary Arg objects. + + /***** The useful part: the matching interface *****/ + + // Matches "text" against "re". If pointer arguments are + // supplied, copies matched sub-patterns into them. + // + // You can pass in a "const char*" or a "std::string" for "text". + // You can pass in a "const char*" or a "std::string" or a "RE2" for "re". + // + // The provided pointer arguments can be pointers to any scalar numeric + // type, or one of: + // std::string (matched piece is copied to string) + // StringPiece (StringPiece is mutated to point to matched piece) + // T (where "bool T::ParseFrom(const char*, size_t)" exists) + // (void*)NULL (the corresponding matched sub-pattern is not copied) + // + // Returns true iff all of the following conditions are satisfied: + // a. "text" matches "re" fully - from the beginning to the end of "text". + // b. The number of matched sub-patterns is >= number of supplied pointers. + // c. The "i"th argument has a suitable type for holding the + // string captured as the "i"th sub-pattern. If you pass in + // NULL for the "i"th argument, or pass fewer arguments than + // number of sub-patterns, the "i"th captured sub-pattern is + // ignored. + // + // CAVEAT: An optional sub-pattern that does not exist in the + // matched string is assigned the empty string. Therefore, the + // following will return false (because the empty string is not a + // valid number): + // int number; + // RE2::FullMatch("abc", "[a-z]+(\\d+)?", &number); + template + static bool FullMatch(const StringPiece &text, const RE2 &re, A &&...a) { + return Apply(FullMatchN, text, re, Arg(std::forward(a))...); + } + + // Like FullMatch(), except that "re" is allowed to match a substring + // of "text". + // + // Returns true iff all of the following conditions are satisfied: + // a. "text" matches "re" partially - for some substring of "text". + // b. The number of matched sub-patterns is >= number of supplied pointers. + // c. The "i"th argument has a suitable type for holding the + // string captured as the "i"th sub-pattern. If you pass in + // NULL for the "i"th argument, or pass fewer arguments than + // number of sub-patterns, the "i"th captured sub-pattern is + // ignored. + template + static bool PartialMatch(const StringPiece &text, const RE2 &re, A &&...a) { + return Apply(PartialMatchN, text, re, Arg(std::forward(a))...); + } + + // Like FullMatch() and PartialMatch(), except that "re" has to match + // a prefix of the text, and "input" is advanced past the matched + // text. Note: "input" is modified iff this routine returns true + // and "re" matched a non-empty substring of "input". + // + // Returns true iff all of the following conditions are satisfied: + // a. "input" matches "re" partially - for some prefix of "input". + // b. The number of matched sub-patterns is >= number of supplied pointers. + // c. The "i"th argument has a suitable type for holding the + // string captured as the "i"th sub-pattern. If you pass in + // NULL for the "i"th argument, or pass fewer arguments than + // number of sub-patterns, the "i"th captured sub-pattern is + // ignored. + template + static bool Consume(StringPiece *input, const RE2 &re, A &&...a) { + return Apply(ConsumeN, input, re, Arg(std::forward(a))...); + } + + // Like Consume(), but does not anchor the match at the beginning of + // the text. That is, "re" need not start its match at the beginning + // of "input". For example, "FindAndConsume(s, "(\\w+)", &word)" finds + // the next word in "s" and stores it in "word". + // + // Returns true iff all of the following conditions are satisfied: + // a. "input" matches "re" partially - for some substring of "input". + // b. The number of matched sub-patterns is >= number of supplied pointers. + // c. The "i"th argument has a suitable type for holding the + // string captured as the "i"th sub-pattern. If you pass in + // NULL for the "i"th argument, or pass fewer arguments than + // number of sub-patterns, the "i"th captured sub-pattern is + // ignored. + template + static bool FindAndConsume(StringPiece *input, const RE2 &re, A &&...a) { + return Apply(FindAndConsumeN, input, re, Arg(std::forward(a))...); + } + + // Replace the first match of "re" in "str" with "rewrite". + // Within "rewrite", backslash-escaped digits (\1 to \9) can be + // used to insert text matching corresponding parenthesized group + // from the pattern. \0 in "rewrite" refers to the entire matching + // text. E.g., + // + // std::string s = "yabba dabba doo"; + // CHECK(RE2::Replace(&s, "b+", "d")); + // + // will leave "s" containing "yada dabba doo" + // + // Returns true if the pattern matches and a replacement occurs, + // false otherwise. + static bool Replace(std::string *str, const RE2 &re, const StringPiece &rewrite); + + // Like Replace(), except replaces successive non-overlapping occurrences + // of the pattern in the string with the rewrite. E.g. + // + // std::string s = "yabba dabba doo"; + // CHECK(RE2::GlobalReplace(&s, "b+", "d")); + // + // will leave "s" containing "yada dada doo" + // Replacements are not subject to re-matching. + // + // Because GlobalReplace only replaces non-overlapping matches, + // replacing "ana" within "banana" makes only one replacement, not two. + // + // Returns the number of replacements made. + static int GlobalReplace(std::string *str, const RE2 &re, const StringPiece &rewrite); + + // Like Replace, except that if the pattern matches, "rewrite" + // is copied into "out" with substitutions. The non-matching + // portions of "text" are ignored. + // + // Returns true iff a match occurred and the extraction happened + // successfully; if no match occurs, the string is left unaffected. + // + // REQUIRES: "text" must not alias any part of "*out". + static bool Extract(const StringPiece &text, const RE2 &re, const StringPiece &rewrite, std::string *out); + + // Escapes all potentially meaningful regexp characters in + // 'unquoted'. The returned string, used as a regular expression, + // will match exactly the original string. For example, + // 1.5-2.0? + // may become: + // 1\.5\-2\.0\? + static std::string QuoteMeta(const StringPiece &unquoted); + + // Computes range for any strings matching regexp. The min and max can in + // some cases be arbitrarily precise, so the caller gets to specify the + // maximum desired length of string returned. + // + // Assuming PossibleMatchRange(&min, &max, N) returns successfully, any + // string s that is an anchored match for this regexp satisfies + // min <= s && s <= max. + // + // Note that PossibleMatchRange() will only consider the first copy of an + // infinitely repeated element (i.e., any regexp element followed by a '*' or + // '+' operator). Regexps with "{N}" constructions are not affected, as those + // do not compile down to infinite repetitions. + // + // Returns true on success, false on error. + bool PossibleMatchRange(std::string *min, std::string *max, int maxlen) const; + + // Generic matching interface + + // Type of match. + enum Anchor { + UNANCHORED, // No anchoring + ANCHOR_START, // Anchor at start only + ANCHOR_BOTH // Anchor at start and end + }; + + // Return the number of capturing subpatterns, or -1 if the + // regexp wasn't valid on construction. The overall match ($0) + // does not count: if the regexp is "(a)(b)", returns 2. + int NumberOfCapturingGroups() const { return num_captures_; } + + // Return a map from names to capturing indices. + // The map records the index of the leftmost group + // with the given name. + // Only valid until the re is deleted. + const std::map &NamedCapturingGroups() const; + + // Return a map from capturing indices to names. + // The map has no entries for unnamed groups. + // Only valid until the re is deleted. + const std::map &CapturingGroupNames() const; + + // General matching routine. + // Match against text starting at offset startpos + // and stopping the search at offset endpos. + // Returns true if match found, false if not. + // On a successful match, fills in submatch[] (up to nsubmatch entries) + // with information about submatches. + // I.e. matching RE2("(foo)|(bar)baz") on "barbazbla" will return true, with + // submatch[0] = "barbaz", submatch[1].data() = NULL, submatch[2] = "bar", + // submatch[3].data() = NULL, ..., up to submatch[nsubmatch-1].data() = NULL. + // Caveat: submatch[] may be clobbered even on match failure. + // + // Don't ask for more match information than you will use: + // runs much faster with nsubmatch == 1 than nsubmatch > 1, and + // runs even faster if nsubmatch == 0. + // Doesn't make sense to use nsubmatch > 1 + NumberOfCapturingGroups(), + // but will be handled correctly. + // + // Passing text == StringPiece(NULL, 0) will be handled like any other + // empty string, but note that on return, it will not be possible to tell + // whether submatch i matched the empty string or did not match: + // either way, submatch[i].data() == NULL. + bool Match(const StringPiece &text, size_t startpos, size_t endpos, Anchor re_anchor, StringPiece *submatch, int nsubmatch) const; + + // Check that the given rewrite string is suitable for use with this + // regular expression. It checks that: + // * The regular expression has enough parenthesized subexpressions + // to satisfy all of the \N tokens in rewrite + // * The rewrite string doesn't have any syntax errors. E.g., + // '\' followed by anything other than a digit or '\'. + // A true return value guarantees that Replace() and Extract() won't + // fail because of a bad rewrite string. + bool CheckRewriteString(const StringPiece &rewrite, std::string *error) const; + + // Returns the maximum submatch needed for the rewrite to be done by + // Replace(). E.g. if rewrite == "foo \\2,\\1", returns 2. + static int MaxSubmatch(const StringPiece &rewrite); + + // Append the "rewrite" string, with backslash subsitutions from "vec", + // to string "out". + // Returns true on success. This method can fail because of a malformed + // rewrite string. CheckRewriteString guarantees that the rewrite will + // be sucessful. + bool Rewrite(std::string *out, const StringPiece &rewrite, const StringPiece *vec, int veclen) const; + + // Constructor options + class Options { + public: + // The options are (defaults in parentheses): + // + // utf8 (true) text and pattern are UTF-8; otherwise Latin-1 + // posix_syntax (false) restrict regexps to POSIX egrep syntax + // longest_match (false) search for longest match, not first match + // log_errors (true) log syntax and execution errors to ERROR + // max_mem (see below) approx. max memory footprint of RE2 + // literal (false) interpret string as literal, not regexp + // never_nl (false) never match \n, even if it is in regexp + // dot_nl (false) dot matches everything including new line + // never_capture (false) parse all parens as non-capturing + // case_sensitive (true) match is case-sensitive (regexp can override + // with (?i) unless in posix_syntax mode) + // + // The following options are only consulted when posix_syntax == true. + // When posix_syntax == false, these features are always enabled and + // cannot be turned off; to perform multi-line matching in that case, + // begin the regexp with (?m). + // perl_classes (false) allow Perl's \d \s \w \D \S \W + // word_boundary (false) allow Perl's \b \B (word boundary and not) + // one_line (false) ^ and $ only match beginning and end of text + // + // The max_mem option controls how much memory can be used + // to hold the compiled form of the regexp (the Prog) and + // its cached DFA graphs. Code Search placed limits on the number + // of Prog instructions and DFA states: 10,000 for both. + // In RE2, those limits would translate to about 240 KB per Prog + // and perhaps 2.5 MB per DFA (DFA state sizes vary by regexp; RE2 does a + // better job of keeping them small than Code Search did). + // Each RE2 has two Progs (one forward, one reverse), and each Prog + // can have two DFAs (one first match, one longest match). + // That makes 4 DFAs: + // + // forward, first-match - used for UNANCHORED or ANCHOR_START searches + // if opt.longest_match() == false + // forward, longest-match - used for all ANCHOR_BOTH searches, + // and the other two kinds if + // opt.longest_match() == true + // reverse, first-match - never used + // reverse, longest-match - used as second phase for unanchored searches + // + // The RE2 memory budget is statically divided between the two + // Progs and then the DFAs: two thirds to the forward Prog + // and one third to the reverse Prog. The forward Prog gives half + // of what it has left over to each of its DFAs. The reverse Prog + // gives it all to its longest-match DFA. + // + // Once a DFA fills its budget, it flushes its cache and starts over. + // If this happens too often, RE2 falls back on the NFA implementation. + + // For now, make the default budget something close to Code Search. + static const int kDefaultMaxMem = 8 << 20; + + enum Encoding { EncodingUTF8 = 1, EncodingLatin1 }; + + Options() + : max_mem_(kDefaultMaxMem), encoding_(EncodingUTF8), posix_syntax_(false), longest_match_(false), log_errors_(true), literal_(false), + never_nl_(false), dot_nl_(false), never_capture_(false), case_sensitive_(true), perl_classes_(false), word_boundary_(false), + one_line_(false) {} + + /*implicit*/ Options(CannedOptions); + + int64_t max_mem() const { return max_mem_; } + void set_max_mem(int64_t m) { max_mem_ = m; } + + Encoding encoding() const { return encoding_; } + void set_encoding(Encoding encoding) { encoding_ = encoding; } + + bool posix_syntax() const { return posix_syntax_; } + void set_posix_syntax(bool b) { posix_syntax_ = b; } + + bool longest_match() const { return longest_match_; } + void set_longest_match(bool b) { longest_match_ = b; } + + bool log_errors() const { return log_errors_; } + void set_log_errors(bool b) { log_errors_ = b; } + + bool literal() const { return literal_; } + void set_literal(bool b) { literal_ = b; } + + bool never_nl() const { return never_nl_; } + void set_never_nl(bool b) { never_nl_ = b; } + + bool dot_nl() const { return dot_nl_; } + void set_dot_nl(bool b) { dot_nl_ = b; } + + bool never_capture() const { return never_capture_; } + void set_never_capture(bool b) { never_capture_ = b; } + + bool case_sensitive() const { return case_sensitive_; } + void set_case_sensitive(bool b) { case_sensitive_ = b; } + + bool perl_classes() const { return perl_classes_; } + void set_perl_classes(bool b) { perl_classes_ = b; } + + bool word_boundary() const { return word_boundary_; } + void set_word_boundary(bool b) { word_boundary_ = b; } + + bool one_line() const { return one_line_; } + void set_one_line(bool b) { one_line_ = b; } + + void Copy(const Options &src) { *this = src; } + + int ParseFlags() const; + + private: + int64_t max_mem_; + Encoding encoding_; + bool posix_syntax_; + bool longest_match_; + bool log_errors_; + bool literal_; + bool never_nl_; + bool dot_nl_; + bool never_capture_; + bool case_sensitive_; + bool perl_classes_; + bool word_boundary_; + bool one_line_; + }; + + // Returns the options set in the constructor. + const Options &options() const { return options_; } + + // Argument converters; see below. + template + static Arg CRadix(T *ptr); + template + static Arg Hex(T *ptr); + template + static Arg Octal(T *ptr); + + // Controls the maximum count permitted by GlobalReplace(); -1 is unlimited. + // FOR FUZZING ONLY. + static void FUZZING_ONLY_set_maximum_global_replace_count(int i); + +private: + void Init(const StringPiece &pattern, const Options &options); + + bool DoMatch(const StringPiece &text, Anchor re_anchor, size_t *consumed, const Arg *const args[], int n) const; + + re2::Prog *ReverseProg() const; + + // First cache line is relatively cold fields. + const std::string *pattern_; // string regular expression + Options options_; // option flags + re2::Regexp *entire_regexp_; // parsed regular expression + re2::Regexp *suffix_regexp_; // parsed regular expression, prefix_ removed + const std::string *error_; // error indicator (or points to empty string) + const std::string *error_arg_; // fragment of regexp showing error (or ditto) + + // Second cache line is relatively hot fields. + // These are ordered oddly to pack everything. + int num_captures_; // number of capturing groups + ErrorCode error_code_ : 29; // error code (29 bits is more than enough) + bool longest_match_ : 1; // cached copy of options_.longest_match() + bool is_one_pass_ : 1; // can use prog_->SearchOnePass? + bool prefix_foldcase_ : 1; // prefix_ is ASCII case-insensitive + std::string prefix_; // required prefix (before suffix_regexp_) + re2::Prog *prog_; // compiled program for regexp + + // Reverse Prog for DFA execution only + mutable re2::Prog *rprog_; + // Map from capture names to indices + mutable const std::map *named_groups_; + // Map from capture indices to names + mutable const std::map *group_names_; + + mutable std::once_flag rprog_once_; + mutable std::once_flag named_groups_once_; + mutable std::once_flag group_names_once_; +}; + +/***** Implementation details *****/ + +namespace re2_internal { + +// Types for which the 3-ary Parse() function template has specializations. +template +struct Parse3ary : public std::false_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; +template <> +struct Parse3ary : public std::true_type {}; + +template +bool Parse(const char *str, size_t n, T *dest); + +// Types for which the 4-ary Parse() function template has specializations. +template +struct Parse4ary : public std::false_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; +template <> +struct Parse4ary : public std::true_type {}; + +template +bool Parse(const char *str, size_t n, T *dest, int radix); + +} // namespace re2_internal + +class RE2::Arg { +private: + template + using CanParse3ary = typename std::enable_if::value, int>::type; + + template + using CanParse4ary = typename std::enable_if::value, int>::type; + +#if !defined(_MSC_VER) + template + using CanParseFrom = + typename std::enable_if(&T::ParseFrom))>::value, + int>::type; +#endif + +public: + Arg() : Arg(nullptr) {} + Arg(std::nullptr_t ptr) : arg_(ptr), parser_(DoNothing) {} + + template = 0> + Arg(T *ptr) : arg_(ptr), parser_(DoParse3ary) {} + + template = 0> + Arg(T *ptr) : arg_(ptr), parser_(DoParse4ary) {} + +#if !defined(_MSC_VER) + template = 0> + Arg(T *ptr) : arg_(ptr), parser_(DoParseFrom) {} +#endif + + typedef bool (*Parser)(const char *str, size_t n, void *dest); + + template + Arg(T *ptr, Parser parser) : arg_(ptr), parser_(parser) {} + + bool Parse(const char *str, size_t n) const { return (*parser_)(str, n, arg_); } + +private: + static bool DoNothing(const char * /*str*/, size_t /*n*/, void * /*dest*/) { return true; } + + template + static bool DoParse3ary(const char *str, size_t n, void *dest) { + return re2_internal::Parse(str, n, reinterpret_cast(dest)); + } + + template + static bool DoParse4ary(const char *str, size_t n, void *dest) { + return re2_internal::Parse(str, n, reinterpret_cast(dest), 10); + } + +#if !defined(_MSC_VER) + template + static bool DoParseFrom(const char *str, size_t n, void *dest) { + if (dest == NULL) + return true; + return reinterpret_cast(dest)->ParseFrom(str, n); + } +#endif + + void *arg_; + Parser parser_; +}; + +template +inline RE2::Arg RE2::CRadix(T *ptr) { + return RE2::Arg(ptr, [](const char *str, size_t n, void *dest) -> bool { return re2_internal::Parse(str, n, reinterpret_cast(dest), 0); }); +} + +template +inline RE2::Arg RE2::Hex(T *ptr) { + return RE2::Arg(ptr, [](const char *str, size_t n, void *dest) -> bool { return re2_internal::Parse(str, n, reinterpret_cast(dest), 16); }); +} + +template +inline RE2::Arg RE2::Octal(T *ptr) { + return RE2::Arg(ptr, [](const char *str, size_t n, void *dest) -> bool { return re2_internal::Parse(str, n, reinterpret_cast(dest), 8); }); +} + +// Silence warnings about missing initializers for members of LazyRE2. +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 6 +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#endif + +// Helper for writing global or static RE2s safely. +// Write +// static LazyRE2 re = {".*"}; +// and then use *re instead of writing +// static RE2 re(".*"); +// The former is more careful about multithreaded +// situations than the latter. +// +// N.B. This class never deletes the RE2 object that +// it constructs: that's a feature, so that it can be used +// for global and function static variables. +class LazyRE2 { +private: + struct NoArg {}; + +public: + typedef RE2 element_type; // support std::pointer_traits + + // Constructor omitted to preserve braced initialization in C++98. + + // Pretend to be a pointer to Type (never NULL due to on-demand creation): + RE2 &operator*() const { return *get(); } + RE2 *operator->() const { return get(); } + + // Named accessor/initializer: + RE2 *get() const { + std::call_once(once_, &LazyRE2::Init, this); + return ptr_; + } + + // All data fields must be public to support {"foo"} initialization. + const char *pattern_; + RE2::CannedOptions options_; + NoArg barrier_against_excess_initializers_; + + mutable RE2 *ptr_; + mutable std::once_flag once_; + +private: + static void Init(const LazyRE2 *lazy_re2) { lazy_re2->ptr_ = new RE2(lazy_re2->pattern_, lazy_re2->options_); } + + void operator=(const LazyRE2 &); // disallowed +}; + +namespace hooks { + +// Most platforms support thread_local. Older versions of iOS don't support +// thread_local, but for the sake of brevity, we lump together all versions +// of Apple platforms that aren't macOS. If an iOS application really needs +// the context pointee someday, we can get more specific then... +// +// As per https://github.com/google/re2/issues/325, thread_local support in +// MinGW seems to be buggy. (FWIW, Abseil folks also avoid it.) +#define RE2_HAVE_THREAD_LOCAL +#if (defined(__APPLE__) && !(defined(TARGET_OS_OSX) && TARGET_OS_OSX)) || defined(__MINGW32__) +#undef RE2_HAVE_THREAD_LOCAL +#endif + +// A hook must not make any assumptions regarding the lifetime of the context +// pointee beyond the current invocation of the hook. Pointers and references +// obtained via the context pointee should be considered invalidated when the +// hook returns. Hence, any data about the context pointee (e.g. its pattern) +// would have to be copied in order for it to be kept for an indefinite time. +// +// A hook must not use RE2 for matching. Control flow reentering RE2::Match() +// could result in infinite mutual recursion. To discourage that possibility, +// RE2 will not maintain the context pointer correctly when used in that way. +#ifdef RE2_HAVE_THREAD_LOCAL +extern thread_local const RE2 *context; +#endif + +struct DFAStateCacheReset { + int64_t state_budget; + size_t state_cache_size; +}; + +struct DFASearchFailure { + // Nothing yet... +}; + +#define DECLARE_HOOK(type) \ + using type##Callback = void(const type &); \ + void Set##type##Hook(type##Callback *cb); \ + type##Callback *Get##type##Hook(); + +DECLARE_HOOK(DFAStateCacheReset) +DECLARE_HOOK(DFASearchFailure) + +#undef DECLARE_HOOK + +} // namespace hooks + +} // namespace re2 + +using re2::LazyRE2; +using re2::RE2; + +#endif // RE2_RE2_H_ diff --git a/internal/cpp/re2/regexp.cc b/internal/cpp/re2/regexp.cc new file mode 100644 index 00000000000..08fa34d8b9d --- /dev/null +++ b/internal/cpp/re2/regexp.cc @@ -0,0 +1,957 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Regular expression representation. +// Tested by parse_test.cc + +#include "re2/regexp.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/stringpiece.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/mutex.h" +#include "util/utf.h" +#include "util/util.h" + +#ifdef min +#undef min +#endif +#ifdef max +#undef max +#endif + +namespace re2 { + +// Constructor. Allocates vectors as appropriate for operator. +Regexp::Regexp(RegexpOp op, ParseFlags parse_flags) + : op_(static_cast(op)), simple_(false), parse_flags_(static_cast(parse_flags)), ref_(1), nsub_(0), down_(NULL) { + subone_ = NULL; + memset(arguments.the_union_, 0, sizeof arguments.the_union_); +} + +// Destructor. Assumes already cleaned up children. +// Private: use Decref() instead of delete to destroy Regexps. +// Can't call Decref on the sub-Regexps here because +// that could cause arbitrarily deep recursion, so +// required Decref() to have handled them for us. +Regexp::~Regexp() { + if (nsub_ > 0) + LOG(DFATAL) << "Regexp not destroyed."; + + switch (op_) { + default: + break; + case kRegexpCapture: + delete arguments.capture.name_; + break; + case kRegexpLiteralString: + delete[] arguments.literal_string.runes_; + break; + case kRegexpCharClass: + if (arguments.char_class.cc_) + arguments.char_class.cc_->Delete(); + delete arguments.char_class.ccb_; + break; + } +} + +// If it's possible to destroy this regexp without recurring, +// do so and return true. Else return false. +bool Regexp::QuickDestroy() { + if (nsub_ == 0) { + delete this; + return true; + } + return false; +} + +// Similar to EmptyStorage in re2.cc. +struct RefStorage { + Mutex ref_mutex; + std::map ref_map; +}; +alignas(RefStorage) static char ref_storage[sizeof(RefStorage)]; + +static inline Mutex *ref_mutex() { return &reinterpret_cast(ref_storage)->ref_mutex; } + +static inline std::map *ref_map() { return &reinterpret_cast(ref_storage)->ref_map; } + +int Regexp::Ref() { + if (ref_ < kMaxRef) + return ref_; + + MutexLock l(ref_mutex()); + return (*ref_map())[this]; +} + +// Increments reference count, returns object as convenience. +Regexp *Regexp::Incref() { + if (ref_ >= kMaxRef - 1) { + static std::once_flag ref_once; + std::call_once(ref_once, []() { (void)new (ref_storage) RefStorage; }); + + // Store ref count in overflow map. + MutexLock l(ref_mutex()); + if (ref_ == kMaxRef) { + // already overflowed + (*ref_map())[this]++; + } else { + // overflowing now + (*ref_map())[this] = kMaxRef; + ref_ = kMaxRef; + } + return this; + } + + ref_++; + return this; +} + +// Decrements reference count and deletes this object if count reaches 0. +void Regexp::Decref() { + if (ref_ == kMaxRef) { + // Ref count is stored in overflow map. + MutexLock l(ref_mutex()); + int r = (*ref_map())[this] - 1; + if (r < kMaxRef) { + ref_ = static_cast(r); + ref_map()->erase(this); + } else { + (*ref_map())[this] = r; + } + return; + } + ref_--; + if (ref_ == 0) + Destroy(); +} + +// Deletes this object; ref count has count reached 0. +void Regexp::Destroy() { + if (QuickDestroy()) + return; + + // Handle recursive Destroy with explicit stack + // to avoid arbitrarily deep recursion on process stack [sigh]. + down_ = NULL; + Regexp *stack = this; + while (stack != NULL) { + Regexp *re = stack; + stack = re->down_; + if (re->ref_ != 0) + LOG(DFATAL) << "Bad reference count " << re->ref_; + if (re->nsub_ > 0) { + Regexp **subs = re->sub(); + for (int i = 0; i < re->nsub_; i++) { + Regexp *sub = subs[i]; + if (sub == NULL) + continue; + if (sub->ref_ == kMaxRef) + sub->Decref(); + else + --sub->ref_; + if (sub->ref_ == 0 && !sub->QuickDestroy()) { + sub->down_ = stack; + stack = sub; + } + } + if (re->nsub_ > 1) + delete[] subs; + re->nsub_ = 0; + } + delete re; + } +} + +void Regexp::AddRuneToString(Rune r) { + DCHECK(op_ == kRegexpLiteralString); + if (arguments.literal_string.nrunes_ == 0) { + // start with 8 + arguments.literal_string.runes_ = new Rune[8]; + } else if (arguments.literal_string.nrunes_ >= 8 && (arguments.literal_string.nrunes_ & (arguments.literal_string.nrunes_ - 1)) == 0) { + // double on powers of two + Rune *old = arguments.literal_string.runes_; + arguments.literal_string.runes_ = new Rune[arguments.literal_string.nrunes_ * 2]; + for (int i = 0; i < arguments.literal_string.nrunes_; i++) + arguments.literal_string.runes_[i] = old[i]; + delete[] old; + } + + arguments.literal_string.runes_[arguments.literal_string.nrunes_++] = r; +} + +Regexp *Regexp::HaveMatch(int match_id, ParseFlags flags) { + Regexp *re = new Regexp(kRegexpHaveMatch, flags); + re->arguments.match_id_ = match_id; + return re; +} + +Regexp *Regexp::StarPlusOrQuest(RegexpOp op, Regexp *sub, ParseFlags flags) { + // Squash **, ++ and ??. + if (op == sub->op() && flags == sub->parse_flags()) + return sub; + + // Squash *+, *?, +*, +?, ?* and ?+. They all squash to *, so because + // op is Star/Plus/Quest, we just have to check that sub->op() is too. + if ((sub->op() == kRegexpStar || sub->op() == kRegexpPlus || sub->op() == kRegexpQuest) && flags == sub->parse_flags()) { + // If sub is Star, no need to rewrite it. + if (sub->op() == kRegexpStar) + return sub; + + // Rewrite sub to Star. + Regexp *re = new Regexp(kRegexpStar, flags); + re->AllocSub(1); + re->sub()[0] = sub->sub()[0]->Incref(); + sub->Decref(); // We didn't consume the reference after all. + return re; + } + + Regexp *re = new Regexp(op, flags); + re->AllocSub(1); + re->sub()[0] = sub; + return re; +} + +Regexp *Regexp::Plus(Regexp *sub, ParseFlags flags) { return StarPlusOrQuest(kRegexpPlus, sub, flags); } + +Regexp *Regexp::Star(Regexp *sub, ParseFlags flags) { return StarPlusOrQuest(kRegexpStar, sub, flags); } + +Regexp *Regexp::Quest(Regexp *sub, ParseFlags flags) { return StarPlusOrQuest(kRegexpQuest, sub, flags); } + +Regexp *Regexp::ConcatOrAlternate(RegexpOp op, Regexp **sub, int nsub, ParseFlags flags, bool can_factor) { + if (nsub == 1) + return sub[0]; + + if (nsub == 0) { + if (op == kRegexpAlternate) + return new Regexp(kRegexpNoMatch, flags); + else + return new Regexp(kRegexpEmptyMatch, flags); + } + + PODArray subcopy; + if (op == kRegexpAlternate && can_factor) { + // Going to edit sub; make a copy so we don't step on caller. + subcopy = PODArray(nsub); + memmove(subcopy.data(), sub, nsub * sizeof sub[0]); + sub = subcopy.data(); + nsub = FactorAlternation(sub, nsub, flags); + if (nsub == 1) { + Regexp *re = sub[0]; + return re; + } + } + + if (nsub > kMaxNsub) { + // Too many subexpressions to fit in a single Regexp. + // Make a two-level tree. Two levels gets us to 65535^2. + int nbigsub = (nsub + kMaxNsub - 1) / kMaxNsub; + Regexp *re = new Regexp(op, flags); + re->AllocSub(nbigsub); + Regexp **subs = re->sub(); + for (int i = 0; i < nbigsub - 1; i++) + subs[i] = ConcatOrAlternate(op, sub + i * kMaxNsub, kMaxNsub, flags, false); + subs[nbigsub - 1] = ConcatOrAlternate(op, sub + (nbigsub - 1) * kMaxNsub, nsub - (nbigsub - 1) * kMaxNsub, flags, false); + return re; + } + + Regexp *re = new Regexp(op, flags); + re->AllocSub(nsub); + Regexp **subs = re->sub(); + for (int i = 0; i < nsub; i++) + subs[i] = sub[i]; + return re; +} + +Regexp *Regexp::Concat(Regexp **sub, int nsub, ParseFlags flags) { return ConcatOrAlternate(kRegexpConcat, sub, nsub, flags, false); } + +Regexp *Regexp::Alternate(Regexp **sub, int nsub, ParseFlags flags) { return ConcatOrAlternate(kRegexpAlternate, sub, nsub, flags, true); } + +Regexp *Regexp::AlternateNoFactor(Regexp **sub, int nsub, ParseFlags flags) { return ConcatOrAlternate(kRegexpAlternate, sub, nsub, flags, false); } + +Regexp *Regexp::Capture(Regexp *sub, ParseFlags flags, int cap) { + Regexp *re = new Regexp(kRegexpCapture, flags); + re->AllocSub(1); + re->sub()[0] = sub; + re->arguments.capture.cap_ = cap; + return re; +} + +Regexp *Regexp::Repeat(Regexp *sub, ParseFlags flags, int min, int max) { + Regexp *re = new Regexp(kRegexpRepeat, flags); + re->AllocSub(1); + re->sub()[0] = sub; + re->arguments.repeat.min_ = min; + re->arguments.repeat.max_ = max; + return re; +} + +Regexp *Regexp::NewLiteral(Rune rune, ParseFlags flags) { + Regexp *re = new Regexp(kRegexpLiteral, flags); + re->arguments.rune_ = rune; + return re; +} + +Regexp *Regexp::LiteralString(Rune *runes, int nrunes, ParseFlags flags) { + if (nrunes <= 0) + return new Regexp(kRegexpEmptyMatch, flags); + if (nrunes == 1) + return NewLiteral(runes[0], flags); + Regexp *re = new Regexp(kRegexpLiteralString, flags); + for (int i = 0; i < nrunes; i++) + re->AddRuneToString(runes[i]); + return re; +} + +Regexp *Regexp::NewCharClass(CharClass *cc, ParseFlags flags) { + Regexp *re = new Regexp(kRegexpCharClass, flags); + re->arguments.char_class.cc_ = cc; + return re; +} + +void Regexp::Swap(Regexp *that) { + // Regexp is not trivially copyable, so we cannot freely copy it with + // memmove(3), but swapping objects like so is safe for our purposes. + char tmp[sizeof *this]; + void *vthis = reinterpret_cast(this); + void *vthat = reinterpret_cast(that); + memmove(tmp, vthis, sizeof *this); + memmove(vthis, vthat, sizeof *this); + memmove(vthat, tmp, sizeof *this); +} + +// Tests equality of all top-level structure but not subregexps. +static bool TopEqual(Regexp *a, Regexp *b) { + if (a->op() != b->op()) + return false; + + switch (a->op()) { + case kRegexpNoMatch: + case kRegexpEmptyMatch: + case kRegexpAnyChar: + case kRegexpAnyByte: + case kRegexpBeginLine: + case kRegexpEndLine: + case kRegexpWordBoundary: + case kRegexpNoWordBoundary: + case kRegexpBeginText: + return true; + + case kRegexpEndText: + // The parse flags remember whether it's \z or (?-m:$), + // which matters when testing against PCRE. + return ((a->parse_flags() ^ b->parse_flags()) & Regexp::WasDollar) == 0; + + case kRegexpLiteral: + return a->rune() == b->rune() && ((a->parse_flags() ^ b->parse_flags()) & Regexp::FoldCase) == 0; + + case kRegexpLiteralString: + return a->nrunes() == b->nrunes() && ((a->parse_flags() ^ b->parse_flags()) & Regexp::FoldCase) == 0 && + memcmp(a->runes(), b->runes(), a->nrunes() * sizeof a->runes()[0]) == 0; + + case kRegexpAlternate: + case kRegexpConcat: + return a->nsub() == b->nsub(); + + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + return ((a->parse_flags() ^ b->parse_flags()) & Regexp::NonGreedy) == 0; + + case kRegexpRepeat: + return ((a->parse_flags() ^ b->parse_flags()) & Regexp::NonGreedy) == 0 && a->min() == b->min() && a->max() == b->max(); + + case kRegexpCapture: + return a->cap() == b->cap() && a->name() == b->name(); + + case kRegexpHaveMatch: + return a->match_id() == b->match_id(); + + case kRegexpCharClass: { + CharClass *acc = a->cc(); + CharClass *bcc = b->cc(); + return acc->size() == bcc->size() && acc->end() - acc->begin() == bcc->end() - bcc->begin() && + memcmp(acc->begin(), bcc->begin(), (acc->end() - acc->begin()) * sizeof acc->begin()[0]) == 0; + } + } + + LOG(DFATAL) << "Unexpected op in Regexp::Equal: " << a->op(); + return 0; +} + +bool Regexp::Equal(Regexp *a, Regexp *b) { + if (a == NULL || b == NULL) + return a == b; + + if (!TopEqual(a, b)) + return false; + + // Fast path: + // return without allocating vector if there are no subregexps. + switch (a->op()) { + case kRegexpAlternate: + case kRegexpConcat: + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + case kRegexpRepeat: + case kRegexpCapture: + break; + + default: + return true; + } + + // Committed to doing real work. + // The stack (vector) has pairs of regexps waiting to + // be compared. The regexps are only equal if + // all the pairs end up being equal. + std::vector stk; + + for (;;) { + // Invariant: TopEqual(a, b) == true. + Regexp *a2; + Regexp *b2; + switch (a->op()) { + default: + break; + case kRegexpAlternate: + case kRegexpConcat: + for (int i = 0; i < a->nsub(); i++) { + a2 = a->sub()[i]; + b2 = b->sub()[i]; + if (!TopEqual(a2, b2)) + return false; + stk.push_back(a2); + stk.push_back(b2); + } + break; + + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + case kRegexpRepeat: + case kRegexpCapture: + a2 = a->sub()[0]; + b2 = b->sub()[0]; + if (!TopEqual(a2, b2)) + return false; + // Really: + // stk.push_back(a2); + // stk.push_back(b2); + // break; + // but faster to assign directly and loop. + a = a2; + b = b2; + continue; + } + + size_t n = stk.size(); + if (n == 0) + break; + + DCHECK_GE(n, 2); + a = stk[n - 2]; + b = stk[n - 1]; + stk.resize(n - 2); + } + + return true; +} + +// Keep in sync with enum RegexpStatusCode in regexp.h +static const char *kErrorStrings[] = { + "no error", + "unexpected error", + "invalid escape sequence", + "invalid character class", + "invalid character class range", + "missing ]", + "missing )", + "unexpected )", + "trailing \\", + "no argument for repetition operator", + "invalid repetition size", + "bad repetition operator", + "invalid perl operator", + "invalid UTF-8", + "invalid named capture group", +}; + +std::string RegexpStatus::CodeText(enum RegexpStatusCode code) { + if (code < 0 || code >= arraysize(kErrorStrings)) + code = kRegexpInternalError; + return kErrorStrings[code]; +} + +std::string RegexpStatus::Text() const { + if (error_arg_.empty()) + return CodeText(code_); + std::string s; + s.append(CodeText(code_)); + s.append(": "); + s.append(error_arg_.data(), error_arg_.size()); + return s; +} + +void RegexpStatus::Copy(const RegexpStatus &status) { + code_ = status.code_; + error_arg_ = status.error_arg_; +} + +typedef int Ignored; // Walker doesn't exist + +// Walker subclass to count capturing parens in regexp. +class NumCapturesWalker : public Regexp::Walker { +public: + NumCapturesWalker() : ncapture_(0) {} + int ncapture() { return ncapture_; } + + virtual Ignored PreVisit(Regexp *re, Ignored ignored, bool *stop) { + if (re->op() == kRegexpCapture) + ncapture_++; + return ignored; + } + + virtual Ignored ShortVisit(Regexp *re, Ignored ignored) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "NumCapturesWalker::ShortVisit called"; +#endif + return ignored; + } + +private: + int ncapture_; + + NumCapturesWalker(const NumCapturesWalker &) = delete; + NumCapturesWalker &operator=(const NumCapturesWalker &) = delete; +}; + +int Regexp::NumCaptures() { + NumCapturesWalker w; + w.Walk(this, 0); + return w.ncapture(); +} + +// Walker class to build map of named capture groups and their indices. +class NamedCapturesWalker : public Regexp::Walker { +public: + NamedCapturesWalker() : map_(NULL) {} + ~NamedCapturesWalker() { delete map_; } + + std::map *TakeMap() { + std::map *m = map_; + map_ = NULL; + return m; + } + + virtual Ignored PreVisit(Regexp *re, Ignored ignored, bool *stop) { + if (re->op() == kRegexpCapture && re->name() != NULL) { + // Allocate map once we find a name. + if (map_ == NULL) + map_ = new std::map; + + // Record first occurrence of each name. + // (The rule is that if you have the same name + // multiple times, only the leftmost one counts.) + map_->insert({*re->name(), re->cap()}); + } + return ignored; + } + + virtual Ignored ShortVisit(Regexp *re, Ignored ignored) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "NamedCapturesWalker::ShortVisit called"; +#endif + return ignored; + } + +private: + std::map *map_; + + NamedCapturesWalker(const NamedCapturesWalker &) = delete; + NamedCapturesWalker &operator=(const NamedCapturesWalker &) = delete; +}; + +std::map *Regexp::NamedCaptures() { + NamedCapturesWalker w; + w.Walk(this, 0); + return w.TakeMap(); +} + +// Walker class to build map from capture group indices to their names. +class CaptureNamesWalker : public Regexp::Walker { +public: + CaptureNamesWalker() : map_(NULL) {} + ~CaptureNamesWalker() { delete map_; } + + std::map *TakeMap() { + std::map *m = map_; + map_ = NULL; + return m; + } + + virtual Ignored PreVisit(Regexp *re, Ignored ignored, bool *stop) { + if (re->op() == kRegexpCapture && re->name() != NULL) { + // Allocate map once we find a name. + if (map_ == NULL) + map_ = new std::map; + + (*map_)[re->cap()] = *re->name(); + } + return ignored; + } + + virtual Ignored ShortVisit(Regexp *re, Ignored ignored) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "CaptureNamesWalker::ShortVisit called"; +#endif + return ignored; + } + +private: + std::map *map_; + + CaptureNamesWalker(const CaptureNamesWalker &) = delete; + CaptureNamesWalker &operator=(const CaptureNamesWalker &) = delete; +}; + +std::map *Regexp::CaptureNames() { + CaptureNamesWalker w; + w.Walk(this, 0); + return w.TakeMap(); +} + +void ConvertRunesToBytes(bool latin1, Rune *runes, int nrunes, std::string *bytes) { + if (latin1) { + bytes->resize(nrunes); + for (int i = 0; i < nrunes; i++) + (*bytes)[i] = static_cast(runes[i]); + } else { + bytes->resize(nrunes * UTFmax); // worst case + char *p = &(*bytes)[0]; + for (int i = 0; i < nrunes; i++) + p += runetochar(p, &runes[i]); + bytes->resize(p - &(*bytes)[0]); + bytes->shrink_to_fit(); + } +} + +// Determines whether regexp matches must be anchored +// with a fixed string prefix. If so, returns the prefix and +// the regexp that remains after the prefix. The prefix might +// be ASCII case-insensitive. +bool Regexp::RequiredPrefix(std::string *prefix, bool *foldcase, Regexp **suffix) { + prefix->clear(); + *foldcase = false; + *suffix = NULL; + + // No need for a walker: the regexp must be of the form + // 1. some number of ^ anchors + // 2. a literal char or string + // 3. the rest + if (op_ != kRegexpConcat) + return false; + int i = 0; + while (i < nsub_ && sub()[i]->op_ == kRegexpBeginText) + i++; + if (i == 0 || i >= nsub_) + return false; + Regexp *re = sub()[i]; + if (re->op_ != kRegexpLiteral && re->op_ != kRegexpLiteralString) + return false; + i++; + if (i < nsub_) { + for (int j = i; j < nsub_; j++) + sub()[j]->Incref(); + *suffix = Concat(sub() + i, nsub_ - i, parse_flags()); + } else { + *suffix = new Regexp(kRegexpEmptyMatch, parse_flags()); + } + + bool latin1 = (re->parse_flags() & Latin1) != 0; + Rune *runes = re->op_ == kRegexpLiteral ? &re->arguments.rune_ : re->arguments.literal_string.runes_; + int nrunes = re->op_ == kRegexpLiteral ? 1 : re->arguments.literal_string.nrunes_; + ConvertRunesToBytes(latin1, runes, nrunes, prefix); + *foldcase = (re->parse_flags() & FoldCase) != 0; + return true; +} + +// Determines whether regexp matches must be unanchored +// with a fixed string prefix. If so, returns the prefix. +// The prefix might be ASCII case-insensitive. +bool Regexp::RequiredPrefixForAccel(std::string *prefix, bool *foldcase) { + prefix->clear(); + *foldcase = false; + + // No need for a walker: the regexp must either begin with or be + // a literal char or string. We "see through" capturing groups, + // but make no effort to glue multiple prefix fragments together. + Regexp *re = op_ == kRegexpConcat && nsub_ > 0 ? sub()[0] : this; + while (re->op_ == kRegexpCapture) { + re = re->sub()[0]; + if (re->op_ == kRegexpConcat && re->nsub_ > 0) + re = re->sub()[0]; + } + if (re->op_ != kRegexpLiteral && re->op_ != kRegexpLiteralString) + return false; + + bool latin1 = (re->parse_flags() & Latin1) != 0; + Rune *runes = re->op_ == kRegexpLiteral ? &re->arguments.rune_ : re->arguments.literal_string.runes_; + int nrunes = re->op_ == kRegexpLiteral ? 1 : re->arguments.literal_string.nrunes_; + ConvertRunesToBytes(latin1, runes, nrunes, prefix); + *foldcase = (re->parse_flags() & FoldCase) != 0; + return true; +} + +// Character class builder is a balanced binary tree (STL set) +// containing non-overlapping, non-abutting RuneRanges. +// The less-than operator used in the tree treats two +// ranges as equal if they overlap at all, so that +// lookups for a particular Rune are possible. + +CharClassBuilder::CharClassBuilder() { + nrunes_ = 0; + upper_ = 0; + lower_ = 0; +} + +// Add lo-hi to the class; return whether class got bigger. +bool CharClassBuilder::AddRange(Rune lo, Rune hi) { + if (hi < lo) + return false; + + if (lo <= 'z' && hi >= 'A') { + // Overlaps some alpha, maybe not all. + // Update bitmaps telling which ASCII letters are in the set. + Rune lo1 = std::max(lo, 'A'); + Rune hi1 = std::min(hi, 'Z'); + if (lo1 <= hi1) + upper_ |= ((1 << (hi1 - lo1 + 1)) - 1) << (lo1 - 'A'); + + lo1 = std::max(lo, 'a'); + hi1 = std::min(hi, 'z'); + if (lo1 <= hi1) + lower_ |= ((1 << (hi1 - lo1 + 1)) - 1) << (lo1 - 'a'); + } + + { // Check whether lo, hi is already in the class. + iterator it = ranges_.find(RuneRange(lo, lo)); + if (it != end() && it->lo <= lo && hi <= it->hi) + return false; + } + + // Look for a range abutting lo on the left. + // If it exists, take it out and increase our range. + if (lo > 0) { + iterator it = ranges_.find(RuneRange(lo - 1, lo - 1)); + if (it != end()) { + lo = it->lo; + if (it->hi > hi) + hi = it->hi; + nrunes_ -= it->hi - it->lo + 1; + ranges_.erase(it); + } + } + + // Look for a range abutting hi on the right. + // If it exists, take it out and increase our range. + if (hi < Runemax) { + iterator it = ranges_.find(RuneRange(hi + 1, hi + 1)); + if (it != end()) { + hi = it->hi; + nrunes_ -= it->hi - it->lo + 1; + ranges_.erase(it); + } + } + + // Look for ranges between lo and hi. Take them out. + // This is only safe because the set has no overlapping ranges. + // We've already removed any ranges abutting lo and hi, so + // any that overlap [lo, hi] must be contained within it. + for (;;) { + iterator it = ranges_.find(RuneRange(lo, hi)); + if (it == end()) + break; + nrunes_ -= it->hi - it->lo + 1; + ranges_.erase(it); + } + + // Finally, add [lo, hi]. + nrunes_ += hi - lo + 1; + ranges_.insert(RuneRange(lo, hi)); + return true; +} + +void CharClassBuilder::AddCharClass(CharClassBuilder *cc) { + for (iterator it = cc->begin(); it != cc->end(); ++it) + AddRange(it->lo, it->hi); +} + +bool CharClassBuilder::Contains(Rune r) { return ranges_.find(RuneRange(r, r)) != end(); } + +// Does the character class behave the same on A-Z as on a-z? +bool CharClassBuilder::FoldsASCII() { return ((upper_ ^ lower_) & AlphaMask) == 0; } + +CharClassBuilder *CharClassBuilder::Copy() { + CharClassBuilder *cc = new CharClassBuilder; + for (iterator it = begin(); it != end(); ++it) + cc->ranges_.insert(RuneRange(it->lo, it->hi)); + cc->upper_ = upper_; + cc->lower_ = lower_; + cc->nrunes_ = nrunes_; + return cc; +} + +void CharClassBuilder::RemoveAbove(Rune r) { + if (r >= Runemax) + return; + + if (r < 'z') { + if (r < 'a') + lower_ = 0; + else + lower_ &= AlphaMask >> ('z' - r); + } + + if (r < 'Z') { + if (r < 'A') + upper_ = 0; + else + upper_ &= AlphaMask >> ('Z' - r); + } + + for (;;) { + + iterator it = ranges_.find(RuneRange(r + 1, Runemax)); + if (it == end()) + break; + RuneRange rr = *it; + ranges_.erase(it); + nrunes_ -= rr.hi - rr.lo + 1; + if (rr.lo <= r) { + rr.hi = r; + ranges_.insert(rr); + nrunes_ += rr.hi - rr.lo + 1; + } + } +} + +void CharClassBuilder::Negate() { + // Build up negation and then copy in. + // Could edit ranges in place, but C++ won't let me. + std::vector v; + v.reserve(ranges_.size() + 1); + + // In negation, first range begins at 0, unless + // the current class begins at 0. + iterator it = begin(); + if (it == end()) { + v.push_back(RuneRange(0, Runemax)); + } else { + int nextlo = 0; + if (it->lo == 0) { + nextlo = it->hi + 1; + ++it; + } + for (; it != end(); ++it) { + v.push_back(RuneRange(nextlo, it->lo - 1)); + nextlo = it->hi + 1; + } + if (nextlo <= Runemax) + v.push_back(RuneRange(nextlo, Runemax)); + } + + ranges_.clear(); + for (size_t i = 0; i < v.size(); i++) + ranges_.insert(v[i]); + + upper_ = AlphaMask & ~upper_; + lower_ = AlphaMask & ~lower_; + nrunes_ = Runemax + 1 - nrunes_; +} + +// Character class is a sorted list of ranges. +// The ranges are allocated in the same block as the header, +// necessitating a special allocator and Delete method. + +CharClass *CharClass::New(size_t maxranges) { + CharClass *cc; + uint8_t *data = new uint8_t[sizeof *cc + maxranges * sizeof cc->ranges_[0]]; + cc = reinterpret_cast(data); + cc->ranges_ = reinterpret_cast(data + sizeof *cc); + cc->nranges_ = 0; + cc->folds_ascii_ = false; + cc->nrunes_ = 0; + return cc; +} + +void CharClass::Delete() { + uint8_t *data = reinterpret_cast(this); + delete[] data; +} + +CharClass *CharClass::Negate() { + CharClass *cc = CharClass::New(static_cast(nranges_ + 1)); + cc->folds_ascii_ = folds_ascii_; + cc->nrunes_ = Runemax + 1 - nrunes_; + int n = 0; + int nextlo = 0; + for (CharClass::iterator it = begin(); it != end(); ++it) { + if (it->lo == nextlo) { + nextlo = it->hi + 1; + } else { + cc->ranges_[n++] = RuneRange(nextlo, it->lo - 1); + nextlo = it->hi + 1; + } + } + if (nextlo <= Runemax) + cc->ranges_[n++] = RuneRange(nextlo, Runemax); + cc->nranges_ = n; + return cc; +} + +bool CharClass::Contains(Rune r) const { + RuneRange *rr = ranges_; + int n = nranges_; + while (n > 0) { + int m = n / 2; + if (rr[m].hi < r) { + rr += m + 1; + n -= m + 1; + } else if (r < rr[m].lo) { + n = m; + } else { // rr[m].lo <= r && r <= rr[m].hi + return true; + } + } + return false; +} + +CharClass *CharClassBuilder::GetCharClass() { + CharClass *cc = CharClass::New(ranges_.size()); + int n = 0; + for (iterator it = begin(); it != end(); ++it) + cc->ranges_[n++] = *it; + cc->nranges_ = n; + DCHECK_LE(n, static_cast(ranges_.size())); + cc->nrunes_ = nrunes_; + cc->folds_ascii_ = FoldsASCII(); + return cc; +} + +} // namespace re2 diff --git a/internal/cpp/re2/regexp.h b/internal/cpp/re2/regexp.h new file mode 100644 index 00000000000..20155fcf55f --- /dev/null +++ b/internal/cpp/re2/regexp.h @@ -0,0 +1,680 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_REGEXP_H_ +#define RE2_REGEXP_H_ + +// --- SPONSORED LINK -------------------------------------------------- +// If you want to use this library for regular expression matching, +// you should use re2/re2.h, which provides a class RE2 that +// mimics the PCRE interface provided by PCRE's C++ wrappers. +// This header describes the low-level interface used to implement RE2 +// and may change in backwards-incompatible ways from time to time. +// In contrast, RE2's interface will not. +// --------------------------------------------------------------------- + +// Regular expression library: parsing, execution, and manipulation +// of regular expressions. +// +// Any operation that traverses the Regexp structures should be written +// using Regexp::Walker (see walker-inl.h), not recursively, because deeply nested +// regular expressions such as x++++++++++++++++++++... might cause recursive +// traversals to overflow the stack. +// +// It is the caller's responsibility to provide appropriate mutual exclusion +// around manipulation of the regexps. RE2 does this. +// +// PARSING +// +// Regexp::Parse parses regular expressions encoded in UTF-8. +// The default syntax is POSIX extended regular expressions, +// with the following changes: +// +// 1. Backreferences (optional in POSIX EREs) are not supported. +// (Supporting them precludes the use of DFA-based +// matching engines.) +// +// 2. Collating elements and collation classes are not supported. +// (No one has needed or wanted them.) +// +// The exact syntax accepted can be modified by passing flags to +// Regexp::Parse. In particular, many of the basic Perl additions +// are available. The flags are documented below (search for LikePerl). +// +// If parsed with the flag Regexp::Latin1, both the regular expression +// and the input to the matching routines are assumed to be encoded in +// Latin-1, not UTF-8. +// +// EXECUTION +// +// Once Regexp has parsed a regular expression, it provides methods +// to search text using that regular expression. These methods are +// implemented via calling out to other regular expression libraries. +// (Let's call them the sublibraries.) +// +// To call a sublibrary, Regexp does not simply prepare a +// string version of the regular expression and hand it to the +// sublibrary. Instead, Regexp prepares, from its own parsed form, the +// corresponding internal representation used by the sublibrary. +// This has the drawback of needing to know the internal representation +// used by the sublibrary, but it has two important benefits: +// +// 1. The syntax and meaning of regular expressions is guaranteed +// to be that used by Regexp's parser, not the syntax expected +// by the sublibrary. Regexp might accept a restricted or +// expanded syntax for regular expressions as compared with +// the sublibrary. As long as Regexp can translate from its +// internal form into the sublibrary's, clients need not know +// exactly which sublibrary they are using. +// +// 2. The sublibrary parsers are bypassed. For whatever reason, +// sublibrary regular expression parsers often have security +// problems. For example, plan9grep's regular expression parser +// has a buffer overflow in its handling of large character +// classes, and PCRE's parser has had buffer overflow problems +// in the past. Security-team requires sandboxing of sublibrary +// regular expression parsers. Avoiding the sublibrary parsers +// avoids the sandbox. +// +// The execution methods we use now are provided by the compiled form, +// Prog, described in prog.h +// +// MANIPULATION +// +// Unlike other regular expression libraries, Regexp makes its parsed +// form accessible to clients, so that client code can analyze the +// parsed regular expressions. + +#include +#include +#include +#include +#include + +#include "re2/stringpiece.h" +#include "util/logging.h" +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +// Keep in sync with string list kOpcodeNames[] in testing/dump.cc +enum RegexpOp { + // Matches no strings. + kRegexpNoMatch = 1, + + // Matches empty string. + kRegexpEmptyMatch, + + // Matches rune_. + kRegexpLiteral, + + // Matches runes_. + kRegexpLiteralString, + + // Matches concatenation of sub_[0..nsub-1]. + kRegexpConcat, + // Matches union of sub_[0..nsub-1]. + kRegexpAlternate, + + // Matches sub_[0] zero or more times. + kRegexpStar, + // Matches sub_[0] one or more times. + kRegexpPlus, + // Matches sub_[0] zero or one times. + kRegexpQuest, + + // Matches sub_[0] at least min_ times, at most max_ times. + // max_ == -1 means no upper limit. + kRegexpRepeat, + + // Parenthesized (capturing) subexpression. Index is cap_. + // Optionally, capturing name is name_. + kRegexpCapture, + + // Matches any character. + kRegexpAnyChar, + + // Matches any byte [sic]. + kRegexpAnyByte, + + // Matches empty string at beginning of line. + kRegexpBeginLine, + // Matches empty string at end of line. + kRegexpEndLine, + + // Matches word boundary "\b". + kRegexpWordBoundary, + // Matches not-a-word boundary "\B". + kRegexpNoWordBoundary, + + // Matches empty string at beginning of text. + kRegexpBeginText, + // Matches empty string at end of text. + kRegexpEndText, + + // Matches character class given by cc_. + kRegexpCharClass, + + // Forces match of entire expression right now, + // with match ID match_id_ (used by RE2::Set). + kRegexpHaveMatch, + + kMaxRegexpOp = kRegexpHaveMatch, +}; + +// Keep in sync with string list in regexp.cc +enum RegexpStatusCode { + // No error + kRegexpSuccess = 0, + + // Unexpected error + kRegexpInternalError, + + // Parse errors + kRegexpBadEscape, // bad escape sequence + kRegexpBadCharClass, // bad character class + kRegexpBadCharRange, // bad character class range + kRegexpMissingBracket, // missing closing ] + kRegexpMissingParen, // missing closing ) + kRegexpUnexpectedParen, // unexpected closing ) + kRegexpTrailingBackslash, // at end of regexp + kRegexpRepeatArgument, // repeat argument missing, e.g. "*" + kRegexpRepeatSize, // bad repetition argument + kRegexpRepeatOp, // bad repetition operator + kRegexpBadPerlOp, // bad perl operator + kRegexpBadUTF8, // invalid UTF-8 in regexp + kRegexpBadNamedCapture, // bad named capture +}; + +// Error status for certain operations. +class RegexpStatus { +public: + RegexpStatus() : code_(kRegexpSuccess), tmp_(NULL) {} + ~RegexpStatus() { delete tmp_; } + + void set_code(RegexpStatusCode code) { code_ = code; } + void set_error_arg(const StringPiece &error_arg) { error_arg_ = error_arg; } + void set_tmp(std::string *tmp) { + delete tmp_; + tmp_ = tmp; + } + RegexpStatusCode code() const { return code_; } + const StringPiece &error_arg() const { return error_arg_; } + bool ok() const { return code() == kRegexpSuccess; } + + // Copies state from status. + void Copy(const RegexpStatus &status); + + // Returns text equivalent of code, e.g.: + // "Bad character class" + static std::string CodeText(RegexpStatusCode code); + + // Returns text describing error, e.g.: + // "Bad character class: [z-a]" + std::string Text() const; + +private: + RegexpStatusCode code_; // Kind of error + StringPiece error_arg_; // Piece of regexp containing syntax error. + std::string *tmp_; // Temporary storage, possibly where error_arg_ is. + + RegexpStatus(const RegexpStatus &) = delete; + RegexpStatus &operator=(const RegexpStatus &) = delete; +}; + +// Compiled form; see prog.h +class Prog; + +struct RuneRange { + RuneRange() : lo(0), hi(0) {} + RuneRange(int l, int h) : lo(l), hi(h) {} + Rune lo; + Rune hi; +}; + +// Less-than on RuneRanges treats a == b if they overlap at all. +// This lets us look in a set to find the range covering a particular Rune. +struct RuneRangeLess { + bool operator()(const RuneRange &a, const RuneRange &b) const { return a.hi < b.lo; } +}; + +class CharClassBuilder; + +class CharClass { +public: + void Delete(); + + typedef RuneRange *iterator; + iterator begin() { return ranges_; } + iterator end() { return ranges_ + nranges_; } + + int size() { return nrunes_; } + bool empty() { return nrunes_ == 0; } + bool full() { return nrunes_ == Runemax + 1; } + bool FoldsASCII() { return folds_ascii_; } + + bool Contains(Rune r) const; + CharClass *Negate(); + +private: + CharClass(); // not implemented + ~CharClass(); // not implemented + static CharClass *New(size_t maxranges); + + friend class CharClassBuilder; + + bool folds_ascii_; + int nrunes_; + RuneRange *ranges_; + int nranges_; + + CharClass(const CharClass &) = delete; + CharClass &operator=(const CharClass &) = delete; +}; + +class Regexp { +public: + // Flags for parsing. Can be ORed together. + enum ParseFlags { + NoParseFlags = 0, + FoldCase = 1 << 0, // Fold case during matching (case-insensitive). + Literal = 1 << 1, // Treat s as literal string instead of a regexp. + ClassNL = 1 << 2, // Allow char classes like [^a-z] and \D and \s + // and [[:space:]] to match newline. + DotNL = 1 << 3, // Allow . to match newline. + MatchNL = ClassNL | DotNL, + OneLine = 1 << 4, // Treat ^ and $ as only matching at beginning and + // end of text, not around embedded newlines. + // (Perl's default) + Latin1 = 1 << 5, // Regexp and text are in Latin1, not UTF-8. + NonGreedy = 1 << 6, // Repetition operators are non-greedy by default. + PerlClasses = 1 << 7, // Allow Perl character classes like \d. + PerlB = 1 << 8, // Allow Perl's \b and \B. + PerlX = 1 << 9, // Perl extensions: + // non-capturing parens - (?: ) + // non-greedy operators - *? +? ?? {}? + // flag edits - (?i) (?-i) (?i: ) + // i - FoldCase + // m - !OneLine + // s - DotNL + // U - NonGreedy + // line ends: \A \z + // \Q and \E to disable/enable metacharacters + // (?Pexpr) for named captures + // \C to match any single byte + UnicodeGroups = 1 << 10, // Allow \p{Han} for Unicode Han group + // and \P{Han} for its negation. + NeverNL = 1 << 11, // Never match NL, even if the regexp mentions + // it explicitly. + NeverCapture = 1 << 12, // Parse all parens as non-capturing. + + // As close to Perl as we can get. + LikePerl = ClassNL | OneLine | PerlClasses | PerlB | PerlX | UnicodeGroups, + + // Internal use only. + WasDollar = 1 << 13, // on kRegexpEndText: was $ in regexp text + AllParseFlags = (1 << 14) - 1, + }; + + // Get. No set, Regexps are logically immutable once created. + RegexpOp op() { return static_cast(op_); } + int nsub() { return nsub_; } + bool simple() { return simple_ != 0; } + ParseFlags parse_flags() { return static_cast(parse_flags_); } + int Ref(); // For testing. + + Regexp **sub() { + if (nsub_ <= 1) + return &subone_; + else + return submany_; + } + + int min() { + DCHECK_EQ(op_, kRegexpRepeat); + return arguments.repeat.min_; + } + int max() { + DCHECK_EQ(op_, kRegexpRepeat); + return arguments.repeat.max_; + } + Rune rune() { + DCHECK_EQ(op_, kRegexpLiteral); + return arguments.rune_; + } + CharClass *cc() { + DCHECK_EQ(op_, kRegexpCharClass); + return arguments.char_class.cc_; + } + int cap() { + DCHECK_EQ(op_, kRegexpCapture); + return arguments.capture.cap_; + } + const std::string *name() { + DCHECK_EQ(op_, kRegexpCapture); + return arguments.capture.name_; + } + Rune *runes() { + DCHECK_EQ(op_, kRegexpLiteralString); + return arguments.literal_string.runes_; + } + int nrunes() { + DCHECK_EQ(op_, kRegexpLiteralString); + return arguments.literal_string.nrunes_; + } + int match_id() { + DCHECK_EQ(op_, kRegexpHaveMatch); + return arguments.match_id_; + } + + // Increments reference count, returns object as convenience. + Regexp *Incref(); + + // Decrements reference count and deletes this object if count reaches 0. + void Decref(); + + // Parses string s to produce regular expression, returned. + // Caller must release return value with re->Decref(). + // On failure, sets *status (if status != NULL) and returns NULL. + static Regexp *Parse(const StringPiece &s, ParseFlags flags, RegexpStatus *status); + + // Returns a _new_ simplified version of the current regexp. + // Does not edit the current regexp. + // Caller must release return value with re->Decref(). + // Simplified means that counted repetition has been rewritten + // into simpler terms and all Perl/POSIX features have been + // removed. The result will capture exactly the same + // subexpressions the original did, unless formatted with ToString. + Regexp *Simplify(); + friend class CoalesceWalker; + friend class SimplifyWalker; + + // Parses the regexp src and then simplifies it and sets *dst to the + // string representation of the simplified form. Returns true on success. + // Returns false and sets *status (if status != NULL) on parse error. + static bool SimplifyRegexp(const StringPiece &src, ParseFlags flags, std::string *dst, RegexpStatus *status); + + // Returns the number of capturing groups in the regexp. + int NumCaptures(); + friend class NumCapturesWalker; + + // Returns a map from names to capturing group indices, + // or NULL if the regexp contains no named capture groups. + // The caller is responsible for deleting the map. + std::map *NamedCaptures(); + + // Returns a map from capturing group indices to capturing group + // names or NULL if the regexp contains no named capture groups. The + // caller is responsible for deleting the map. + std::map *CaptureNames(); + + // Returns a string representation of the current regexp, + // using as few parentheses as possible. + std::string ToString(); + + // Convenience functions. They consume the passed reference, + // so in many cases you should use, e.g., Plus(re->Incref(), flags). + // They do not consume allocated arrays like subs or runes. + static Regexp *Plus(Regexp *sub, ParseFlags flags); + static Regexp *Star(Regexp *sub, ParseFlags flags); + static Regexp *Quest(Regexp *sub, ParseFlags flags); + static Regexp *Concat(Regexp **subs, int nsubs, ParseFlags flags); + static Regexp *Alternate(Regexp **subs, int nsubs, ParseFlags flags); + static Regexp *Capture(Regexp *sub, ParseFlags flags, int cap); + static Regexp *Repeat(Regexp *sub, ParseFlags flags, int min, int max); + static Regexp *NewLiteral(Rune rune, ParseFlags flags); + static Regexp *NewCharClass(CharClass *cc, ParseFlags flags); + static Regexp *LiteralString(Rune *runes, int nrunes, ParseFlags flags); + static Regexp *HaveMatch(int match_id, ParseFlags flags); + + // Like Alternate but does not factor out common prefixes. + static Regexp *AlternateNoFactor(Regexp **subs, int nsubs, ParseFlags flags); + + // Debugging function. Returns string format for regexp + // that makes structure clear. Does NOT use regexp syntax. + std::string Dump(); + + // Helper traversal class, defined fully in walker-inl.h. + template + class Walker; + + // Compile to Prog. See prog.h + // Reverse prog expects to be run over text backward. + // Construction and execution of prog will + // stay within approximately max_mem bytes of memory. + // If max_mem <= 0, a reasonable default is used. + Prog *CompileToProg(int64_t max_mem); + Prog *CompileToReverseProg(int64_t max_mem); + + // Whether to expect this library to find exactly the same answer as PCRE + // when running this regexp. Most regexps do mimic PCRE exactly, but a few + // obscure cases behave differently. Technically this is more a property + // of the Prog than the Regexp, but the computation is much easier to do + // on the Regexp. See mimics_pcre.cc for the exact conditions. + bool MimicsPCRE(); + + // Benchmarking function. + void NullWalk(); + + // Whether every match of this regexp must be anchored and + // begin with a non-empty fixed string (perhaps after ASCII + // case-folding). If so, returns the prefix and the sub-regexp that + // follows it. + // Callers should expect *prefix, *foldcase and *suffix to be "zeroed" + // regardless of the return value. + bool RequiredPrefix(std::string *prefix, bool *foldcase, Regexp **suffix); + + // Whether every match of this regexp must be unanchored and + // begin with a non-empty fixed string (perhaps after ASCII + // case-folding). If so, returns the prefix. + // Callers should expect *prefix and *foldcase to be "zeroed" + // regardless of the return value. + bool RequiredPrefixForAccel(std::string *prefix, bool *foldcase); + + // Controls the maximum repeat count permitted by the parser. + // FOR FUZZING ONLY. + static void FUZZING_ONLY_set_maximum_repeat_count(int i); + +private: + // Constructor allocates vectors as appropriate for operator. + explicit Regexp(RegexpOp op, ParseFlags parse_flags); + + // Use Decref() instead of delete to release Regexps. + // This is private to catch deletes at compile time. + ~Regexp(); + void Destroy(); + bool QuickDestroy(); + + // Helpers for Parse. Listed here so they can edit Regexps. + class ParseState; + + friend class ParseState; + friend bool ParseCharClass(StringPiece *s, Regexp **out_re, RegexpStatus *status); + + // Helper for testing [sic]. + friend bool RegexpEqualTestingOnly(Regexp *, Regexp *); + + // Computes whether Regexp is already simple. + bool ComputeSimple(); + + // Constructor that generates a Star, Plus or Quest, + // squashing the pair if sub is also a Star, Plus or Quest. + static Regexp *StarPlusOrQuest(RegexpOp op, Regexp *sub, ParseFlags flags); + + // Constructor that generates a concatenation or alternation, + // enforcing the limit on the number of subexpressions for + // a particular Regexp. + static Regexp *ConcatOrAlternate(RegexpOp op, Regexp **subs, int nsubs, ParseFlags flags, bool can_factor); + + // Returns the leading string that re starts with. + // The returned Rune* points into a piece of re, + // so it must not be used after the caller calls re->Decref(). + static Rune *LeadingString(Regexp *re, int *nrune, ParseFlags *flags); + + // Removes the first n leading runes from the beginning of re. + // Edits re in place. + static void RemoveLeadingString(Regexp *re, int n); + + // Returns the leading regexp in re's top-level concatenation. + // The returned Regexp* points at re or a sub-expression of re, + // so it must not be used after the caller calls re->Decref(). + static Regexp *LeadingRegexp(Regexp *re); + + // Removes LeadingRegexp(re) from re and returns the remainder. + // Might edit re in place. + static Regexp *RemoveLeadingRegexp(Regexp *re); + + // Simplifies an alternation of literal strings by factoring out + // common prefixes. + static int FactorAlternation(Regexp **sub, int nsub, ParseFlags flags); + friend class FactorAlternationImpl; + + // Is a == b? Only efficient on regexps that have not been through + // Simplify yet - the expansion of a kRegexpRepeat will make this + // take a long time. Do not call on such regexps, hence private. + static bool Equal(Regexp *a, Regexp *b); + + // Allocate space for n sub-regexps. + void AllocSub(int n) { + DCHECK(n >= 0 && static_cast(n) == n); + if (n > 1) + submany_ = new Regexp *[n]; + nsub_ = static_cast(n); + } + + // Add Rune to LiteralString + void AddRuneToString(Rune r); + + // Swaps this with that, in place. + void Swap(Regexp *that); + + // Operator. See description of operators above. + // uint8_t instead of RegexpOp to control space usage. + uint8_t op_; + + // Is this regexp structure already simple + // (has it been returned by Simplify)? + // uint8_t instead of bool to control space usage. + uint8_t simple_; + + // Flags saved from parsing and used during execution. + // (Only FoldCase is used.) + // uint16_t instead of ParseFlags to control space usage. + uint16_t parse_flags_; + + // Reference count. Exists so that SimplifyRegexp can build + // regexp structures that are dags rather than trees to avoid + // exponential blowup in space requirements. + // uint16_t to control space usage. + // The standard regexp routines will never generate a + // ref greater than the maximum repeat count (kMaxRepeat), + // but even so, Incref and Decref consult an overflow map + // when ref_ reaches kMaxRef. + uint16_t ref_; + static const uint16_t kMaxRef = 0xffff; + + // Subexpressions. + // uint16_t to control space usage. + // Concat and Alternate handle larger numbers of subexpressions + // by building concatenation or alternation trees. + // Other routines should call Concat or Alternate instead of + // filling in sub() by hand. + uint16_t nsub_; + static const uint16_t kMaxNsub = 0xffff; + union { + Regexp **submany_; // if nsub_ > 1 + Regexp *subone_; // if nsub_ == 1 + }; + + // Extra space for parse and teardown stacks. + Regexp *down_; + + // Arguments to operator. See description of operators above. + union { + struct { // Repeat + int max_; + int min_; + } repeat; + struct { // Capture + int cap_; + std::string *name_; + } capture; + struct { // LiteralString + int nrunes_; + Rune *runes_; + } literal_string; + struct { // CharClass + // These two could be in separate union members, + // but it wouldn't save any space (there are other two-word structs) + // and keeping them separate avoids confusion during parsing. + CharClass *cc_; + CharClassBuilder *ccb_; + } char_class; + Rune rune_; // Literal + int match_id_; // HaveMatch + void *the_union_[2]; // as big as any other element, for memset + } arguments; + + Regexp(const Regexp &) = delete; + Regexp &operator=(const Regexp &) = delete; +}; + +// Character class set: contains non-overlapping, non-abutting RuneRanges. +typedef std::set RuneRangeSet; + +class CharClassBuilder { +public: + CharClassBuilder(); + + typedef RuneRangeSet::iterator iterator; + iterator begin() { return ranges_.begin(); } + iterator end() { return ranges_.end(); } + + int size() { return nrunes_; } + bool empty() { return nrunes_ == 0; } + bool full() { return nrunes_ == Runemax + 1; } + + bool Contains(Rune r); + bool FoldsASCII(); + bool AddRange(Rune lo, Rune hi); // returns whether class changed + CharClassBuilder *Copy(); + void AddCharClass(CharClassBuilder *cc); + void Negate(); + void RemoveAbove(Rune r); + CharClass *GetCharClass(); + void AddRangeFlags(Rune lo, Rune hi, Regexp::ParseFlags parse_flags); + +private: + static const uint32_t AlphaMask = (1 << 26) - 1; + uint32_t upper_; // bitmap of A-Z + uint32_t lower_; // bitmap of a-z + int nrunes_; + RuneRangeSet ranges_; + + CharClassBuilder(const CharClassBuilder &) = delete; + CharClassBuilder &operator=(const CharClassBuilder &) = delete; +}; + +// Bitwise ops on ParseFlags produce ParseFlags. +inline Regexp::ParseFlags operator|(Regexp::ParseFlags a, Regexp::ParseFlags b) { + return static_cast(static_cast(a) | static_cast(b)); +} + +inline Regexp::ParseFlags operator^(Regexp::ParseFlags a, Regexp::ParseFlags b) { + return static_cast(static_cast(a) ^ static_cast(b)); +} + +inline Regexp::ParseFlags operator&(Regexp::ParseFlags a, Regexp::ParseFlags b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +inline Regexp::ParseFlags operator~(Regexp::ParseFlags a) { + // Attempting to produce a value out of enum's range has undefined behaviour. + return static_cast(~static_cast(a) & static_cast(Regexp::AllParseFlags)); +} + +} // namespace re2 + +#endif // RE2_REGEXP_H_ diff --git a/internal/cpp/re2/set.cc b/internal/cpp/re2/set.cc new file mode 100644 index 00000000000..84e013f9c63 --- /dev/null +++ b/internal/cpp/re2/set.cc @@ -0,0 +1,159 @@ +// Copyright 2010 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/set.h" + +#include +#include +#include +#include + +#include "re2/pod_array.h" +#include "re2/prog.h" +#include "re2/re2.h" +#include "re2/regexp.h" +#include "re2/stringpiece.h" +#include "util/logging.h" +#include "util/util.h" + +namespace re2 { + +RE2::Set::Set(const RE2::Options &options, RE2::Anchor anchor) : options_(options), anchor_(anchor), compiled_(false), size_(0) { + options_.set_never_capture(true); // might unblock some optimisations +} + +RE2::Set::~Set() { + for (size_t i = 0; i < elem_.size(); i++) + elem_[i].second->Decref(); +} + +RE2::Set::Set(Set &&other) + : options_(other.options_), anchor_(other.anchor_), elem_(std::move(other.elem_)), compiled_(other.compiled_), size_(other.size_), + prog_(std::move(other.prog_)) { + other.elem_.clear(); + other.elem_.shrink_to_fit(); + other.compiled_ = false; + other.size_ = 0; + other.prog_.reset(); +} + +RE2::Set &RE2::Set::operator=(Set &&other) { + this->~Set(); + (void)new (this) Set(std::move(other)); + return *this; +} + +int RE2::Set::Add(const StringPiece &pattern, std::string *error) { + if (compiled_) { + LOG(DFATAL) << "RE2::Set::Add() called after compiling"; + return -1; + } + + Regexp::ParseFlags pf = static_cast(options_.ParseFlags()); + RegexpStatus status; + re2::Regexp *re = Regexp::Parse(pattern, pf, &status); + if (re == NULL) { + if (error != NULL) + *error = status.Text(); + if (options_.log_errors()) + LOG(ERROR) << "Error parsing '" << pattern << "': " << status.Text(); + return -1; + } + + // Concatenate with match index and push on vector. + int n = static_cast(elem_.size()); + re2::Regexp *m = re2::Regexp::HaveMatch(n, pf); + if (re->op() == kRegexpConcat) { + int nsub = re->nsub(); + PODArray sub(nsub + 1); + for (int i = 0; i < nsub; i++) + sub[i] = re->sub()[i]->Incref(); + sub[nsub] = m; + re->Decref(); + re = re2::Regexp::Concat(sub.data(), nsub + 1, pf); + } else { + re2::Regexp *sub[2]; + sub[0] = re; + sub[1] = m; + re = re2::Regexp::Concat(sub, 2, pf); + } + elem_.emplace_back(std::string(pattern), re); + return n; +} + +bool RE2::Set::Compile() { + if (compiled_) { + LOG(DFATAL) << "RE2::Set::Compile() called more than once"; + return false; + } + compiled_ = true; + size_ = static_cast(elem_.size()); + + // Sort the elements by their patterns. This is good enough for now + // until we have a Regexp comparison function. (Maybe someday...) + std::sort(elem_.begin(), elem_.end(), [](const Elem &a, const Elem &b) -> bool { return a.first < b.first; }); + + PODArray sub(size_); + for (int i = 0; i < size_; i++) + sub[i] = elem_[i].second; + elem_.clear(); + elem_.shrink_to_fit(); + + Regexp::ParseFlags pf = static_cast(options_.ParseFlags()); + re2::Regexp *re = re2::Regexp::Alternate(sub.data(), size_, pf); + + prog_.reset(Prog::CompileSet(re, anchor_, options_.max_mem())); + re->Decref(); + return prog_ != nullptr; +} + +bool RE2::Set::Match(const StringPiece &text, std::vector *v) const { return Match(text, v, NULL); } + +bool RE2::Set::Match(const StringPiece &text, std::vector *v, ErrorInfo *error_info) const { + if (!compiled_) { + if (error_info != NULL) + error_info->kind = kNotCompiled; + LOG(DFATAL) << "RE2::Set::Match() called before compiling"; + return false; + } +#ifdef RE2_HAVE_THREAD_LOCAL + hooks::context = NULL; +#endif + bool dfa_failed = false; + std::unique_ptr matches; + if (v != NULL) { + matches.reset(new SparseSet(size_)); + v->clear(); + } + bool ret = prog_->SearchDFA(text, text, Prog::kAnchored, Prog::kManyMatch, NULL, &dfa_failed, matches.get()); + if (dfa_failed) { + if (options_.log_errors()) + LOG(ERROR) << "DFA out of memory: " + << "program size " << prog_->size() << ", " + << "list count " << prog_->list_count() << ", " + << "bytemap range " << prog_->bytemap_range(); + if (error_info != NULL) + error_info->kind = kOutOfMemory; + return false; + } + if (ret == false) { + if (error_info != NULL) + error_info->kind = kNoError; + return false; + } + if (v != NULL) { + if (matches->empty()) { + if (error_info != NULL) + error_info->kind = kInconsistent; + LOG(DFATAL) << "RE2::Set::Match() matched, but no matches returned?!"; + return false; + } + v->assign(matches->begin(), matches->end()); + } + if (error_info != NULL) + error_info->kind = kNoError; + return true; +} + +} // namespace re2 diff --git a/internal/cpp/re2/set.h b/internal/cpp/re2/set.h new file mode 100644 index 00000000000..f57443d6a14 --- /dev/null +++ b/internal/cpp/re2/set.h @@ -0,0 +1,84 @@ +// Copyright 2010 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_SET_H_ +#define RE2_SET_H_ + +#include +#include +#include +#include + +#include "re2/re2.h" + +namespace re2 { +class Prog; +class Regexp; +} // namespace re2 + +namespace re2 { + +// An RE2::Set represents a collection of regexps that can +// be searched for simultaneously. +class RE2::Set { +public: + enum ErrorKind { + kNoError = 0, + kNotCompiled, // The set is not compiled. + kOutOfMemory, // The DFA ran out of memory. + kInconsistent, // The result is inconsistent. This should never happen. + }; + + struct ErrorInfo { + ErrorKind kind; + }; + + Set(const RE2::Options &options, RE2::Anchor anchor); + ~Set(); + + // Not copyable. + Set(const Set &) = delete; + Set &operator=(const Set &) = delete; + // Movable. + Set(Set &&other); + Set &operator=(Set &&other); + + // Adds pattern to the set using the options passed to the constructor. + // Returns the index that will identify the regexp in the output of Match(), + // or -1 if the regexp cannot be parsed. + // Indices are assigned in sequential order starting from 0. + // Errors do not increment the index; if error is not NULL, *error will hold + // the error message from the parser. + int Add(const StringPiece &pattern, std::string *error); + + // Compiles the set in preparation for matching. + // Returns false if the compiler runs out of memory. + // Add() must not be called again after Compile(). + // Compile() must be called before Match(). + bool Compile(); + + // Returns true if text matches at least one of the regexps in the set. + // Fills v (if not NULL) with the indices of the matching regexps. + // Callers must not expect v to be sorted. + bool Match(const StringPiece &text, std::vector *v) const; + + // As above, but populates error_info (if not NULL) when none of the regexps + // in the set matched. This can inform callers when DFA execution fails, for + // example, because they might wish to handle that case differently. + bool Match(const StringPiece &text, std::vector *v, ErrorInfo *error_info) const; + +private: + typedef std::pair Elem; + + RE2::Options options_; + RE2::Anchor anchor_; + std::vector elem_; + bool compiled_; + int size_; + std::unique_ptr prog_; +}; + +} // namespace re2 + +#endif // RE2_SET_H_ diff --git a/internal/cpp/re2/simplify.cc b/internal/cpp/re2/simplify.cc new file mode 100644 index 00000000000..cbc7edb380a --- /dev/null +++ b/internal/cpp/re2/simplify.cc @@ -0,0 +1,629 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Rewrite POSIX and other features in re +// to use simple extended regular expression features. +// Also sort and simplify character classes. + +#include + +#include "re2/pod_array.h" +#include "re2/regexp.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +// Parses the regexp src and then simplifies it and sets *dst to the +// string representation of the simplified form. Returns true on success. +// Returns false and sets *error (if error != NULL) on error. +bool Regexp::SimplifyRegexp(const StringPiece &src, ParseFlags flags, std::string *dst, RegexpStatus *status) { + Regexp *re = Parse(src, flags, status); + if (re == NULL) + return false; + Regexp *sre = re->Simplify(); + re->Decref(); + if (sre == NULL) { + if (status) { + status->set_code(kRegexpInternalError); + status->set_error_arg(src); + } + return false; + } + *dst = sre->ToString(); + sre->Decref(); + return true; +} + +// Assuming the simple_ flags on the children are accurate, +// is this Regexp* simple? +bool Regexp::ComputeSimple() { + Regexp **subs; + switch (op_) { + case kRegexpNoMatch: + case kRegexpEmptyMatch: + case kRegexpLiteral: + case kRegexpLiteralString: + case kRegexpBeginLine: + case kRegexpEndLine: + case kRegexpBeginText: + case kRegexpWordBoundary: + case kRegexpNoWordBoundary: + case kRegexpEndText: + case kRegexpAnyChar: + case kRegexpAnyByte: + case kRegexpHaveMatch: + return true; + case kRegexpConcat: + case kRegexpAlternate: + // These are simple as long as the subpieces are simple. + subs = sub(); + for (int i = 0; i < nsub_; i++) + if (!subs[i]->simple()) + return false; + return true; + case kRegexpCharClass: + // Simple as long as the char class is not empty, not full. + if (arguments.char_class.ccb_ != NULL) + return !arguments.char_class.ccb_->empty() && !arguments.char_class.ccb_->full(); + return !arguments.char_class.cc_->empty() && !arguments.char_class.cc_->full(); + case kRegexpCapture: + subs = sub(); + return subs[0]->simple(); + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + subs = sub(); + if (!subs[0]->simple()) + return false; + switch (subs[0]->op_) { + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + case kRegexpEmptyMatch: + case kRegexpNoMatch: + return false; + default: + break; + } + return true; + case kRegexpRepeat: + return false; + } + LOG(DFATAL) << "Case not handled in ComputeSimple: " << op_; + return false; +} + +// Walker subclass used by Simplify. +// Coalesces runs of star/plus/quest/repeat of the same literal along with any +// occurrences of that literal into repeats of that literal. It also works for +// char classes, any char and any byte. +// PostVisit creates the coalesced result, which should then be simplified. +class CoalesceWalker : public Regexp::Walker { +public: + CoalesceWalker() {} + virtual Regexp *PostVisit(Regexp *re, Regexp *parent_arg, Regexp *pre_arg, Regexp **child_args, int nchild_args); + virtual Regexp *Copy(Regexp *re); + virtual Regexp *ShortVisit(Regexp *re, Regexp *parent_arg); + +private: + // These functions are declared inside CoalesceWalker so that + // they can edit the private fields of the Regexps they construct. + + // Returns true if r1 and r2 can be coalesced. In particular, ensures that + // the parse flags are consistent. (They will not be checked again later.) + static bool CanCoalesce(Regexp *r1, Regexp *r2); + + // Coalesces *r1ptr and *r2ptr. In most cases, the array elements afterwards + // will be empty match and the coalesced op. In other cases, where part of a + // literal string was removed to be coalesced, the array elements afterwards + // will be the coalesced op and the remainder of the literal string. + static void DoCoalesce(Regexp **r1ptr, Regexp **r2ptr); + + CoalesceWalker(const CoalesceWalker &) = delete; + CoalesceWalker &operator=(const CoalesceWalker &) = delete; +}; + +// Walker subclass used by Simplify. +// The simplify walk is purely post-recursive: given the simplified children, +// PostVisit creates the simplified result. +// The child_args are simplified Regexp*s. +class SimplifyWalker : public Regexp::Walker { +public: + SimplifyWalker() {} + virtual Regexp *PreVisit(Regexp *re, Regexp *parent_arg, bool *stop); + virtual Regexp *PostVisit(Regexp *re, Regexp *parent_arg, Regexp *pre_arg, Regexp **child_args, int nchild_args); + virtual Regexp *Copy(Regexp *re); + virtual Regexp *ShortVisit(Regexp *re, Regexp *parent_arg); + +private: + // These functions are declared inside SimplifyWalker so that + // they can edit the private fields of the Regexps they construct. + + // Creates a concatenation of two Regexp, consuming refs to re1 and re2. + // Caller must Decref return value when done with it. + static Regexp *Concat2(Regexp *re1, Regexp *re2, Regexp::ParseFlags flags); + + // Simplifies the expression re{min,max} in terms of *, +, and ?. + // Returns a new regexp. Does not edit re. Does not consume reference to re. + // Caller must Decref return value when done with it. + static Regexp *SimplifyRepeat(Regexp *re, int min, int max, Regexp::ParseFlags parse_flags); + + // Simplifies a character class by expanding any named classes + // into rune ranges. Does not edit re. Does not consume ref to re. + // Caller must Decref return value when done with it. + static Regexp *SimplifyCharClass(Regexp *re); + + SimplifyWalker(const SimplifyWalker &) = delete; + SimplifyWalker &operator=(const SimplifyWalker &) = delete; +}; + +// Simplifies a regular expression, returning a new regexp. +// The new regexp uses traditional Unix egrep features only, +// plus the Perl (?:) non-capturing parentheses. +// Otherwise, no POSIX or Perl additions. The new regexp +// captures exactly the same subexpressions (with the same indices) +// as the original. +// Does not edit current object. +// Caller must Decref() return value when done with it. + +Regexp *Regexp::Simplify() { + CoalesceWalker cw; + Regexp *cre = cw.Walk(this, NULL); + if (cre == NULL) + return NULL; + if (cw.stopped_early()) { + cre->Decref(); + return NULL; + } + SimplifyWalker sw; + Regexp *sre = sw.Walk(cre, NULL); + cre->Decref(); + if (sre == NULL) + return NULL; + if (sw.stopped_early()) { + sre->Decref(); + return NULL; + } + return sre; +} + +#define Simplify DontCallSimplify // Avoid accidental recursion + +// Utility function for PostVisit implementations that compares re->sub() with +// child_args to determine whether any child_args changed. In the common case, +// where nothing changed, calls Decref() for all child_args and returns false, +// so PostVisit must return re->Incref(). Otherwise, returns true. +static bool ChildArgsChanged(Regexp *re, Regexp **child_args) { + for (int i = 0; i < re->nsub(); i++) { + Regexp *sub = re->sub()[i]; + Regexp *newsub = child_args[i]; + if (newsub != sub) + return true; + } + for (int i = 0; i < re->nsub(); i++) { + Regexp *newsub = child_args[i]; + newsub->Decref(); + } + return false; +} + +Regexp *CoalesceWalker::Copy(Regexp *re) { return re->Incref(); } + +Regexp *CoalesceWalker::ShortVisit(Regexp *re, Regexp *parent_arg) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "CoalesceWalker::ShortVisit called"; +#endif + return re->Incref(); +} + +Regexp *CoalesceWalker::PostVisit(Regexp *re, Regexp *parent_arg, Regexp *pre_arg, Regexp **child_args, int nchild_args) { + if (re->nsub() == 0) + return re->Incref(); + + if (re->op() != kRegexpConcat) { + if (!ChildArgsChanged(re, child_args)) + return re->Incref(); + + // Something changed. Build a new op. + Regexp *nre = new Regexp(re->op(), re->parse_flags()); + nre->AllocSub(re->nsub()); + Regexp **nre_subs = nre->sub(); + for (int i = 0; i < re->nsub(); i++) + nre_subs[i] = child_args[i]; + // Repeats and Captures have additional data that must be copied. + if (re->op() == kRegexpRepeat) { + nre->arguments.repeat.min_ = re->min(); + nre->arguments.repeat.max_ = re->max(); + } else if (re->op() == kRegexpCapture) { + nre->arguments.capture.cap_ = re->cap(); + } + return nre; + } + + bool can_coalesce = false; + for (int i = 0; i < re->nsub(); i++) { + if (i + 1 < re->nsub() && CanCoalesce(child_args[i], child_args[i + 1])) { + can_coalesce = true; + break; + } + } + if (!can_coalesce) { + if (!ChildArgsChanged(re, child_args)) + return re->Incref(); + + // Something changed. Build a new op. + Regexp *nre = new Regexp(re->op(), re->parse_flags()); + nre->AllocSub(re->nsub()); + Regexp **nre_subs = nre->sub(); + for (int i = 0; i < re->nsub(); i++) + nre_subs[i] = child_args[i]; + return nre; + } + + for (int i = 0; i < re->nsub(); i++) { + if (i + 1 < re->nsub() && CanCoalesce(child_args[i], child_args[i + 1])) + DoCoalesce(&child_args[i], &child_args[i + 1]); + } + // Determine how many empty matches were left by DoCoalesce. + int n = 0; + for (int i = n; i < re->nsub(); i++) { + if (child_args[i]->op() == kRegexpEmptyMatch) + n++; + } + // Build a new op. + Regexp *nre = new Regexp(re->op(), re->parse_flags()); + nre->AllocSub(re->nsub() - n); + Regexp **nre_subs = nre->sub(); + for (int i = 0, j = 0; i < re->nsub(); i++) { + if (child_args[i]->op() == kRegexpEmptyMatch) { + child_args[i]->Decref(); + continue; + } + nre_subs[j] = child_args[i]; + j++; + } + return nre; +} + +bool CoalesceWalker::CanCoalesce(Regexp *r1, Regexp *r2) { + // r1 must be a star/plus/quest/repeat of a literal, char class, any char or + // any byte. + if ((r1->op() == kRegexpStar || r1->op() == kRegexpPlus || r1->op() == kRegexpQuest || r1->op() == kRegexpRepeat) && + (r1->sub()[0]->op() == kRegexpLiteral || r1->sub()[0]->op() == kRegexpCharClass || r1->sub()[0]->op() == kRegexpAnyChar || + r1->sub()[0]->op() == kRegexpAnyByte)) { + // r2 must be a star/plus/quest/repeat of the same literal, char class, + // any char or any byte. + if ((r2->op() == kRegexpStar || r2->op() == kRegexpPlus || r2->op() == kRegexpQuest || r2->op() == kRegexpRepeat) && + Regexp::Equal(r1->sub()[0], r2->sub()[0]) && + // The parse flags must be consistent. + ((r1->parse_flags() & Regexp::NonGreedy) == (r2->parse_flags() & Regexp::NonGreedy))) { + return true; + } + // ... OR an occurrence of that literal, char class, any char or any byte + if (Regexp::Equal(r1->sub()[0], r2)) { + return true; + } + // ... OR a literal string that begins with that literal. + if (r1->sub()[0]->op() == kRegexpLiteral && r2->op() == kRegexpLiteralString && r2->runes()[0] == r1->sub()[0]->rune() && + // The parse flags must be consistent. + ((r1->sub()[0]->parse_flags() & Regexp::FoldCase) == (r2->parse_flags() & Regexp::FoldCase))) { + return true; + } + } + return false; +} + +void CoalesceWalker::DoCoalesce(Regexp **r1ptr, Regexp **r2ptr) { + Regexp *r1 = *r1ptr; + Regexp *r2 = *r2ptr; + + Regexp *nre = Regexp::Repeat(r1->sub()[0]->Incref(), r1->parse_flags(), 0, 0); + + switch (r1->op()) { + case kRegexpStar: + nre->arguments.repeat.min_ = 0; + nre->arguments.repeat.max_ = -1; + break; + + case kRegexpPlus: + nre->arguments.repeat.min_ = 1; + nre->arguments.repeat.max_ = -1; + break; + + case kRegexpQuest: + nre->arguments.repeat.min_ = 0; + nre->arguments.repeat.max_ = 1; + break; + + case kRegexpRepeat: + nre->arguments.repeat.min_ = r1->min(); + nre->arguments.repeat.max_ = r1->max(); + break; + + default: + nre->Decref(); + LOG(DFATAL) << "DoCoalesce failed: r1->op() is " << r1->op(); + return; + } + + switch (r2->op()) { + case kRegexpStar: + nre->arguments.repeat.max_ = -1; + goto LeaveEmpty; + + case kRegexpPlus: + nre->arguments.repeat.min_++; + nre->arguments.repeat.max_ = -1; + goto LeaveEmpty; + + case kRegexpQuest: + if (nre->max() != -1) + nre->arguments.repeat.max_++; + goto LeaveEmpty; + + case kRegexpRepeat: + nre->arguments.repeat.min_ += r2->min(); + if (r2->max() == -1) + nre->arguments.repeat.max_ = -1; + else if (nre->max() != -1) + nre->arguments.repeat.max_ += r2->max(); + goto LeaveEmpty; + + case kRegexpLiteral: + case kRegexpCharClass: + case kRegexpAnyChar: + case kRegexpAnyByte: + nre->arguments.repeat.min_++; + if (nre->max() != -1) + nre->arguments.repeat.max_++; + goto LeaveEmpty; + + LeaveEmpty: + *r1ptr = new Regexp(kRegexpEmptyMatch, Regexp::NoParseFlags); + *r2ptr = nre; + break; + + case kRegexpLiteralString: { + Rune r = r1->sub()[0]->rune(); + // Determine how much of the literal string is removed. + // We know that we have at least one rune. :) + int n = 1; + while (n < r2->nrunes() && r2->runes()[n] == r) + n++; + nre->arguments.repeat.min_ += n; + if (nre->max() != -1) + nre->arguments.repeat.max_ += n; + if (n == r2->nrunes()) + goto LeaveEmpty; + *r1ptr = nre; + *r2ptr = Regexp::LiteralString(&r2->runes()[n], r2->nrunes() - n, r2->parse_flags()); + break; + } + + default: + nre->Decref(); + LOG(DFATAL) << "DoCoalesce failed: r2->op() is " << r2->op(); + return; + } + + r1->Decref(); + r2->Decref(); +} + +Regexp *SimplifyWalker::Copy(Regexp *re) { return re->Incref(); } + +Regexp *SimplifyWalker::ShortVisit(Regexp *re, Regexp *parent_arg) { + // Should never be called: we use Walk(), not WalkExponential(). +#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + LOG(DFATAL) << "SimplifyWalker::ShortVisit called"; +#endif + return re->Incref(); +} + +Regexp *SimplifyWalker::PreVisit(Regexp *re, Regexp *parent_arg, bool *stop) { + if (re->simple()) { + *stop = true; + return re->Incref(); + } + return NULL; +} + +Regexp *SimplifyWalker::PostVisit(Regexp *re, Regexp *parent_arg, Regexp *pre_arg, Regexp **child_args, int nchild_args) { + switch (re->op()) { + case kRegexpNoMatch: + case kRegexpEmptyMatch: + case kRegexpLiteral: + case kRegexpLiteralString: + case kRegexpBeginLine: + case kRegexpEndLine: + case kRegexpBeginText: + case kRegexpWordBoundary: + case kRegexpNoWordBoundary: + case kRegexpEndText: + case kRegexpAnyChar: + case kRegexpAnyByte: + case kRegexpHaveMatch: + // All these are always simple. + re->simple_ = true; + return re->Incref(); + + case kRegexpConcat: + case kRegexpAlternate: { + // These are simple as long as the subpieces are simple. + if (!ChildArgsChanged(re, child_args)) { + re->simple_ = true; + return re->Incref(); + } + Regexp *nre = new Regexp(re->op(), re->parse_flags()); + nre->AllocSub(re->nsub()); + Regexp **nre_subs = nre->sub(); + for (int i = 0; i < re->nsub(); i++) + nre_subs[i] = child_args[i]; + nre->simple_ = true; + return nre; + } + + case kRegexpCapture: { + Regexp *newsub = child_args[0]; + if (newsub == re->sub()[0]) { + newsub->Decref(); + re->simple_ = true; + return re->Incref(); + } + Regexp *nre = new Regexp(kRegexpCapture, re->parse_flags()); + nre->AllocSub(1); + nre->sub()[0] = newsub; + nre->arguments.capture.cap_ = re->cap(); + nre->simple_ = true; + return nre; + } + + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: { + Regexp *newsub = child_args[0]; + // Special case: repeat the empty string as much as + // you want, but it's still the empty string. + if (newsub->op() == kRegexpEmptyMatch) + return newsub; + + // These are simple as long as the subpiece is simple. + if (newsub == re->sub()[0]) { + newsub->Decref(); + re->simple_ = true; + return re->Incref(); + } + + // These are also idempotent if flags are constant. + if (re->op() == newsub->op() && re->parse_flags() == newsub->parse_flags()) + return newsub; + + Regexp *nre = new Regexp(re->op(), re->parse_flags()); + nre->AllocSub(1); + nre->sub()[0] = newsub; + nre->simple_ = true; + return nre; + } + + case kRegexpRepeat: { + Regexp *newsub = child_args[0]; + // Special case: repeat the empty string as much as + // you want, but it's still the empty string. + if (newsub->op() == kRegexpEmptyMatch) + return newsub; + + Regexp *nre = SimplifyRepeat(newsub, re->arguments.repeat.min_, re->arguments.repeat.max_, re->parse_flags()); + newsub->Decref(); + nre->simple_ = true; + return nre; + } + + case kRegexpCharClass: { + Regexp *nre = SimplifyCharClass(re); + nre->simple_ = true; + return nre; + } + } + + LOG(ERROR) << "Simplify case not handled: " << re->op(); + return re->Incref(); +} + +// Creates a concatenation of two Regexp, consuming refs to re1 and re2. +// Returns a new Regexp, handing the ref to the caller. +Regexp *SimplifyWalker::Concat2(Regexp *re1, Regexp *re2, Regexp::ParseFlags parse_flags) { + Regexp *re = new Regexp(kRegexpConcat, parse_flags); + re->AllocSub(2); + Regexp **subs = re->sub(); + subs[0] = re1; + subs[1] = re2; + return re; +} + +// Simplifies the expression re{min,max} in terms of *, +, and ?. +// Returns a new regexp. Does not edit re. Does not consume reference to re. +// Caller must Decref return value when done with it. +// The result will *not* necessarily have the right capturing parens +// if you call ToString() and re-parse it: (x){2} becomes (x)(x), +// but in the Regexp* representation, both (x) are marked as $1. +Regexp *SimplifyWalker::SimplifyRepeat(Regexp *re, int min, int max, Regexp::ParseFlags f) { + // x{n,} means at least n matches of x. + if (max == -1) { + // Special case: x{0,} is x* + if (min == 0) + return Regexp::Star(re->Incref(), f); + + // Special case: x{1,} is x+ + if (min == 1) + return Regexp::Plus(re->Incref(), f); + + // General case: x{4,} is xxxx+ + PODArray nre_subs(min); + for (int i = 0; i < min - 1; i++) + nre_subs[i] = re->Incref(); + nre_subs[min - 1] = Regexp::Plus(re->Incref(), f); + return Regexp::Concat(nre_subs.data(), min, f); + } + + // Special case: (x){0} matches only empty string. + if (min == 0 && max == 0) + return new Regexp(kRegexpEmptyMatch, f); + + // Special case: x{1} is just x. + if (min == 1 && max == 1) + return re->Incref(); + + // General case: x{n,m} means n copies of x and m copies of x?. + // The machine will do less work if we nest the final m copies, + // so that x{2,5} = xx(x(x(x)?)?)? + + // Build leading prefix: xx. Capturing only on the last one. + Regexp *nre = NULL; + if (min > 0) { + PODArray nre_subs(min); + for (int i = 0; i < min; i++) + nre_subs[i] = re->Incref(); + nre = Regexp::Concat(nre_subs.data(), min, f); + } + + // Build and attach suffix: (x(x(x)?)?)? + if (max > min) { + Regexp *suf = Regexp::Quest(re->Incref(), f); + for (int i = min + 1; i < max; i++) + suf = Regexp::Quest(Concat2(re->Incref(), suf, f), f); + if (nre == NULL) + nre = suf; + else + nre = Concat2(nre, suf, f); + } + + if (nre == NULL) { + // Some degenerate case, like min > max, or min < max < 0. + // This shouldn't happen, because the parser rejects such regexps. + LOG(DFATAL) << "Malformed repeat " << re->ToString() << " " << min << " " << max; + return new Regexp(kRegexpNoMatch, f); + } + + return nre; +} + +// Simplifies a character class. +// Caller must Decref return value when done with it. +Regexp *SimplifyWalker::SimplifyCharClass(Regexp *re) { + CharClass *cc = re->cc(); + + // Special cases + if (cc->empty()) + return new Regexp(kRegexpNoMatch, re->parse_flags()); + if (cc->full()) + return new Regexp(kRegexpAnyChar, re->parse_flags()); + + return re->Incref(); +} + +} // namespace re2 diff --git a/internal/cpp/re2/sparse_array.h b/internal/cpp/re2/sparse_array.h new file mode 100644 index 00000000000..02023ecbdd8 --- /dev/null +++ b/internal/cpp/re2/sparse_array.h @@ -0,0 +1,367 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_SPARSE_ARRAY_H_ +#define RE2_SPARSE_ARRAY_H_ + +// DESCRIPTION +// +// SparseArray(m) is a map from integers in [0, m) to T values. +// It requires (sizeof(T)+sizeof(int))*m memory, but it provides +// fast iteration through the elements in the array and fast clearing +// of the array. The array has a concept of certain elements being +// uninitialized (having no value). +// +// Insertion and deletion are constant time operations. +// +// Allocating the array is a constant time operation +// when memory allocation is a constant time operation. +// +// Clearing the array is a constant time operation (unusual!). +// +// Iterating through the array is an O(n) operation, where n +// is the number of items in the array (not O(m)). +// +// The array iterator visits entries in the order they were first +// inserted into the array. It is safe to add items to the array while +// using an iterator: the iterator will visit indices added to the array +// during the iteration, but will not re-visit indices whose values +// change after visiting. Thus SparseArray can be a convenient +// implementation of a work queue. +// +// The SparseArray implementation is NOT thread-safe. It is up to the +// caller to make sure only one thread is accessing the array. (Typically +// these arrays are temporary values and used in situations where speed is +// important.) +// +// The SparseArray interface does not present all the usual STL bells and +// whistles. +// +// Implemented with reference to Briggs & Torczon, An Efficient +// Representation for Sparse Sets, ACM Letters on Programming Languages +// and Systems, Volume 2, Issue 1-4 (March-Dec. 1993), pp. 59-69. +// +// Briggs & Torczon popularized this technique, but it had been known +// long before their paper. They point out that Aho, Hopcroft, and +// Ullman's 1974 Design and Analysis of Computer Algorithms and Bentley's +// 1986 Programming Pearls both hint at the technique in exercises to the +// reader (in Aho & Hopcroft, exercise 2.12; in Bentley, column 1 +// exercise 8). +// +// Briggs & Torczon describe a sparse set implementation. I have +// trivially generalized it to create a sparse array (actually the original +// target of the AHU and Bentley exercises). + +// IMPLEMENTATION +// +// SparseArray is an array dense_ and an array sparse_ of identical size. +// At any point, the number of elements in the sparse array is size_. +// +// The array dense_ contains the size_ elements in the sparse array (with +// their indices), +// in the order that the elements were first inserted. This array is dense: +// the size_ pairs are dense_[0] through dense_[size_-1]. +// +// The array sparse_ maps from indices in [0,m) to indices in [0,size_). +// For indices present in the array, dense_[sparse_[i]].index_ == i. +// For indices not present in the array, sparse_ can contain any value at all, +// perhaps outside the range [0, size_) but perhaps not. +// +// The lax requirement on sparse_ values makes clearing the array very easy: +// set size_ to 0. Lookups are slightly more complicated. +// An index i has a value in the array if and only if: +// sparse_[i] is in [0, size_) AND +// dense_[sparse_[i]].index_ == i. +// If both these properties hold, only then it is safe to refer to +// dense_[sparse_[i]].value_ +// as the value associated with index i. +// +// To insert a new entry, set sparse_[i] to size_, +// initialize dense_[size_], and then increment size_. +// +// To make the sparse array as efficient as possible for non-primitive types, +// elements may or may not be destroyed when they are deleted from the sparse +// array through a call to resize(). They immediately become inaccessible, but +// they are only guaranteed to be destroyed when the SparseArray destructor is +// called. +// +// A moved-from SparseArray will be empty. + +// Doing this simplifies the logic below. +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#include +#include +#if __has_feature(memory_sanitizer) +#include +#endif +#include +#include +#include + +#include "re2/pod_array.h" + +namespace re2 { + +template +class SparseArray { +public: + SparseArray(); + explicit SparseArray(int max_size); + ~SparseArray(); + + // IndexValue pairs: exposed in SparseArray::iterator. + class IndexValue; + + typedef IndexValue *iterator; + typedef const IndexValue *const_iterator; + + SparseArray(const SparseArray &src); + SparseArray(SparseArray &&src); + + SparseArray &operator=(const SparseArray &src); + SparseArray &operator=(SparseArray &&src); + + // Return the number of entries in the array. + int size() const { return size_; } + + // Indicate whether the array is empty. + int empty() const { return size_ == 0; } + + // Iterate over the array. + iterator begin() { return dense_.data(); } + iterator end() { return dense_.data() + size_; } + + const_iterator begin() const { return dense_.data(); } + const_iterator end() const { return dense_.data() + size_; } + + // Change the maximum size of the array. + // Invalidates all iterators. + void resize(int new_max_size); + + // Return the maximum size of the array. + // Indices can be in the range [0, max_size). + int max_size() const { + if (dense_.data() != NULL) + return dense_.size(); + else + return 0; + } + + // Clear the array. + void clear() { size_ = 0; } + + // Check whether index i is in the array. + bool has_index(int i) const; + + // Comparison function for sorting. + // Can sort the sparse array so that future iterations + // will visit indices in increasing order using + // std::sort(arr.begin(), arr.end(), arr.less); + static bool less(const IndexValue &a, const IndexValue &b); + +public: + // Set the value at index i to v. + iterator set(int i, const Value &v) { return SetInternal(true, i, v); } + + // Set the value at new index i to v. + // Fast but unsafe: only use if has_index(i) is false. + iterator set_new(int i, const Value &v) { return SetInternal(false, i, v); } + + // Set the value at index i to v. + // Fast but unsafe: only use if has_index(i) is true. + iterator set_existing(int i, const Value &v) { return SetExistingInternal(i, v); } + + // Get the value at index i. + // Fast but unsafe: only use if has_index(i) is true. + Value &get_existing(int i) { + assert(has_index(i)); + return dense_[sparse_[i]].value_; + } + const Value &get_existing(int i) const { + assert(has_index(i)); + return dense_[sparse_[i]].value_; + } + +private: + iterator SetInternal(bool allow_existing, int i, const Value &v) { + DebugCheckInvariants(); + if (static_cast(i) >= static_cast(max_size())) { + assert(false && "illegal index"); + // Semantically, end() would be better here, but we already know + // the user did something stupid, so begin() insulates them from + // dereferencing an invalid pointer. + return begin(); + } + if (!allow_existing) { + assert(!has_index(i)); + create_index(i); + } else { + if (!has_index(i)) + create_index(i); + } + return SetExistingInternal(i, v); + } + + iterator SetExistingInternal(int i, const Value &v) { + DebugCheckInvariants(); + assert(has_index(i)); + dense_[sparse_[i]].value_ = v; + DebugCheckInvariants(); + return dense_.data() + sparse_[i]; + } + + // Add the index i to the array. + // Only use if has_index(i) is known to be false. + // Since it doesn't set the value associated with i, + // this function is private, only intended as a helper + // for other methods. + void create_index(int i); + + // In debug mode, verify that some invariant properties of the class + // are being maintained. This is called at the end of the constructor + // and at the beginning and end of all public non-const member functions. + void DebugCheckInvariants() const; + + // Initializes memory for elements [min, max). + void MaybeInitializeMemory(int min, int max) { +#if __has_feature(memory_sanitizer) + __msan_unpoison(sparse_.data() + min, (max - min) * sizeof sparse_[0]); +#elif defined(RE2_ON_VALGRIND) + for (int i = min; i < max; i++) { + sparse_[i] = 0xababababU; + } +#endif + } + + int size_ = 0; + PODArray sparse_; + PODArray dense_; +}; + +template +SparseArray::SparseArray() = default; + +template +SparseArray::SparseArray(const SparseArray &src) : size_(src.size_), sparse_(src.max_size()), dense_(src.max_size()) { + std::copy_n(src.sparse_.data(), src.max_size(), sparse_.data()); + std::copy_n(src.dense_.data(), src.max_size(), dense_.data()); +} + +template +SparseArray::SparseArray(SparseArray &&src) : size_(src.size_), sparse_(std::move(src.sparse_)), dense_(std::move(src.dense_)) { + src.size_ = 0; +} + +template +SparseArray &SparseArray::operator=(const SparseArray &src) { + // Construct these first for exception safety. + PODArray a(src.max_size()); + PODArray b(src.max_size()); + + size_ = src.size_; + sparse_ = std::move(a); + dense_ = std::move(b); + std::copy_n(src.sparse_.data(), src.max_size(), sparse_.data()); + std::copy_n(src.dense_.data(), src.max_size(), dense_.data()); + return *this; +} + +template +SparseArray &SparseArray::operator=(SparseArray &&src) { + size_ = src.size_; + sparse_ = std::move(src.sparse_); + dense_ = std::move(src.dense_); + src.size_ = 0; + return *this; +} + +// IndexValue pairs: exposed in SparseArray::iterator. +template +class SparseArray::IndexValue { +public: + int index() const { return index_; } + Value &value() { return value_; } + const Value &value() const { return value_; } + +private: + friend class SparseArray; + int index_; + Value value_; +}; + +// Change the maximum size of the array. +// Invalidates all iterators. +template +void SparseArray::resize(int new_max_size) { + DebugCheckInvariants(); + if (new_max_size > max_size()) { + const int old_max_size = max_size(); + + // Construct these first for exception safety. + PODArray a(new_max_size); + PODArray b(new_max_size); + + std::copy_n(sparse_.data(), old_max_size, a.data()); + std::copy_n(dense_.data(), old_max_size, b.data()); + + sparse_ = std::move(a); + dense_ = std::move(b); + + MaybeInitializeMemory(old_max_size, new_max_size); + } + if (size_ > new_max_size) + size_ = new_max_size; + DebugCheckInvariants(); +} + +// Check whether index i is in the array. +template +bool SparseArray::has_index(int i) const { + assert(i >= 0); + assert(i < max_size()); + if (static_cast(i) >= static_cast(max_size())) { + return false; + } + // Unsigned comparison avoids checking sparse_[i] < 0. + return (uint32_t)sparse_[i] < (uint32_t)size_ && dense_[sparse_[i]].index_ == i; +} + +template +void SparseArray::create_index(int i) { + assert(!has_index(i)); + assert(size_ < max_size()); + sparse_[i] = size_; + dense_[size_].index_ = i; + size_++; +} + +template +SparseArray::SparseArray(int max_size) : sparse_(max_size), dense_(max_size) { + MaybeInitializeMemory(size_, max_size); + DebugCheckInvariants(); +} + +template +SparseArray::~SparseArray() { + DebugCheckInvariants(); +} + +template +void SparseArray::DebugCheckInvariants() const { + assert(0 <= size_); + assert(size_ <= max_size()); +} + +// Comparison function for sorting. +template +bool SparseArray::less(const IndexValue &a, const IndexValue &b) { + return a.index_ < b.index_; +} + +} // namespace re2 + +#endif // RE2_SPARSE_ARRAY_H_ diff --git a/internal/cpp/re2/sparse_set.h b/internal/cpp/re2/sparse_set.h new file mode 100644 index 00000000000..7a993968a13 --- /dev/null +++ b/internal/cpp/re2/sparse_set.h @@ -0,0 +1,248 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_SPARSE_SET_H_ +#define RE2_SPARSE_SET_H_ + +// DESCRIPTION +// +// SparseSet(m) is a set of integers in [0, m). +// It requires sizeof(int)*m memory, but it provides +// fast iteration through the elements in the set and fast clearing +// of the set. +// +// Insertion and deletion are constant time operations. +// +// Allocating the set is a constant time operation +// when memory allocation is a constant time operation. +// +// Clearing the set is a constant time operation (unusual!). +// +// Iterating through the set is an O(n) operation, where n +// is the number of items in the set (not O(m)). +// +// The set iterator visits entries in the order they were first +// inserted into the set. It is safe to add items to the set while +// using an iterator: the iterator will visit indices added to the set +// during the iteration, but will not re-visit indices whose values +// change after visiting. Thus SparseSet can be a convenient +// implementation of a work queue. +// +// The SparseSet implementation is NOT thread-safe. It is up to the +// caller to make sure only one thread is accessing the set. (Typically +// these sets are temporary values and used in situations where speed is +// important.) +// +// The SparseSet interface does not present all the usual STL bells and +// whistles. +// +// Implemented with reference to Briggs & Torczon, An Efficient +// Representation for Sparse Sets, ACM Letters on Programming Languages +// and Systems, Volume 2, Issue 1-4 (March-Dec. 1993), pp. 59-69. +// +// This is a specialization of sparse array; see sparse_array.h. + +// IMPLEMENTATION +// +// See sparse_array.h for implementation details. + +// Doing this simplifies the logic below. +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#include +#include +#if __has_feature(memory_sanitizer) +#include +#endif +#include +#include +#include + +#include "re2/pod_array.h" + +namespace re2 { + +template +class SparseSetT { +public: + SparseSetT(); + explicit SparseSetT(int max_size); + ~SparseSetT(); + + typedef int *iterator; + typedef const int *const_iterator; + + // Return the number of entries in the set. + int size() const { return size_; } + + // Indicate whether the set is empty. + int empty() const { return size_ == 0; } + + // Iterate over the set. + iterator begin() { return dense_.data(); } + iterator end() { return dense_.data() + size_; } + + const_iterator begin() const { return dense_.data(); } + const_iterator end() const { return dense_.data() + size_; } + + // Change the maximum size of the set. + // Invalidates all iterators. + void resize(int new_max_size); + + // Return the maximum size of the set. + // Indices can be in the range [0, max_size). + int max_size() const { + if (dense_.data() != NULL) + return dense_.size(); + else + return 0; + } + + // Clear the set. + void clear() { size_ = 0; } + + // Check whether index i is in the set. + bool contains(int i) const; + + // Comparison function for sorting. + // Can sort the sparse set so that future iterations + // will visit indices in increasing order using + // std::sort(arr.begin(), arr.end(), arr.less); + static bool less(int a, int b); + +public: + // Insert index i into the set. + iterator insert(int i) { return InsertInternal(true, i); } + + // Insert index i into the set. + // Fast but unsafe: only use if contains(i) is false. + iterator insert_new(int i) { return InsertInternal(false, i); } + +private: + iterator InsertInternal(bool allow_existing, int i) { + DebugCheckInvariants(); + if (static_cast(i) >= static_cast(max_size())) { + assert(false && "illegal index"); + // Semantically, end() would be better here, but we already know + // the user did something stupid, so begin() insulates them from + // dereferencing an invalid pointer. + return begin(); + } + if (!allow_existing) { + assert(!contains(i)); + create_index(i); + } else { + if (!contains(i)) + create_index(i); + } + DebugCheckInvariants(); + return dense_.data() + sparse_[i]; + } + + // Add the index i to the set. + // Only use if contains(i) is known to be false. + // This function is private, only intended as a helper + // for other methods. + void create_index(int i); + + // In debug mode, verify that some invariant properties of the class + // are being maintained. This is called at the end of the constructor + // and at the beginning and end of all public non-const member functions. + void DebugCheckInvariants() const; + + // Initializes memory for elements [min, max). + void MaybeInitializeMemory(int min, int max) { +#if __has_feature(memory_sanitizer) + __msan_unpoison(sparse_.data() + min, (max - min) * sizeof sparse_[0]); +#elif defined(RE2_ON_VALGRIND) + for (int i = min; i < max; i++) { + sparse_[i] = 0xababababU; + } +#endif + } + + int size_ = 0; + PODArray sparse_; + PODArray dense_; +}; + +template +SparseSetT::SparseSetT() = default; + +// Change the maximum size of the set. +// Invalidates all iterators. +template +void SparseSetT::resize(int new_max_size) { + DebugCheckInvariants(); + if (new_max_size > max_size()) { + const int old_max_size = max_size(); + + // Construct these first for exception safety. + PODArray a(new_max_size); + PODArray b(new_max_size); + + std::copy_n(sparse_.data(), old_max_size, a.data()); + std::copy_n(dense_.data(), old_max_size, b.data()); + + sparse_ = std::move(a); + dense_ = std::move(b); + + MaybeInitializeMemory(old_max_size, new_max_size); + } + if (size_ > new_max_size) + size_ = new_max_size; + DebugCheckInvariants(); +} + +// Check whether index i is in the set. +template +bool SparseSetT::contains(int i) const { + assert(i >= 0); + assert(i < max_size()); + if (static_cast(i) >= static_cast(max_size())) { + return false; + } + // Unsigned comparison avoids checking sparse_[i] < 0. + return (uint32_t)sparse_[i] < (uint32_t)size_ && dense_[sparse_[i]] == i; +} + +template +void SparseSetT::create_index(int i) { + assert(!contains(i)); + assert(size_ < max_size()); + sparse_[i] = size_; + dense_[size_] = i; + size_++; +} + +template +SparseSetT::SparseSetT(int max_size) : sparse_(max_size), dense_(max_size) { + MaybeInitializeMemory(size_, max_size); + DebugCheckInvariants(); +} + +template +SparseSetT::~SparseSetT() { + DebugCheckInvariants(); +} + +template +void SparseSetT::DebugCheckInvariants() const { + assert(0 <= size_); + assert(size_ <= max_size()); +} + +// Comparison function for sorting. +template +bool SparseSetT::less(int a, int b) { + return a < b; +} + +typedef SparseSetT SparseSet; + +} // namespace re2 + +#endif // RE2_SPARSE_SET_H_ diff --git a/internal/cpp/re2/stringpiece.cc b/internal/cpp/re2/stringpiece.cc new file mode 100644 index 00000000000..41e95bbb910 --- /dev/null +++ b/internal/cpp/re2/stringpiece.cc @@ -0,0 +1,69 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "re2/stringpiece.h" + +#include + +#include "util/util.h" + +namespace re2 { + +const StringPiece::size_type StringPiece::npos; // initialized in stringpiece.h + +StringPiece::size_type StringPiece::copy(char *buf, size_type n, size_type pos) const { + size_type ret = std::min(size_ - pos, n); + memcpy(buf, data_ + pos, ret); + return ret; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > size_) + pos = size_; + if (n > size_ - pos) + n = size_ - pos; + return StringPiece(data_ + pos, n); +} + +StringPiece::size_type StringPiece::find(const StringPiece &s, size_type pos) const { + if (pos > size_) + return npos; + const_pointer result = std::search(data_ + pos, data_ + size_, s.data_, s.data_ + s.size_); + size_type xpos = result - data_; + return xpos + s.size_ <= size_ ? xpos : npos; +} + +StringPiece::size_type StringPiece::find(char c, size_type pos) const { + if (size_ <= 0 || pos >= size_) + return npos; + const_pointer result = std::find(data_ + pos, data_ + size_, c); + return result != data_ + size_ ? result - data_ : npos; +} + +StringPiece::size_type StringPiece::rfind(const StringPiece &s, size_type pos) const { + if (size_ < s.size_) + return npos; + if (s.size_ == 0) + return std::min(size_, pos); + const_pointer last = data_ + std::min(size_ - s.size_, pos) + s.size_; + const_pointer result = std::find_end(data_, last, s.data_, s.data_ + s.size_); + return result != last ? result - data_ : npos; +} + +StringPiece::size_type StringPiece::rfind(char c, size_type pos) const { + if (size_ <= 0) + return npos; + for (size_t i = std::min(pos + 1, size_); i != 0;) { + if (data_[--i] == c) + return i; + } + return npos; +} + +std::ostream &operator<<(std::ostream &o, const StringPiece &p) { + o.write(p.data(), p.size()); + return o; +} + +} // namespace re2 diff --git a/internal/cpp/re2/stringpiece.h b/internal/cpp/re2/stringpiece.h new file mode 100644 index 00000000000..2429a8c917d --- /dev/null +++ b/internal/cpp/re2/stringpiece.h @@ -0,0 +1,189 @@ +// Copyright 2001-2010 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_STRINGPIECE_H_ +#define RE2_STRINGPIECE_H_ + +#ifdef min +#undef min +#endif + +// A string-like object that points to a sized piece of memory. +// +// Functions or methods may use const StringPiece& parameters to accept either +// a "const char*" or a "string" value that will be implicitly converted to +// a StringPiece. The implicit conversion means that it is often appropriate +// to include this .h file in other files rather than forward-declaring +// StringPiece as would be appropriate for most other Google classes. +// +// Systematic usage of StringPiece is encouraged as it will reduce unnecessary +// conversions from "const char*" to "string" and back again. +// +// +// Arghh! I wish C++ literals were "string". + +#include +#include +#include +#include +#include +#include +#ifdef __cpp_lib_string_view +#include +#endif + +namespace re2 { + +class StringPiece { +public: + typedef std::char_traits traits_type; + typedef char value_type; + typedef char *pointer; + typedef const char *const_pointer; + typedef char &reference; + typedef const char &const_reference; + typedef const char *const_iterator; + typedef const_iterator iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef const_reverse_iterator reverse_iterator; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + static const size_type npos = static_cast(-1); + + // We provide non-explicit singleton constructors so users can pass + // in a "const char*" or a "string" wherever a "StringPiece" is + // expected. + StringPiece() : data_(NULL), size_(0) {} +#ifdef __cpp_lib_string_view + StringPiece(const std::string_view &str) : data_(str.data()), size_(str.size()) {} +#endif + StringPiece(const std::string &str) : data_(str.data()), size_(str.size()) {} + StringPiece(const char *str) : data_(str), size_(str == NULL ? 0 : strlen(str)) {} + StringPiece(const char *str, size_type len) : data_(str), size_(len) {} + + const_iterator begin() const { return data_; } + const_iterator end() const { return data_ + size_; } + const_reverse_iterator rbegin() const { return const_reverse_iterator(data_ + size_); } + const_reverse_iterator rend() const { return const_reverse_iterator(data_); } + + size_type size() const { return size_; } + size_type length() const { return size_; } + bool empty() const { return size_ == 0; } + + const_reference operator[](size_type i) const { return data_[i]; } + const_pointer data() const { return data_; } + + void remove_prefix(size_type n) { + data_ += n; + size_ -= n; + } + + void remove_suffix(size_type n) { size_ -= n; } + + void set(const char *str) { + data_ = str; + size_ = str == NULL ? 0 : strlen(str); + } + + void set(const char *str, size_type len) { + data_ = str; + size_ = len; + } + +#ifdef __cpp_lib_string_view + // Converts to `std::basic_string_view`. + operator std::basic_string_view() const { + if (!data_) + return {}; + return std::basic_string_view(data_, size_); + } +#endif + + // Converts to `std::basic_string`. + template + explicit operator std::basic_string() const { + if (!data_) + return {}; + return std::basic_string(data_, size_); + } + + std::string as_string() const { return std::string(data_, size_); } + + // We also define ToString() here, since many other string-like + // interfaces name the routine that converts to a C++ string + // "ToString", and it's confusing to have the method that does that + // for a StringPiece be called "as_string()". We also leave the + // "as_string()" method defined here for existing code. + std::string ToString() const { return std::string(data_, size_); } + + void CopyToString(std::string *target) const { target->assign(data_, size_); } + + void AppendToString(std::string *target) const { target->append(data_, size_); } + + size_type copy(char *buf, size_type n, size_type pos = 0) const; + StringPiece substr(size_type pos = 0, size_type n = npos) const; + + int compare(const StringPiece &x) const { + size_type min_size = std::min(size(), x.size()); + if (min_size > 0) { + int r = memcmp(data(), x.data(), min_size); + if (r < 0) + return -1; + if (r > 0) + return 1; + } + if (size() < x.size()) + return -1; + if (size() > x.size()) + return 1; + return 0; + } + + // Does "this" start with "x"? + bool starts_with(const StringPiece &x) const { return x.empty() || (size() >= x.size() && memcmp(data(), x.data(), x.size()) == 0); } + + // Does "this" end with "x"? + bool ends_with(const StringPiece &x) const { + return x.empty() || (size() >= x.size() && memcmp(data() + (size() - x.size()), x.data(), x.size()) == 0); + } + + bool contains(const StringPiece &s) const { return find(s) != npos; } + + size_type find(const StringPiece &s, size_type pos = 0) const; + size_type find(char c, size_type pos = 0) const; + size_type rfind(const StringPiece &s, size_type pos = npos) const; + size_type rfind(char c, size_type pos = npos) const; + +private: + const_pointer data_; + size_type size_; +}; + +inline bool operator==(const StringPiece &x, const StringPiece &y) { + StringPiece::size_type len = x.size(); + if (len != y.size()) + return false; + return x.data() == y.data() || len == 0 || memcmp(x.data(), y.data(), len) == 0; +} + +inline bool operator!=(const StringPiece &x, const StringPiece &y) { return !(x == y); } + +inline bool operator<(const StringPiece &x, const StringPiece &y) { + StringPiece::size_type min_size = std::min(x.size(), y.size()); + int r = min_size == 0 ? 0 : memcmp(x.data(), y.data(), min_size); + return (r < 0) || (r == 0 && x.size() < y.size()); +} + +inline bool operator>(const StringPiece &x, const StringPiece &y) { return y < x; } + +inline bool operator<=(const StringPiece &x, const StringPiece &y) { return !(x > y); } + +inline bool operator>=(const StringPiece &x, const StringPiece &y) { return !(x < y); } + +// Allow StringPiece to be logged. +std::ostream &operator<<(std::ostream &o, const StringPiece &p); + +} // namespace re2 + +#endif // RE2_STRINGPIECE_H_ diff --git a/internal/cpp/re2/tostring.cc b/internal/cpp/re2/tostring.cc new file mode 100644 index 00000000000..e86185be16c --- /dev/null +++ b/internal/cpp/re2/tostring.cc @@ -0,0 +1,345 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Format a regular expression structure as a string. +// Tested by parse_test.cc + +#include +#include + +#include "re2/regexp.h" +#include "re2/walker-inl.h" +#include "util/logging.h" +#include "util/strutil.h" +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +enum { + PrecAtom, + PrecUnary, + PrecConcat, + PrecAlternate, + PrecEmpty, + PrecParen, + PrecToplevel, +}; + +// Helper function. See description below. +static void AppendCCRange(std::string *t, Rune lo, Rune hi); + +// Walker to generate string in s_. +// The arg pointers are actually integers giving the +// context precedence. +// The child_args are always NULL. +class ToStringWalker : public Regexp::Walker { +public: + explicit ToStringWalker(std::string *t) : t_(t) {} + + virtual int PreVisit(Regexp *re, int parent_arg, bool *stop); + virtual int PostVisit(Regexp *re, int parent_arg, int pre_arg, int *child_args, int nchild_args); + virtual int ShortVisit(Regexp *re, int parent_arg) { return 0; } + +private: + std::string *t_; // The string the walker appends to. + + ToStringWalker(const ToStringWalker &) = delete; + ToStringWalker &operator=(const ToStringWalker &) = delete; +}; + +std::string Regexp::ToString() { + std::string t; + ToStringWalker w(&t); + w.WalkExponential(this, PrecToplevel, 100000); + if (w.stopped_early()) + t += " [truncated]"; + return t; +} + +#define ToString DontCallToString // Avoid accidental recursion. + +// Visits re before children are processed. +// Appends ( if needed and passes new precedence to children. +int ToStringWalker::PreVisit(Regexp *re, int parent_arg, bool *stop) { + int prec = parent_arg; + int nprec = PrecAtom; + + switch (re->op()) { + case kRegexpNoMatch: + case kRegexpEmptyMatch: + case kRegexpLiteral: + case kRegexpAnyChar: + case kRegexpAnyByte: + case kRegexpBeginLine: + case kRegexpEndLine: + case kRegexpBeginText: + case kRegexpEndText: + case kRegexpWordBoundary: + case kRegexpNoWordBoundary: + case kRegexpCharClass: + case kRegexpHaveMatch: + nprec = PrecAtom; + break; + + case kRegexpConcat: + case kRegexpLiteralString: + if (prec < PrecConcat) + t_->append("(?:"); + nprec = PrecConcat; + break; + + case kRegexpAlternate: + if (prec < PrecAlternate) + t_->append("(?:"); + nprec = PrecAlternate; + break; + + case kRegexpCapture: + t_->append("("); + if (re->cap() == 0) + LOG(DFATAL) << "kRegexpCapture cap() == 0"; + if (re->name()) { + t_->append("?P<"); + t_->append(*re->name()); + t_->append(">"); + } + nprec = PrecParen; + break; + + case kRegexpStar: + case kRegexpPlus: + case kRegexpQuest: + case kRegexpRepeat: + if (prec < PrecUnary) + t_->append("(?:"); + // The subprecedence here is PrecAtom instead of PrecUnary + // because PCRE treats two unary ops in a row as a parse error. + nprec = PrecAtom; + break; + } + + return nprec; +} + +static void AppendLiteral(std::string *t, Rune r, bool foldcase) { + if (r != 0 && r < 0x80 && strchr("(){}[]*+?|.^$\\", r)) { + t->append(1, '\\'); + t->append(1, static_cast(r)); + } else if (foldcase && 'a' <= r && r <= 'z') { + r -= 'a' - 'A'; + t->append(1, '['); + t->append(1, static_cast(r)); + t->append(1, static_cast(r) + 'a' - 'A'); + t->append(1, ']'); + } else { + AppendCCRange(t, r, r); + } +} + +// Visits re after children are processed. +// For childless regexps, all the work is done here. +// For regexps with children, append any unary suffixes or ). +int ToStringWalker::PostVisit(Regexp *re, int parent_arg, int pre_arg, int *child_args, int nchild_args) { + int prec = parent_arg; + switch (re->op()) { + case kRegexpNoMatch: + // There's no simple symbol for "no match", but + // [^0-Runemax] excludes everything. + t_->append("[^\\x00-\\x{10ffff}]"); + break; + + case kRegexpEmptyMatch: + // Append (?:) to make empty string visible, + // unless this is already being parenthesized. + if (prec < PrecEmpty) + t_->append("(?:)"); + break; + + case kRegexpLiteral: + AppendLiteral(t_, re->rune(), (re->parse_flags() & Regexp::FoldCase) != 0); + break; + + case kRegexpLiteralString: + for (int i = 0; i < re->nrunes(); i++) + AppendLiteral(t_, re->runes()[i], (re->parse_flags() & Regexp::FoldCase) != 0); + if (prec < PrecConcat) + t_->append(")"); + break; + + case kRegexpConcat: + if (prec < PrecConcat) + t_->append(")"); + break; + + case kRegexpAlternate: + // Clumsy but workable: the children all appended | + // at the end of their strings, so just remove the last one. + if ((*t_)[t_->size() - 1] == '|') + t_->erase(t_->size() - 1); + else + LOG(DFATAL) << "Bad final char: " << t_; + if (prec < PrecAlternate) + t_->append(")"); + break; + + case kRegexpStar: + t_->append("*"); + if (re->parse_flags() & Regexp::NonGreedy) + t_->append("?"); + if (prec < PrecUnary) + t_->append(")"); + break; + + case kRegexpPlus: + t_->append("+"); + if (re->parse_flags() & Regexp::NonGreedy) + t_->append("?"); + if (prec < PrecUnary) + t_->append(")"); + break; + + case kRegexpQuest: + t_->append("?"); + if (re->parse_flags() & Regexp::NonGreedy) + t_->append("?"); + if (prec < PrecUnary) + t_->append(")"); + break; + + case kRegexpRepeat: + if (re->max() == -1) + t_->append(StringPrintf("{%d,}", re->min())); + else if (re->min() == re->max()) + t_->append(StringPrintf("{%d}", re->min())); + else + t_->append(StringPrintf("{%d,%d}", re->min(), re->max())); + if (re->parse_flags() & Regexp::NonGreedy) + t_->append("?"); + if (prec < PrecUnary) + t_->append(")"); + break; + + case kRegexpAnyChar: + t_->append("."); + break; + + case kRegexpAnyByte: + t_->append("\\C"); + break; + + case kRegexpBeginLine: + t_->append("^"); + break; + + case kRegexpEndLine: + t_->append("$"); + break; + + case kRegexpBeginText: + t_->append("(?-m:^)"); + break; + + case kRegexpEndText: + if (re->parse_flags() & Regexp::WasDollar) + t_->append("(?-m:$)"); + else + t_->append("\\z"); + break; + + case kRegexpWordBoundary: + t_->append("\\b"); + break; + + case kRegexpNoWordBoundary: + t_->append("\\B"); + break; + + case kRegexpCharClass: { + if (re->cc()->size() == 0) { + t_->append("[^\\x00-\\x{10ffff}]"); + break; + } + t_->append("["); + // Heuristic: show class as negated if it contains the + // non-character 0xFFFE and yet somehow isn't full. + CharClass *cc = re->cc(); + if (cc->Contains(0xFFFE) && !cc->full()) { + cc = cc->Negate(); + t_->append("^"); + } + for (CharClass::iterator i = cc->begin(); i != cc->end(); ++i) + AppendCCRange(t_, i->lo, i->hi); + if (cc != re->cc()) + cc->Delete(); + t_->append("]"); + break; + } + + case kRegexpCapture: + t_->append(")"); + break; + + case kRegexpHaveMatch: + // There's no syntax accepted by the parser to generate + // this node (it is generated by RE2::Set) so make something + // up that is readable but won't compile. + t_->append(StringPrintf("(?HaveMatch:%d)", re->match_id())); + break; + } + + // If the parent is an alternation, append the | for it. + if (prec == PrecAlternate) + t_->append("|"); + + return 0; +} + +// Appends a rune for use in a character class to the string t. +static void AppendCCChar(std::string *t, Rune r) { + if (0x20 <= r && r <= 0x7E) { + if (strchr("[]^-\\", r)) + t->append("\\"); + t->append(1, static_cast(r)); + return; + } + switch (r) { + default: + break; + + case '\r': + t->append("\\r"); + return; + + case '\t': + t->append("\\t"); + return; + + case '\n': + t->append("\\n"); + return; + + case '\f': + t->append("\\f"); + return; + } + + if (r < 0x100) { + *t += StringPrintf("\\x%02x", static_cast(r)); + return; + } + *t += StringPrintf("\\x{%x}", static_cast(r)); +} + +static void AppendCCRange(std::string *t, Rune lo, Rune hi) { + if (lo > hi) + return; + AppendCCChar(t, lo); + if (lo < hi) { + t->append("-"); + AppendCCChar(t, hi); + } +} + +} // namespace re2 diff --git a/internal/cpp/re2/unicode_casefold.cc b/internal/cpp/re2/unicode_casefold.cc new file mode 100644 index 00000000000..f7818ff24c3 --- /dev/null +++ b/internal/cpp/re2/unicode_casefold.cc @@ -0,0 +1,591 @@ + +// GENERATED BY make_unicode_casefold.py; DO NOT EDIT. +// make_unicode_casefold.py >unicode_casefold.cc + +#include "re2/unicode_casefold.h" + +namespace re2 { + +// 1424 groups, 2878 pairs, 367 ranges +const CaseFold unicode_casefold[] = { + {65, 90, 32}, + {97, 106, -32}, + {107, 107, 8383}, + {108, 114, -32}, + {115, 115, 268}, + {116, 122, -32}, + {181, 181, 743}, + {192, 214, 32}, + {216, 222, 32}, + {223, 223, 7615}, + {224, 228, -32}, + {229, 229, 8262}, + {230, 246, -32}, + {248, 254, -32}, + {255, 255, 121}, + {256, 303, EvenOdd}, + {306, 311, EvenOdd}, + {313, 328, OddEven}, + {330, 375, EvenOdd}, + {376, 376, -121}, + {377, 382, OddEven}, + {383, 383, -300}, + {384, 384, 195}, + {385, 385, 210}, + {386, 389, EvenOdd}, + {390, 390, 206}, + {391, 392, OddEven}, + {393, 394, 205}, + {395, 396, OddEven}, + {398, 398, 79}, + {399, 399, 202}, + {400, 400, 203}, + {401, 402, OddEven}, + {403, 403, 205}, + {404, 404, 207}, + {405, 405, 97}, + {406, 406, 211}, + {407, 407, 209}, + {408, 409, EvenOdd}, + {410, 410, 163}, + {412, 412, 211}, + {413, 413, 213}, + {414, 414, 130}, + {415, 415, 214}, + {416, 421, EvenOdd}, + {422, 422, 218}, + {423, 424, OddEven}, + {425, 425, 218}, + {428, 429, EvenOdd}, + {430, 430, 218}, + {431, 432, OddEven}, + {433, 434, 217}, + {435, 438, OddEven}, + {439, 439, 219}, + {440, 441, EvenOdd}, + {444, 445, EvenOdd}, + {447, 447, 56}, + {452, 452, EvenOdd}, + {453, 453, OddEven}, + {454, 454, -2}, + {455, 455, OddEven}, + {456, 456, EvenOdd}, + {457, 457, -2}, + {458, 458, EvenOdd}, + {459, 459, OddEven}, + {460, 460, -2}, + {461, 476, OddEven}, + {477, 477, -79}, + {478, 495, EvenOdd}, + {497, 497, OddEven}, + {498, 498, EvenOdd}, + {499, 499, -2}, + {500, 501, EvenOdd}, + {502, 502, -97}, + {503, 503, -56}, + {504, 543, EvenOdd}, + {544, 544, -130}, + {546, 563, EvenOdd}, + {570, 570, 10795}, + {571, 572, OddEven}, + {573, 573, -163}, + {574, 574, 10792}, + {575, 576, 10815}, + {577, 578, OddEven}, + {579, 579, -195}, + {580, 580, 69}, + {581, 581, 71}, + {582, 591, EvenOdd}, + {592, 592, 10783}, + {593, 593, 10780}, + {594, 594, 10782}, + {595, 595, -210}, + {596, 596, -206}, + {598, 599, -205}, + {601, 601, -202}, + {603, 603, -203}, + {604, 604, 42319}, + {608, 608, -205}, + {609, 609, 42315}, + {611, 611, -207}, + {613, 613, 42280}, + {614, 614, 42308}, + {616, 616, -209}, + {617, 617, -211}, + {618, 618, 42308}, + {619, 619, 10743}, + {620, 620, 42305}, + {623, 623, -211}, + {625, 625, 10749}, + {626, 626, -213}, + {629, 629, -214}, + {637, 637, 10727}, + {640, 640, -218}, + {642, 642, 42307}, + {643, 643, -218}, + {647, 647, 42282}, + {648, 648, -218}, + {649, 649, -69}, + {650, 651, -217}, + {652, 652, -71}, + {658, 658, -219}, + {669, 669, 42261}, + {670, 670, 42258}, + {837, 837, 84}, + {880, 883, EvenOdd}, + {886, 887, EvenOdd}, + {891, 893, 130}, + {895, 895, 116}, + {902, 902, 38}, + {904, 906, 37}, + {908, 908, 64}, + {910, 911, 63}, + {913, 929, 32}, + {931, 931, 31}, + {932, 939, 32}, + {940, 940, -38}, + {941, 943, -37}, + {945, 945, -32}, + {946, 946, 30}, + {947, 948, -32}, + {949, 949, 64}, + {950, 951, -32}, + {952, 952, 25}, + {953, 953, 7173}, + {954, 954, 54}, + {955, 955, -32}, + {956, 956, -775}, + {957, 959, -32}, + {960, 960, 22}, + {961, 961, 48}, + {962, 962, EvenOdd}, + {963, 965, -32}, + {966, 966, 15}, + {967, 968, -32}, + {969, 969, 7517}, + {970, 971, -32}, + {972, 972, -64}, + {973, 974, -63}, + {975, 975, 8}, + {976, 976, -62}, + {977, 977, 35}, + {981, 981, -47}, + {982, 982, -54}, + {983, 983, -8}, + {984, 1007, EvenOdd}, + {1008, 1008, -86}, + {1009, 1009, -80}, + {1010, 1010, 7}, + {1011, 1011, -116}, + {1012, 1012, -92}, + {1013, 1013, -96}, + {1015, 1016, OddEven}, + {1017, 1017, -7}, + {1018, 1019, EvenOdd}, + {1021, 1023, -130}, + {1024, 1039, 80}, + {1040, 1071, 32}, + {1072, 1073, -32}, + {1074, 1074, 6222}, + {1075, 1075, -32}, + {1076, 1076, 6221}, + {1077, 1085, -32}, + {1086, 1086, 6212}, + {1087, 1088, -32}, + {1089, 1090, 6210}, + {1091, 1097, -32}, + {1098, 1098, 6204}, + {1099, 1103, -32}, + {1104, 1119, -80}, + {1120, 1122, EvenOdd}, + {1123, 1123, 6180}, + {1124, 1153, EvenOdd}, + {1162, 1215, EvenOdd}, + {1216, 1216, 15}, + {1217, 1230, OddEven}, + {1231, 1231, -15}, + {1232, 1327, EvenOdd}, + {1329, 1366, 48}, + {1377, 1414, -48}, + {4256, 4293, 7264}, + {4295, 4295, 7264}, + {4301, 4301, 7264}, + {4304, 4346, 3008}, + {4349, 4351, 3008}, + {5024, 5103, 38864}, + {5104, 5109, 8}, + {5112, 5117, -8}, + {7296, 7296, -6254}, + {7297, 7297, -6253}, + {7298, 7298, -6244}, + {7299, 7299, -6242}, + {7300, 7300, EvenOdd}, + {7301, 7301, -6243}, + {7302, 7302, -6236}, + {7303, 7303, -6181}, + {7304, 7304, 35266}, + {7312, 7354, -3008}, + {7357, 7359, -3008}, + {7545, 7545, 35332}, + {7549, 7549, 3814}, + {7566, 7566, 35384}, + {7680, 7776, EvenOdd}, + {7777, 7777, 58}, + {7778, 7829, EvenOdd}, + {7835, 7835, -59}, + {7838, 7838, -7615}, + {7840, 7935, EvenOdd}, + {7936, 7943, 8}, + {7944, 7951, -8}, + {7952, 7957, 8}, + {7960, 7965, -8}, + {7968, 7975, 8}, + {7976, 7983, -8}, + {7984, 7991, 8}, + {7992, 7999, -8}, + {8000, 8005, 8}, + {8008, 8013, -8}, + {8017, 8017, 8}, + {8019, 8019, 8}, + {8021, 8021, 8}, + {8023, 8023, 8}, + {8025, 8025, -8}, + {8027, 8027, -8}, + {8029, 8029, -8}, + {8031, 8031, -8}, + {8032, 8039, 8}, + {8040, 8047, -8}, + {8048, 8049, 74}, + {8050, 8053, 86}, + {8054, 8055, 100}, + {8056, 8057, 128}, + {8058, 8059, 112}, + {8060, 8061, 126}, + {8064, 8071, 8}, + {8072, 8079, -8}, + {8080, 8087, 8}, + {8088, 8095, -8}, + {8096, 8103, 8}, + {8104, 8111, -8}, + {8112, 8113, 8}, + {8115, 8115, 9}, + {8120, 8121, -8}, + {8122, 8123, -74}, + {8124, 8124, -9}, + {8126, 8126, -7289}, + {8131, 8131, 9}, + {8136, 8139, -86}, + {8140, 8140, -9}, + {8144, 8145, 8}, + {8152, 8153, -8}, + {8154, 8155, -100}, + {8160, 8161, 8}, + {8165, 8165, 7}, + {8168, 8169, -8}, + {8170, 8171, -112}, + {8172, 8172, -7}, + {8179, 8179, 9}, + {8184, 8185, -128}, + {8186, 8187, -126}, + {8188, 8188, -9}, + {8486, 8486, -7549}, + {8490, 8490, -8415}, + {8491, 8491, -8294}, + {8498, 8498, 28}, + {8526, 8526, -28}, + {8544, 8559, 16}, + {8560, 8575, -16}, + {8579, 8580, OddEven}, + {9398, 9423, 26}, + {9424, 9449, -26}, + {11264, 11311, 48}, + {11312, 11359, -48}, + {11360, 11361, EvenOdd}, + {11362, 11362, -10743}, + {11363, 11363, -3814}, + {11364, 11364, -10727}, + {11365, 11365, -10795}, + {11366, 11366, -10792}, + {11367, 11372, OddEven}, + {11373, 11373, -10780}, + {11374, 11374, -10749}, + {11375, 11375, -10783}, + {11376, 11376, -10782}, + {11378, 11379, EvenOdd}, + {11381, 11382, OddEven}, + {11390, 11391, -10815}, + {11392, 11491, EvenOdd}, + {11499, 11502, OddEven}, + {11506, 11507, EvenOdd}, + {11520, 11557, -7264}, + {11559, 11559, -7264}, + {11565, 11565, -7264}, + {42560, 42570, EvenOdd}, + {42571, 42571, -35267}, + {42572, 42605, EvenOdd}, + {42624, 42651, EvenOdd}, + {42786, 42799, EvenOdd}, + {42802, 42863, EvenOdd}, + {42873, 42876, OddEven}, + {42877, 42877, -35332}, + {42878, 42887, EvenOdd}, + {42891, 42892, OddEven}, + {42893, 42893, -42280}, + {42896, 42899, EvenOdd}, + {42900, 42900, 48}, + {42902, 42921, EvenOdd}, + {42922, 42922, -42308}, + {42923, 42923, -42319}, + {42924, 42924, -42315}, + {42925, 42925, -42305}, + {42926, 42926, -42308}, + {42928, 42928, -42258}, + {42929, 42929, -42282}, + {42930, 42930, -42261}, + {42931, 42931, 928}, + {42932, 42947, EvenOdd}, + {42948, 42948, -48}, + {42949, 42949, -42307}, + {42950, 42950, -35384}, + {42951, 42954, OddEven}, + {42960, 42961, EvenOdd}, + {42966, 42969, EvenOdd}, + {42997, 42998, OddEven}, + {43859, 43859, -928}, + {43888, 43967, -38864}, + {65313, 65338, 32}, + {65345, 65370, -32}, + {66560, 66599, 40}, + {66600, 66639, -40}, + {66736, 66771, 40}, + {66776, 66811, -40}, + {66928, 66938, 39}, + {66940, 66954, 39}, + {66956, 66962, 39}, + {66964, 66965, 39}, + {66967, 66977, -39}, + {66979, 66993, -39}, + {66995, 67001, -39}, + {67003, 67004, -39}, + {68736, 68786, 64}, + {68800, 68850, -64}, + {71840, 71871, 32}, + {71872, 71903, -32}, + {93760, 93791, 32}, + {93792, 93823, -32}, + {125184, 125217, 34}, + {125218, 125251, -34}, +}; +const int num_unicode_casefold = 367; + +// 1424 groups, 1454 pairs, 205 ranges +const CaseFold unicode_tolower[] = { + {65, 90, 32}, + {181, 181, 775}, + {192, 214, 32}, + {216, 222, 32}, + {256, 302, EvenOddSkip}, + {306, 310, EvenOddSkip}, + {313, 327, OddEvenSkip}, + {330, 374, EvenOddSkip}, + {376, 376, -121}, + {377, 381, OddEvenSkip}, + {383, 383, -268}, + {385, 385, 210}, + {386, 388, EvenOddSkip}, + {390, 390, 206}, + {391, 391, OddEven}, + {393, 394, 205}, + {395, 395, OddEven}, + {398, 398, 79}, + {399, 399, 202}, + {400, 400, 203}, + {401, 401, OddEven}, + {403, 403, 205}, + {404, 404, 207}, + {406, 406, 211}, + {407, 407, 209}, + {408, 408, EvenOdd}, + {412, 412, 211}, + {413, 413, 213}, + {415, 415, 214}, + {416, 420, EvenOddSkip}, + {422, 422, 218}, + {423, 423, OddEven}, + {425, 425, 218}, + {428, 428, EvenOdd}, + {430, 430, 218}, + {431, 431, OddEven}, + {433, 434, 217}, + {435, 437, OddEvenSkip}, + {439, 439, 219}, + {440, 440, EvenOdd}, + {444, 444, EvenOdd}, + {452, 452, 2}, + {453, 453, OddEven}, + {455, 455, 2}, + {456, 456, EvenOdd}, + {458, 458, 2}, + {459, 475, OddEvenSkip}, + {478, 494, EvenOddSkip}, + {497, 497, 2}, + {498, 500, EvenOddSkip}, + {502, 502, -97}, + {503, 503, -56}, + {504, 542, EvenOddSkip}, + {544, 544, -130}, + {546, 562, EvenOddSkip}, + {570, 570, 10795}, + {571, 571, OddEven}, + {573, 573, -163}, + {574, 574, 10792}, + {577, 577, OddEven}, + {579, 579, -195}, + {580, 580, 69}, + {581, 581, 71}, + {582, 590, EvenOddSkip}, + {837, 837, 116}, + {880, 882, EvenOddSkip}, + {886, 886, EvenOdd}, + {895, 895, 116}, + {902, 902, 38}, + {904, 906, 37}, + {908, 908, 64}, + {910, 911, 63}, + {913, 929, 32}, + {931, 939, 32}, + {962, 962, EvenOdd}, + {975, 975, 8}, + {976, 976, -30}, + {977, 977, -25}, + {981, 981, -15}, + {982, 982, -22}, + {984, 1006, EvenOddSkip}, + {1008, 1008, -54}, + {1009, 1009, -48}, + {1012, 1012, -60}, + {1013, 1013, -64}, + {1015, 1015, OddEven}, + {1017, 1017, -7}, + {1018, 1018, EvenOdd}, + {1021, 1023, -130}, + {1024, 1039, 80}, + {1040, 1071, 32}, + {1120, 1152, EvenOddSkip}, + {1162, 1214, EvenOddSkip}, + {1216, 1216, 15}, + {1217, 1229, OddEvenSkip}, + {1232, 1326, EvenOddSkip}, + {1329, 1366, 48}, + {4256, 4293, 7264}, + {4295, 4295, 7264}, + {4301, 4301, 7264}, + {5112, 5117, -8}, + {7296, 7296, -6222}, + {7297, 7297, -6221}, + {7298, 7298, -6212}, + {7299, 7300, -6210}, + {7301, 7301, -6211}, + {7302, 7302, -6204}, + {7303, 7303, -6180}, + {7304, 7304, 35267}, + {7312, 7354, -3008}, + {7357, 7359, -3008}, + {7680, 7828, EvenOddSkip}, + {7835, 7835, -58}, + {7838, 7838, -7615}, + {7840, 7934, EvenOddSkip}, + {7944, 7951, -8}, + {7960, 7965, -8}, + {7976, 7983, -8}, + {7992, 7999, -8}, + {8008, 8013, -8}, + {8025, 8025, -8}, + {8027, 8027, -8}, + {8029, 8029, -8}, + {8031, 8031, -8}, + {8040, 8047, -8}, + {8072, 8079, -8}, + {8088, 8095, -8}, + {8104, 8111, -8}, + {8120, 8121, -8}, + {8122, 8123, -74}, + {8124, 8124, -9}, + {8126, 8126, -7173}, + {8136, 8139, -86}, + {8140, 8140, -9}, + {8152, 8153, -8}, + {8154, 8155, -100}, + {8168, 8169, -8}, + {8170, 8171, -112}, + {8172, 8172, -7}, + {8184, 8185, -128}, + {8186, 8187, -126}, + {8188, 8188, -9}, + {8486, 8486, -7517}, + {8490, 8490, -8383}, + {8491, 8491, -8262}, + {8498, 8498, 28}, + {8544, 8559, 16}, + {8579, 8579, OddEven}, + {9398, 9423, 26}, + {11264, 11311, 48}, + {11360, 11360, EvenOdd}, + {11362, 11362, -10743}, + {11363, 11363, -3814}, + {11364, 11364, -10727}, + {11367, 11371, OddEvenSkip}, + {11373, 11373, -10780}, + {11374, 11374, -10749}, + {11375, 11375, -10783}, + {11376, 11376, -10782}, + {11378, 11378, EvenOdd}, + {11381, 11381, OddEven}, + {11390, 11391, -10815}, + {11392, 11490, EvenOddSkip}, + {11499, 11501, OddEvenSkip}, + {11506, 11506, EvenOdd}, + {42560, 42604, EvenOddSkip}, + {42624, 42650, EvenOddSkip}, + {42786, 42798, EvenOddSkip}, + {42802, 42862, EvenOddSkip}, + {42873, 42875, OddEvenSkip}, + {42877, 42877, -35332}, + {42878, 42886, EvenOddSkip}, + {42891, 42891, OddEven}, + {42893, 42893, -42280}, + {42896, 42898, EvenOddSkip}, + {42902, 42920, EvenOddSkip}, + {42922, 42922, -42308}, + {42923, 42923, -42319}, + {42924, 42924, -42315}, + {42925, 42925, -42305}, + {42926, 42926, -42308}, + {42928, 42928, -42258}, + {42929, 42929, -42282}, + {42930, 42930, -42261}, + {42931, 42931, 928}, + {42932, 42946, EvenOddSkip}, + {42948, 42948, -48}, + {42949, 42949, -42307}, + {42950, 42950, -35384}, + {42951, 42953, OddEvenSkip}, + {42960, 42960, EvenOdd}, + {42966, 42968, EvenOddSkip}, + {42997, 42997, OddEven}, + {43888, 43967, -38864}, + {65313, 65338, 32}, + {66560, 66599, 40}, + {66736, 66771, 40}, + {66928, 66938, 39}, + {66940, 66954, 39}, + {66956, 66962, 39}, + {66964, 66965, 39}, + {68736, 68786, 64}, + {71840, 71871, 32}, + {93760, 93791, 32}, + {125184, 125217, 34}, +}; +const int num_unicode_tolower = 205; + +} // namespace re2 diff --git a/internal/cpp/re2/unicode_casefold.h b/internal/cpp/re2/unicode_casefold.h new file mode 100644 index 00000000000..0e5e3a4ad83 --- /dev/null +++ b/internal/cpp/re2/unicode_casefold.h @@ -0,0 +1,78 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_UNICODE_CASEFOLD_H_ +#define RE2_UNICODE_CASEFOLD_H_ + +// Unicode case folding tables. + +// The Unicode case folding tables encode the mapping from one Unicode point +// to the next largest Unicode point with equivalent folding. The largest +// point wraps back to the first. For example, the tables map: +// +// 'A' -> 'a' +// 'a' -> 'A' +// +// 'K' -> 'k' +// 'k' -> 'K' (Kelvin symbol) +// 'K' -> 'K' +// +// Like everything Unicode, these tables are big. If we represent the table +// as a sorted list of uint32_t pairs, it has 2049 entries and is 16 kB. +// Most table entries look like the ones around them: +// 'A' maps to 'A'+32, 'B' maps to 'B'+32, etc. +// Instead of listing all the pairs explicitly, we make a list of ranges +// and deltas, so that the table entries for 'A' through 'Z' can be represented +// as a single entry { 'A', 'Z', +32 }. +// +// In addition to blocks that map to each other (A-Z mapping to a-z) +// there are blocks of pairs that individually map to each other +// (for example, 0100<->0101, 0102<->0103, 0104<->0105, ...). +// For those, the special delta value EvenOdd marks even/odd pairs +// (if even, add 1; if odd, subtract 1), and OddEven marks odd/even pairs. +// +// In this form, the table has 274 entries, about 3kB. If we were to split +// the table into one for 16-bit codes and an overflow table for larger ones, +// we could get it down to about 1.5kB, but that's not worth the complexity. +// +// The grouped form also allows for efficient fold range calculations +// rather than looping one character at a time. + +#include + +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +enum { + EvenOdd = 1, + OddEven = -1, + EvenOddSkip = 1 << 30, + OddEvenSkip, +}; + +struct CaseFold { + Rune lo; + Rune hi; + int32_t delta; +}; + +extern const CaseFold unicode_casefold[]; +extern const int num_unicode_casefold; + +extern const CaseFold unicode_tolower[]; +extern const int num_unicode_tolower; + +// Returns the CaseFold* in the tables that contains rune. +// If rune is not in the tables, returns the first CaseFold* after rune. +// If rune is larger than any value in the tables, returns NULL. +extern const CaseFold *LookupCaseFold(const CaseFold *, int, Rune rune); + +// Returns the result of applying the fold f to the rune r. +extern Rune ApplyFold(const CaseFold *f, Rune r); + +} // namespace re2 + +#endif // RE2_UNICODE_CASEFOLD_H_ diff --git a/internal/cpp/re2/unicode_groups.cc b/internal/cpp/re2/unicode_groups.cc new file mode 100644 index 00000000000..3b58be4cb8e --- /dev/null +++ b/internal/cpp/re2/unicode_groups.cc @@ -0,0 +1,6512 @@ + +// GENERATED BY make_unicode_groups.py; DO NOT EDIT. +// make_unicode_groups.py >unicode_groups.cc + +#include "re2/unicode_groups.h" + +namespace re2 { + + +static const URange16 C_range16[] = { + { 0, 31 }, + { 127, 159 }, + { 173, 173 }, + { 1536, 1541 }, + { 1564, 1564 }, + { 1757, 1757 }, + { 1807, 1807 }, + { 2192, 2193 }, + { 2274, 2274 }, + { 6158, 6158 }, + { 8203, 8207 }, + { 8234, 8238 }, + { 8288, 8292 }, + { 8294, 8303 }, + { 55296, 63743 }, + { 65279, 65279 }, + { 65529, 65531 }, +}; +static const URange32 C_range32[] = { + { 69821, 69821 }, + { 69837, 69837 }, + { 78896, 78911 }, + { 113824, 113827 }, + { 119155, 119162 }, + { 917505, 917505 }, + { 917536, 917631 }, + { 983040, 1048573 }, + { 1048576, 1114109 }, +}; +static const URange16 Cc_range16[] = { + { 0, 31 }, + { 127, 159 }, +}; +static const URange16 Cf_range16[] = { + { 173, 173 }, + { 1536, 1541 }, + { 1564, 1564 }, + { 1757, 1757 }, + { 1807, 1807 }, + { 2192, 2193 }, + { 2274, 2274 }, + { 6158, 6158 }, + { 8203, 8207 }, + { 8234, 8238 }, + { 8288, 8292 }, + { 8294, 8303 }, + { 65279, 65279 }, + { 65529, 65531 }, +}; +static const URange32 Cf_range32[] = { + { 69821, 69821 }, + { 69837, 69837 }, + { 78896, 78911 }, + { 113824, 113827 }, + { 119155, 119162 }, + { 917505, 917505 }, + { 917536, 917631 }, +}; +static const URange16 Co_range16[] = { + { 57344, 63743 }, +}; +static const URange32 Co_range32[] = { + { 983040, 1048573 }, + { 1048576, 1114109 }, +}; +static const URange16 Cs_range16[] = { + { 55296, 57343 }, +}; +static const URange16 L_range16[] = { + { 65, 90 }, + { 97, 122 }, + { 170, 170 }, + { 181, 181 }, + { 186, 186 }, + { 192, 214 }, + { 216, 246 }, + { 248, 705 }, + { 710, 721 }, + { 736, 740 }, + { 748, 748 }, + { 750, 750 }, + { 880, 884 }, + { 886, 887 }, + { 890, 893 }, + { 895, 895 }, + { 902, 902 }, + { 904, 906 }, + { 908, 908 }, + { 910, 929 }, + { 931, 1013 }, + { 1015, 1153 }, + { 1162, 1327 }, + { 1329, 1366 }, + { 1369, 1369 }, + { 1376, 1416 }, + { 1488, 1514 }, + { 1519, 1522 }, + { 1568, 1610 }, + { 1646, 1647 }, + { 1649, 1747 }, + { 1749, 1749 }, + { 1765, 1766 }, + { 1774, 1775 }, + { 1786, 1788 }, + { 1791, 1791 }, + { 1808, 1808 }, + { 1810, 1839 }, + { 1869, 1957 }, + { 1969, 1969 }, + { 1994, 2026 }, + { 2036, 2037 }, + { 2042, 2042 }, + { 2048, 2069 }, + { 2074, 2074 }, + { 2084, 2084 }, + { 2088, 2088 }, + { 2112, 2136 }, + { 2144, 2154 }, + { 2160, 2183 }, + { 2185, 2190 }, + { 2208, 2249 }, + { 2308, 2361 }, + { 2365, 2365 }, + { 2384, 2384 }, + { 2392, 2401 }, + { 2417, 2432 }, + { 2437, 2444 }, + { 2447, 2448 }, + { 2451, 2472 }, + { 2474, 2480 }, + { 2482, 2482 }, + { 2486, 2489 }, + { 2493, 2493 }, + { 2510, 2510 }, + { 2524, 2525 }, + { 2527, 2529 }, + { 2544, 2545 }, + { 2556, 2556 }, + { 2565, 2570 }, + { 2575, 2576 }, + { 2579, 2600 }, + { 2602, 2608 }, + { 2610, 2611 }, + { 2613, 2614 }, + { 2616, 2617 }, + { 2649, 2652 }, + { 2654, 2654 }, + { 2674, 2676 }, + { 2693, 2701 }, + { 2703, 2705 }, + { 2707, 2728 }, + { 2730, 2736 }, + { 2738, 2739 }, + { 2741, 2745 }, + { 2749, 2749 }, + { 2768, 2768 }, + { 2784, 2785 }, + { 2809, 2809 }, + { 2821, 2828 }, + { 2831, 2832 }, + { 2835, 2856 }, + { 2858, 2864 }, + { 2866, 2867 }, + { 2869, 2873 }, + { 2877, 2877 }, + { 2908, 2909 }, + { 2911, 2913 }, + { 2929, 2929 }, + { 2947, 2947 }, + { 2949, 2954 }, + { 2958, 2960 }, + { 2962, 2965 }, + { 2969, 2970 }, + { 2972, 2972 }, + { 2974, 2975 }, + { 2979, 2980 }, + { 2984, 2986 }, + { 2990, 3001 }, + { 3024, 3024 }, + { 3077, 3084 }, + { 3086, 3088 }, + { 3090, 3112 }, + { 3114, 3129 }, + { 3133, 3133 }, + { 3160, 3162 }, + { 3165, 3165 }, + { 3168, 3169 }, + { 3200, 3200 }, + { 3205, 3212 }, + { 3214, 3216 }, + { 3218, 3240 }, + { 3242, 3251 }, + { 3253, 3257 }, + { 3261, 3261 }, + { 3293, 3294 }, + { 3296, 3297 }, + { 3313, 3314 }, + { 3332, 3340 }, + { 3342, 3344 }, + { 3346, 3386 }, + { 3389, 3389 }, + { 3406, 3406 }, + { 3412, 3414 }, + { 3423, 3425 }, + { 3450, 3455 }, + { 3461, 3478 }, + { 3482, 3505 }, + { 3507, 3515 }, + { 3517, 3517 }, + { 3520, 3526 }, + { 3585, 3632 }, + { 3634, 3635 }, + { 3648, 3654 }, + { 3713, 3714 }, + { 3716, 3716 }, + { 3718, 3722 }, + { 3724, 3747 }, + { 3749, 3749 }, + { 3751, 3760 }, + { 3762, 3763 }, + { 3773, 3773 }, + { 3776, 3780 }, + { 3782, 3782 }, + { 3804, 3807 }, + { 3840, 3840 }, + { 3904, 3911 }, + { 3913, 3948 }, + { 3976, 3980 }, + { 4096, 4138 }, + { 4159, 4159 }, + { 4176, 4181 }, + { 4186, 4189 }, + { 4193, 4193 }, + { 4197, 4198 }, + { 4206, 4208 }, + { 4213, 4225 }, + { 4238, 4238 }, + { 4256, 4293 }, + { 4295, 4295 }, + { 4301, 4301 }, + { 4304, 4346 }, + { 4348, 4680 }, + { 4682, 4685 }, + { 4688, 4694 }, + { 4696, 4696 }, + { 4698, 4701 }, + { 4704, 4744 }, + { 4746, 4749 }, + { 4752, 4784 }, + { 4786, 4789 }, + { 4792, 4798 }, + { 4800, 4800 }, + { 4802, 4805 }, + { 4808, 4822 }, + { 4824, 4880 }, + { 4882, 4885 }, + { 4888, 4954 }, + { 4992, 5007 }, + { 5024, 5109 }, + { 5112, 5117 }, + { 5121, 5740 }, + { 5743, 5759 }, + { 5761, 5786 }, + { 5792, 5866 }, + { 5873, 5880 }, + { 5888, 5905 }, + { 5919, 5937 }, + { 5952, 5969 }, + { 5984, 5996 }, + { 5998, 6000 }, + { 6016, 6067 }, + { 6103, 6103 }, + { 6108, 6108 }, + { 6176, 6264 }, + { 6272, 6276 }, + { 6279, 6312 }, + { 6314, 6314 }, + { 6320, 6389 }, + { 6400, 6430 }, + { 6480, 6509 }, + { 6512, 6516 }, + { 6528, 6571 }, + { 6576, 6601 }, + { 6656, 6678 }, + { 6688, 6740 }, + { 6823, 6823 }, + { 6917, 6963 }, + { 6981, 6988 }, + { 7043, 7072 }, + { 7086, 7087 }, + { 7098, 7141 }, + { 7168, 7203 }, + { 7245, 7247 }, + { 7258, 7293 }, + { 7296, 7304 }, + { 7312, 7354 }, + { 7357, 7359 }, + { 7401, 7404 }, + { 7406, 7411 }, + { 7413, 7414 }, + { 7418, 7418 }, + { 7424, 7615 }, + { 7680, 7957 }, + { 7960, 7965 }, + { 7968, 8005 }, + { 8008, 8013 }, + { 8016, 8023 }, + { 8025, 8025 }, + { 8027, 8027 }, + { 8029, 8029 }, + { 8031, 8061 }, + { 8064, 8116 }, + { 8118, 8124 }, + { 8126, 8126 }, + { 8130, 8132 }, + { 8134, 8140 }, + { 8144, 8147 }, + { 8150, 8155 }, + { 8160, 8172 }, + { 8178, 8180 }, + { 8182, 8188 }, + { 8305, 8305 }, + { 8319, 8319 }, + { 8336, 8348 }, + { 8450, 8450 }, + { 8455, 8455 }, + { 8458, 8467 }, + { 8469, 8469 }, + { 8473, 8477 }, + { 8484, 8484 }, + { 8486, 8486 }, + { 8488, 8488 }, + { 8490, 8493 }, + { 8495, 8505 }, + { 8508, 8511 }, + { 8517, 8521 }, + { 8526, 8526 }, + { 8579, 8580 }, + { 11264, 11492 }, + { 11499, 11502 }, + { 11506, 11507 }, + { 11520, 11557 }, + { 11559, 11559 }, + { 11565, 11565 }, + { 11568, 11623 }, + { 11631, 11631 }, + { 11648, 11670 }, + { 11680, 11686 }, + { 11688, 11694 }, + { 11696, 11702 }, + { 11704, 11710 }, + { 11712, 11718 }, + { 11720, 11726 }, + { 11728, 11734 }, + { 11736, 11742 }, + { 11823, 11823 }, + { 12293, 12294 }, + { 12337, 12341 }, + { 12347, 12348 }, + { 12353, 12438 }, + { 12445, 12447 }, + { 12449, 12538 }, + { 12540, 12543 }, + { 12549, 12591 }, + { 12593, 12686 }, + { 12704, 12735 }, + { 12784, 12799 }, + { 13312, 19903 }, + { 19968, 42124 }, + { 42192, 42237 }, + { 42240, 42508 }, + { 42512, 42527 }, + { 42538, 42539 }, + { 42560, 42606 }, + { 42623, 42653 }, + { 42656, 42725 }, + { 42775, 42783 }, + { 42786, 42888 }, + { 42891, 42954 }, + { 42960, 42961 }, + { 42963, 42963 }, + { 42965, 42969 }, + { 42994, 43009 }, + { 43011, 43013 }, + { 43015, 43018 }, + { 43020, 43042 }, + { 43072, 43123 }, + { 43138, 43187 }, + { 43250, 43255 }, + { 43259, 43259 }, + { 43261, 43262 }, + { 43274, 43301 }, + { 43312, 43334 }, + { 43360, 43388 }, + { 43396, 43442 }, + { 43471, 43471 }, + { 43488, 43492 }, + { 43494, 43503 }, + { 43514, 43518 }, + { 43520, 43560 }, + { 43584, 43586 }, + { 43588, 43595 }, + { 43616, 43638 }, + { 43642, 43642 }, + { 43646, 43695 }, + { 43697, 43697 }, + { 43701, 43702 }, + { 43705, 43709 }, + { 43712, 43712 }, + { 43714, 43714 }, + { 43739, 43741 }, + { 43744, 43754 }, + { 43762, 43764 }, + { 43777, 43782 }, + { 43785, 43790 }, + { 43793, 43798 }, + { 43808, 43814 }, + { 43816, 43822 }, + { 43824, 43866 }, + { 43868, 43881 }, + { 43888, 44002 }, + { 44032, 55203 }, + { 55216, 55238 }, + { 55243, 55291 }, + { 63744, 64109 }, + { 64112, 64217 }, + { 64256, 64262 }, + { 64275, 64279 }, + { 64285, 64285 }, + { 64287, 64296 }, + { 64298, 64310 }, + { 64312, 64316 }, + { 64318, 64318 }, + { 64320, 64321 }, + { 64323, 64324 }, + { 64326, 64433 }, + { 64467, 64829 }, + { 64848, 64911 }, + { 64914, 64967 }, + { 65008, 65019 }, + { 65136, 65140 }, + { 65142, 65276 }, + { 65313, 65338 }, + { 65345, 65370 }, + { 65382, 65470 }, + { 65474, 65479 }, + { 65482, 65487 }, + { 65490, 65495 }, + { 65498, 65500 }, +}; +static const URange32 L_range32[] = { + { 65536, 65547 }, + { 65549, 65574 }, + { 65576, 65594 }, + { 65596, 65597 }, + { 65599, 65613 }, + { 65616, 65629 }, + { 65664, 65786 }, + { 66176, 66204 }, + { 66208, 66256 }, + { 66304, 66335 }, + { 66349, 66368 }, + { 66370, 66377 }, + { 66384, 66421 }, + { 66432, 66461 }, + { 66464, 66499 }, + { 66504, 66511 }, + { 66560, 66717 }, + { 66736, 66771 }, + { 66776, 66811 }, + { 66816, 66855 }, + { 66864, 66915 }, + { 66928, 66938 }, + { 66940, 66954 }, + { 66956, 66962 }, + { 66964, 66965 }, + { 66967, 66977 }, + { 66979, 66993 }, + { 66995, 67001 }, + { 67003, 67004 }, + { 67072, 67382 }, + { 67392, 67413 }, + { 67424, 67431 }, + { 67456, 67461 }, + { 67463, 67504 }, + { 67506, 67514 }, + { 67584, 67589 }, + { 67592, 67592 }, + { 67594, 67637 }, + { 67639, 67640 }, + { 67644, 67644 }, + { 67647, 67669 }, + { 67680, 67702 }, + { 67712, 67742 }, + { 67808, 67826 }, + { 67828, 67829 }, + { 67840, 67861 }, + { 67872, 67897 }, + { 67968, 68023 }, + { 68030, 68031 }, + { 68096, 68096 }, + { 68112, 68115 }, + { 68117, 68119 }, + { 68121, 68149 }, + { 68192, 68220 }, + { 68224, 68252 }, + { 68288, 68295 }, + { 68297, 68324 }, + { 68352, 68405 }, + { 68416, 68437 }, + { 68448, 68466 }, + { 68480, 68497 }, + { 68608, 68680 }, + { 68736, 68786 }, + { 68800, 68850 }, + { 68864, 68899 }, + { 69248, 69289 }, + { 69296, 69297 }, + { 69376, 69404 }, + { 69415, 69415 }, + { 69424, 69445 }, + { 69488, 69505 }, + { 69552, 69572 }, + { 69600, 69622 }, + { 69635, 69687 }, + { 69745, 69746 }, + { 69749, 69749 }, + { 69763, 69807 }, + { 69840, 69864 }, + { 69891, 69926 }, + { 69956, 69956 }, + { 69959, 69959 }, + { 69968, 70002 }, + { 70006, 70006 }, + { 70019, 70066 }, + { 70081, 70084 }, + { 70106, 70106 }, + { 70108, 70108 }, + { 70144, 70161 }, + { 70163, 70187 }, + { 70207, 70208 }, + { 70272, 70278 }, + { 70280, 70280 }, + { 70282, 70285 }, + { 70287, 70301 }, + { 70303, 70312 }, + { 70320, 70366 }, + { 70405, 70412 }, + { 70415, 70416 }, + { 70419, 70440 }, + { 70442, 70448 }, + { 70450, 70451 }, + { 70453, 70457 }, + { 70461, 70461 }, + { 70480, 70480 }, + { 70493, 70497 }, + { 70656, 70708 }, + { 70727, 70730 }, + { 70751, 70753 }, + { 70784, 70831 }, + { 70852, 70853 }, + { 70855, 70855 }, + { 71040, 71086 }, + { 71128, 71131 }, + { 71168, 71215 }, + { 71236, 71236 }, + { 71296, 71338 }, + { 71352, 71352 }, + { 71424, 71450 }, + { 71488, 71494 }, + { 71680, 71723 }, + { 71840, 71903 }, + { 71935, 71942 }, + { 71945, 71945 }, + { 71948, 71955 }, + { 71957, 71958 }, + { 71960, 71983 }, + { 71999, 71999 }, + { 72001, 72001 }, + { 72096, 72103 }, + { 72106, 72144 }, + { 72161, 72161 }, + { 72163, 72163 }, + { 72192, 72192 }, + { 72203, 72242 }, + { 72250, 72250 }, + { 72272, 72272 }, + { 72284, 72329 }, + { 72349, 72349 }, + { 72368, 72440 }, + { 72704, 72712 }, + { 72714, 72750 }, + { 72768, 72768 }, + { 72818, 72847 }, + { 72960, 72966 }, + { 72968, 72969 }, + { 72971, 73008 }, + { 73030, 73030 }, + { 73056, 73061 }, + { 73063, 73064 }, + { 73066, 73097 }, + { 73112, 73112 }, + { 73440, 73458 }, + { 73474, 73474 }, + { 73476, 73488 }, + { 73490, 73523 }, + { 73648, 73648 }, + { 73728, 74649 }, + { 74880, 75075 }, + { 77712, 77808 }, + { 77824, 78895 }, + { 78913, 78918 }, + { 82944, 83526 }, + { 92160, 92728 }, + { 92736, 92766 }, + { 92784, 92862 }, + { 92880, 92909 }, + { 92928, 92975 }, + { 92992, 92995 }, + { 93027, 93047 }, + { 93053, 93071 }, + { 93760, 93823 }, + { 93952, 94026 }, + { 94032, 94032 }, + { 94099, 94111 }, + { 94176, 94177 }, + { 94179, 94179 }, + { 94208, 100343 }, + { 100352, 101589 }, + { 101632, 101640 }, + { 110576, 110579 }, + { 110581, 110587 }, + { 110589, 110590 }, + { 110592, 110882 }, + { 110898, 110898 }, + { 110928, 110930 }, + { 110933, 110933 }, + { 110948, 110951 }, + { 110960, 111355 }, + { 113664, 113770 }, + { 113776, 113788 }, + { 113792, 113800 }, + { 113808, 113817 }, + { 119808, 119892 }, + { 119894, 119964 }, + { 119966, 119967 }, + { 119970, 119970 }, + { 119973, 119974 }, + { 119977, 119980 }, + { 119982, 119993 }, + { 119995, 119995 }, + { 119997, 120003 }, + { 120005, 120069 }, + { 120071, 120074 }, + { 120077, 120084 }, + { 120086, 120092 }, + { 120094, 120121 }, + { 120123, 120126 }, + { 120128, 120132 }, + { 120134, 120134 }, + { 120138, 120144 }, + { 120146, 120485 }, + { 120488, 120512 }, + { 120514, 120538 }, + { 120540, 120570 }, + { 120572, 120596 }, + { 120598, 120628 }, + { 120630, 120654 }, + { 120656, 120686 }, + { 120688, 120712 }, + { 120714, 120744 }, + { 120746, 120770 }, + { 120772, 120779 }, + { 122624, 122654 }, + { 122661, 122666 }, + { 122928, 122989 }, + { 123136, 123180 }, + { 123191, 123197 }, + { 123214, 123214 }, + { 123536, 123565 }, + { 123584, 123627 }, + { 124112, 124139 }, + { 124896, 124902 }, + { 124904, 124907 }, + { 124909, 124910 }, + { 124912, 124926 }, + { 124928, 125124 }, + { 125184, 125251 }, + { 125259, 125259 }, + { 126464, 126467 }, + { 126469, 126495 }, + { 126497, 126498 }, + { 126500, 126500 }, + { 126503, 126503 }, + { 126505, 126514 }, + { 126516, 126519 }, + { 126521, 126521 }, + { 126523, 126523 }, + { 126530, 126530 }, + { 126535, 126535 }, + { 126537, 126537 }, + { 126539, 126539 }, + { 126541, 126543 }, + { 126545, 126546 }, + { 126548, 126548 }, + { 126551, 126551 }, + { 126553, 126553 }, + { 126555, 126555 }, + { 126557, 126557 }, + { 126559, 126559 }, + { 126561, 126562 }, + { 126564, 126564 }, + { 126567, 126570 }, + { 126572, 126578 }, + { 126580, 126583 }, + { 126585, 126588 }, + { 126590, 126590 }, + { 126592, 126601 }, + { 126603, 126619 }, + { 126625, 126627 }, + { 126629, 126633 }, + { 126635, 126651 }, + { 131072, 173791 }, + { 173824, 177977 }, + { 177984, 178205 }, + { 178208, 183969 }, + { 183984, 191456 }, + { 194560, 195101 }, + { 196608, 201546 }, + { 201552, 205743 }, +}; +static const URange16 Ll_range16[] = { + { 97, 122 }, + { 181, 181 }, + { 223, 246 }, + { 248, 255 }, + { 257, 257 }, + { 259, 259 }, + { 261, 261 }, + { 263, 263 }, + { 265, 265 }, + { 267, 267 }, + { 269, 269 }, + { 271, 271 }, + { 273, 273 }, + { 275, 275 }, + { 277, 277 }, + { 279, 279 }, + { 281, 281 }, + { 283, 283 }, + { 285, 285 }, + { 287, 287 }, + { 289, 289 }, + { 291, 291 }, + { 293, 293 }, + { 295, 295 }, + { 297, 297 }, + { 299, 299 }, + { 301, 301 }, + { 303, 303 }, + { 305, 305 }, + { 307, 307 }, + { 309, 309 }, + { 311, 312 }, + { 314, 314 }, + { 316, 316 }, + { 318, 318 }, + { 320, 320 }, + { 322, 322 }, + { 324, 324 }, + { 326, 326 }, + { 328, 329 }, + { 331, 331 }, + { 333, 333 }, + { 335, 335 }, + { 337, 337 }, + { 339, 339 }, + { 341, 341 }, + { 343, 343 }, + { 345, 345 }, + { 347, 347 }, + { 349, 349 }, + { 351, 351 }, + { 353, 353 }, + { 355, 355 }, + { 357, 357 }, + { 359, 359 }, + { 361, 361 }, + { 363, 363 }, + { 365, 365 }, + { 367, 367 }, + { 369, 369 }, + { 371, 371 }, + { 373, 373 }, + { 375, 375 }, + { 378, 378 }, + { 380, 380 }, + { 382, 384 }, + { 387, 387 }, + { 389, 389 }, + { 392, 392 }, + { 396, 397 }, + { 402, 402 }, + { 405, 405 }, + { 409, 411 }, + { 414, 414 }, + { 417, 417 }, + { 419, 419 }, + { 421, 421 }, + { 424, 424 }, + { 426, 427 }, + { 429, 429 }, + { 432, 432 }, + { 436, 436 }, + { 438, 438 }, + { 441, 442 }, + { 445, 447 }, + { 454, 454 }, + { 457, 457 }, + { 460, 460 }, + { 462, 462 }, + { 464, 464 }, + { 466, 466 }, + { 468, 468 }, + { 470, 470 }, + { 472, 472 }, + { 474, 474 }, + { 476, 477 }, + { 479, 479 }, + { 481, 481 }, + { 483, 483 }, + { 485, 485 }, + { 487, 487 }, + { 489, 489 }, + { 491, 491 }, + { 493, 493 }, + { 495, 496 }, + { 499, 499 }, + { 501, 501 }, + { 505, 505 }, + { 507, 507 }, + { 509, 509 }, + { 511, 511 }, + { 513, 513 }, + { 515, 515 }, + { 517, 517 }, + { 519, 519 }, + { 521, 521 }, + { 523, 523 }, + { 525, 525 }, + { 527, 527 }, + { 529, 529 }, + { 531, 531 }, + { 533, 533 }, + { 535, 535 }, + { 537, 537 }, + { 539, 539 }, + { 541, 541 }, + { 543, 543 }, + { 545, 545 }, + { 547, 547 }, + { 549, 549 }, + { 551, 551 }, + { 553, 553 }, + { 555, 555 }, + { 557, 557 }, + { 559, 559 }, + { 561, 561 }, + { 563, 569 }, + { 572, 572 }, + { 575, 576 }, + { 578, 578 }, + { 583, 583 }, + { 585, 585 }, + { 587, 587 }, + { 589, 589 }, + { 591, 659 }, + { 661, 687 }, + { 881, 881 }, + { 883, 883 }, + { 887, 887 }, + { 891, 893 }, + { 912, 912 }, + { 940, 974 }, + { 976, 977 }, + { 981, 983 }, + { 985, 985 }, + { 987, 987 }, + { 989, 989 }, + { 991, 991 }, + { 993, 993 }, + { 995, 995 }, + { 997, 997 }, + { 999, 999 }, + { 1001, 1001 }, + { 1003, 1003 }, + { 1005, 1005 }, + { 1007, 1011 }, + { 1013, 1013 }, + { 1016, 1016 }, + { 1019, 1020 }, + { 1072, 1119 }, + { 1121, 1121 }, + { 1123, 1123 }, + { 1125, 1125 }, + { 1127, 1127 }, + { 1129, 1129 }, + { 1131, 1131 }, + { 1133, 1133 }, + { 1135, 1135 }, + { 1137, 1137 }, + { 1139, 1139 }, + { 1141, 1141 }, + { 1143, 1143 }, + { 1145, 1145 }, + { 1147, 1147 }, + { 1149, 1149 }, + { 1151, 1151 }, + { 1153, 1153 }, + { 1163, 1163 }, + { 1165, 1165 }, + { 1167, 1167 }, + { 1169, 1169 }, + { 1171, 1171 }, + { 1173, 1173 }, + { 1175, 1175 }, + { 1177, 1177 }, + { 1179, 1179 }, + { 1181, 1181 }, + { 1183, 1183 }, + { 1185, 1185 }, + { 1187, 1187 }, + { 1189, 1189 }, + { 1191, 1191 }, + { 1193, 1193 }, + { 1195, 1195 }, + { 1197, 1197 }, + { 1199, 1199 }, + { 1201, 1201 }, + { 1203, 1203 }, + { 1205, 1205 }, + { 1207, 1207 }, + { 1209, 1209 }, + { 1211, 1211 }, + { 1213, 1213 }, + { 1215, 1215 }, + { 1218, 1218 }, + { 1220, 1220 }, + { 1222, 1222 }, + { 1224, 1224 }, + { 1226, 1226 }, + { 1228, 1228 }, + { 1230, 1231 }, + { 1233, 1233 }, + { 1235, 1235 }, + { 1237, 1237 }, + { 1239, 1239 }, + { 1241, 1241 }, + { 1243, 1243 }, + { 1245, 1245 }, + { 1247, 1247 }, + { 1249, 1249 }, + { 1251, 1251 }, + { 1253, 1253 }, + { 1255, 1255 }, + { 1257, 1257 }, + { 1259, 1259 }, + { 1261, 1261 }, + { 1263, 1263 }, + { 1265, 1265 }, + { 1267, 1267 }, + { 1269, 1269 }, + { 1271, 1271 }, + { 1273, 1273 }, + { 1275, 1275 }, + { 1277, 1277 }, + { 1279, 1279 }, + { 1281, 1281 }, + { 1283, 1283 }, + { 1285, 1285 }, + { 1287, 1287 }, + { 1289, 1289 }, + { 1291, 1291 }, + { 1293, 1293 }, + { 1295, 1295 }, + { 1297, 1297 }, + { 1299, 1299 }, + { 1301, 1301 }, + { 1303, 1303 }, + { 1305, 1305 }, + { 1307, 1307 }, + { 1309, 1309 }, + { 1311, 1311 }, + { 1313, 1313 }, + { 1315, 1315 }, + { 1317, 1317 }, + { 1319, 1319 }, + { 1321, 1321 }, + { 1323, 1323 }, + { 1325, 1325 }, + { 1327, 1327 }, + { 1376, 1416 }, + { 4304, 4346 }, + { 4349, 4351 }, + { 5112, 5117 }, + { 7296, 7304 }, + { 7424, 7467 }, + { 7531, 7543 }, + { 7545, 7578 }, + { 7681, 7681 }, + { 7683, 7683 }, + { 7685, 7685 }, + { 7687, 7687 }, + { 7689, 7689 }, + { 7691, 7691 }, + { 7693, 7693 }, + { 7695, 7695 }, + { 7697, 7697 }, + { 7699, 7699 }, + { 7701, 7701 }, + { 7703, 7703 }, + { 7705, 7705 }, + { 7707, 7707 }, + { 7709, 7709 }, + { 7711, 7711 }, + { 7713, 7713 }, + { 7715, 7715 }, + { 7717, 7717 }, + { 7719, 7719 }, + { 7721, 7721 }, + { 7723, 7723 }, + { 7725, 7725 }, + { 7727, 7727 }, + { 7729, 7729 }, + { 7731, 7731 }, + { 7733, 7733 }, + { 7735, 7735 }, + { 7737, 7737 }, + { 7739, 7739 }, + { 7741, 7741 }, + { 7743, 7743 }, + { 7745, 7745 }, + { 7747, 7747 }, + { 7749, 7749 }, + { 7751, 7751 }, + { 7753, 7753 }, + { 7755, 7755 }, + { 7757, 7757 }, + { 7759, 7759 }, + { 7761, 7761 }, + { 7763, 7763 }, + { 7765, 7765 }, + { 7767, 7767 }, + { 7769, 7769 }, + { 7771, 7771 }, + { 7773, 7773 }, + { 7775, 7775 }, + { 7777, 7777 }, + { 7779, 7779 }, + { 7781, 7781 }, + { 7783, 7783 }, + { 7785, 7785 }, + { 7787, 7787 }, + { 7789, 7789 }, + { 7791, 7791 }, + { 7793, 7793 }, + { 7795, 7795 }, + { 7797, 7797 }, + { 7799, 7799 }, + { 7801, 7801 }, + { 7803, 7803 }, + { 7805, 7805 }, + { 7807, 7807 }, + { 7809, 7809 }, + { 7811, 7811 }, + { 7813, 7813 }, + { 7815, 7815 }, + { 7817, 7817 }, + { 7819, 7819 }, + { 7821, 7821 }, + { 7823, 7823 }, + { 7825, 7825 }, + { 7827, 7827 }, + { 7829, 7837 }, + { 7839, 7839 }, + { 7841, 7841 }, + { 7843, 7843 }, + { 7845, 7845 }, + { 7847, 7847 }, + { 7849, 7849 }, + { 7851, 7851 }, + { 7853, 7853 }, + { 7855, 7855 }, + { 7857, 7857 }, + { 7859, 7859 }, + { 7861, 7861 }, + { 7863, 7863 }, + { 7865, 7865 }, + { 7867, 7867 }, + { 7869, 7869 }, + { 7871, 7871 }, + { 7873, 7873 }, + { 7875, 7875 }, + { 7877, 7877 }, + { 7879, 7879 }, + { 7881, 7881 }, + { 7883, 7883 }, + { 7885, 7885 }, + { 7887, 7887 }, + { 7889, 7889 }, + { 7891, 7891 }, + { 7893, 7893 }, + { 7895, 7895 }, + { 7897, 7897 }, + { 7899, 7899 }, + { 7901, 7901 }, + { 7903, 7903 }, + { 7905, 7905 }, + { 7907, 7907 }, + { 7909, 7909 }, + { 7911, 7911 }, + { 7913, 7913 }, + { 7915, 7915 }, + { 7917, 7917 }, + { 7919, 7919 }, + { 7921, 7921 }, + { 7923, 7923 }, + { 7925, 7925 }, + { 7927, 7927 }, + { 7929, 7929 }, + { 7931, 7931 }, + { 7933, 7933 }, + { 7935, 7943 }, + { 7952, 7957 }, + { 7968, 7975 }, + { 7984, 7991 }, + { 8000, 8005 }, + { 8016, 8023 }, + { 8032, 8039 }, + { 8048, 8061 }, + { 8064, 8071 }, + { 8080, 8087 }, + { 8096, 8103 }, + { 8112, 8116 }, + { 8118, 8119 }, + { 8126, 8126 }, + { 8130, 8132 }, + { 8134, 8135 }, + { 8144, 8147 }, + { 8150, 8151 }, + { 8160, 8167 }, + { 8178, 8180 }, + { 8182, 8183 }, + { 8458, 8458 }, + { 8462, 8463 }, + { 8467, 8467 }, + { 8495, 8495 }, + { 8500, 8500 }, + { 8505, 8505 }, + { 8508, 8509 }, + { 8518, 8521 }, + { 8526, 8526 }, + { 8580, 8580 }, + { 11312, 11359 }, + { 11361, 11361 }, + { 11365, 11366 }, + { 11368, 11368 }, + { 11370, 11370 }, + { 11372, 11372 }, + { 11377, 11377 }, + { 11379, 11380 }, + { 11382, 11387 }, + { 11393, 11393 }, + { 11395, 11395 }, + { 11397, 11397 }, + { 11399, 11399 }, + { 11401, 11401 }, + { 11403, 11403 }, + { 11405, 11405 }, + { 11407, 11407 }, + { 11409, 11409 }, + { 11411, 11411 }, + { 11413, 11413 }, + { 11415, 11415 }, + { 11417, 11417 }, + { 11419, 11419 }, + { 11421, 11421 }, + { 11423, 11423 }, + { 11425, 11425 }, + { 11427, 11427 }, + { 11429, 11429 }, + { 11431, 11431 }, + { 11433, 11433 }, + { 11435, 11435 }, + { 11437, 11437 }, + { 11439, 11439 }, + { 11441, 11441 }, + { 11443, 11443 }, + { 11445, 11445 }, + { 11447, 11447 }, + { 11449, 11449 }, + { 11451, 11451 }, + { 11453, 11453 }, + { 11455, 11455 }, + { 11457, 11457 }, + { 11459, 11459 }, + { 11461, 11461 }, + { 11463, 11463 }, + { 11465, 11465 }, + { 11467, 11467 }, + { 11469, 11469 }, + { 11471, 11471 }, + { 11473, 11473 }, + { 11475, 11475 }, + { 11477, 11477 }, + { 11479, 11479 }, + { 11481, 11481 }, + { 11483, 11483 }, + { 11485, 11485 }, + { 11487, 11487 }, + { 11489, 11489 }, + { 11491, 11492 }, + { 11500, 11500 }, + { 11502, 11502 }, + { 11507, 11507 }, + { 11520, 11557 }, + { 11559, 11559 }, + { 11565, 11565 }, + { 42561, 42561 }, + { 42563, 42563 }, + { 42565, 42565 }, + { 42567, 42567 }, + { 42569, 42569 }, + { 42571, 42571 }, + { 42573, 42573 }, + { 42575, 42575 }, + { 42577, 42577 }, + { 42579, 42579 }, + { 42581, 42581 }, + { 42583, 42583 }, + { 42585, 42585 }, + { 42587, 42587 }, + { 42589, 42589 }, + { 42591, 42591 }, + { 42593, 42593 }, + { 42595, 42595 }, + { 42597, 42597 }, + { 42599, 42599 }, + { 42601, 42601 }, + { 42603, 42603 }, + { 42605, 42605 }, + { 42625, 42625 }, + { 42627, 42627 }, + { 42629, 42629 }, + { 42631, 42631 }, + { 42633, 42633 }, + { 42635, 42635 }, + { 42637, 42637 }, + { 42639, 42639 }, + { 42641, 42641 }, + { 42643, 42643 }, + { 42645, 42645 }, + { 42647, 42647 }, + { 42649, 42649 }, + { 42651, 42651 }, + { 42787, 42787 }, + { 42789, 42789 }, + { 42791, 42791 }, + { 42793, 42793 }, + { 42795, 42795 }, + { 42797, 42797 }, + { 42799, 42801 }, + { 42803, 42803 }, + { 42805, 42805 }, + { 42807, 42807 }, + { 42809, 42809 }, + { 42811, 42811 }, + { 42813, 42813 }, + { 42815, 42815 }, + { 42817, 42817 }, + { 42819, 42819 }, + { 42821, 42821 }, + { 42823, 42823 }, + { 42825, 42825 }, + { 42827, 42827 }, + { 42829, 42829 }, + { 42831, 42831 }, + { 42833, 42833 }, + { 42835, 42835 }, + { 42837, 42837 }, + { 42839, 42839 }, + { 42841, 42841 }, + { 42843, 42843 }, + { 42845, 42845 }, + { 42847, 42847 }, + { 42849, 42849 }, + { 42851, 42851 }, + { 42853, 42853 }, + { 42855, 42855 }, + { 42857, 42857 }, + { 42859, 42859 }, + { 42861, 42861 }, + { 42863, 42863 }, + { 42865, 42872 }, + { 42874, 42874 }, + { 42876, 42876 }, + { 42879, 42879 }, + { 42881, 42881 }, + { 42883, 42883 }, + { 42885, 42885 }, + { 42887, 42887 }, + { 42892, 42892 }, + { 42894, 42894 }, + { 42897, 42897 }, + { 42899, 42901 }, + { 42903, 42903 }, + { 42905, 42905 }, + { 42907, 42907 }, + { 42909, 42909 }, + { 42911, 42911 }, + { 42913, 42913 }, + { 42915, 42915 }, + { 42917, 42917 }, + { 42919, 42919 }, + { 42921, 42921 }, + { 42927, 42927 }, + { 42933, 42933 }, + { 42935, 42935 }, + { 42937, 42937 }, + { 42939, 42939 }, + { 42941, 42941 }, + { 42943, 42943 }, + { 42945, 42945 }, + { 42947, 42947 }, + { 42952, 42952 }, + { 42954, 42954 }, + { 42961, 42961 }, + { 42963, 42963 }, + { 42965, 42965 }, + { 42967, 42967 }, + { 42969, 42969 }, + { 42998, 42998 }, + { 43002, 43002 }, + { 43824, 43866 }, + { 43872, 43880 }, + { 43888, 43967 }, + { 64256, 64262 }, + { 64275, 64279 }, + { 65345, 65370 }, +}; +static const URange32 Ll_range32[] = { + { 66600, 66639 }, + { 66776, 66811 }, + { 66967, 66977 }, + { 66979, 66993 }, + { 66995, 67001 }, + { 67003, 67004 }, + { 68800, 68850 }, + { 71872, 71903 }, + { 93792, 93823 }, + { 119834, 119859 }, + { 119886, 119892 }, + { 119894, 119911 }, + { 119938, 119963 }, + { 119990, 119993 }, + { 119995, 119995 }, + { 119997, 120003 }, + { 120005, 120015 }, + { 120042, 120067 }, + { 120094, 120119 }, + { 120146, 120171 }, + { 120198, 120223 }, + { 120250, 120275 }, + { 120302, 120327 }, + { 120354, 120379 }, + { 120406, 120431 }, + { 120458, 120485 }, + { 120514, 120538 }, + { 120540, 120545 }, + { 120572, 120596 }, + { 120598, 120603 }, + { 120630, 120654 }, + { 120656, 120661 }, + { 120688, 120712 }, + { 120714, 120719 }, + { 120746, 120770 }, + { 120772, 120777 }, + { 120779, 120779 }, + { 122624, 122633 }, + { 122635, 122654 }, + { 122661, 122666 }, + { 125218, 125251 }, +}; +static const URange16 Lm_range16[] = { + { 688, 705 }, + { 710, 721 }, + { 736, 740 }, + { 748, 748 }, + { 750, 750 }, + { 884, 884 }, + { 890, 890 }, + { 1369, 1369 }, + { 1600, 1600 }, + { 1765, 1766 }, + { 2036, 2037 }, + { 2042, 2042 }, + { 2074, 2074 }, + { 2084, 2084 }, + { 2088, 2088 }, + { 2249, 2249 }, + { 2417, 2417 }, + { 3654, 3654 }, + { 3782, 3782 }, + { 4348, 4348 }, + { 6103, 6103 }, + { 6211, 6211 }, + { 6823, 6823 }, + { 7288, 7293 }, + { 7468, 7530 }, + { 7544, 7544 }, + { 7579, 7615 }, + { 8305, 8305 }, + { 8319, 8319 }, + { 8336, 8348 }, + { 11388, 11389 }, + { 11631, 11631 }, + { 11823, 11823 }, + { 12293, 12293 }, + { 12337, 12341 }, + { 12347, 12347 }, + { 12445, 12446 }, + { 12540, 12542 }, + { 40981, 40981 }, + { 42232, 42237 }, + { 42508, 42508 }, + { 42623, 42623 }, + { 42652, 42653 }, + { 42775, 42783 }, + { 42864, 42864 }, + { 42888, 42888 }, + { 42994, 42996 }, + { 43000, 43001 }, + { 43471, 43471 }, + { 43494, 43494 }, + { 43632, 43632 }, + { 43741, 43741 }, + { 43763, 43764 }, + { 43868, 43871 }, + { 43881, 43881 }, + { 65392, 65392 }, + { 65438, 65439 }, +}; +static const URange32 Lm_range32[] = { + { 67456, 67461 }, + { 67463, 67504 }, + { 67506, 67514 }, + { 92992, 92995 }, + { 94099, 94111 }, + { 94176, 94177 }, + { 94179, 94179 }, + { 110576, 110579 }, + { 110581, 110587 }, + { 110589, 110590 }, + { 122928, 122989 }, + { 123191, 123197 }, + { 124139, 124139 }, + { 125259, 125259 }, +}; +static const URange16 Lo_range16[] = { + { 170, 170 }, + { 186, 186 }, + { 443, 443 }, + { 448, 451 }, + { 660, 660 }, + { 1488, 1514 }, + { 1519, 1522 }, + { 1568, 1599 }, + { 1601, 1610 }, + { 1646, 1647 }, + { 1649, 1747 }, + { 1749, 1749 }, + { 1774, 1775 }, + { 1786, 1788 }, + { 1791, 1791 }, + { 1808, 1808 }, + { 1810, 1839 }, + { 1869, 1957 }, + { 1969, 1969 }, + { 1994, 2026 }, + { 2048, 2069 }, + { 2112, 2136 }, + { 2144, 2154 }, + { 2160, 2183 }, + { 2185, 2190 }, + { 2208, 2248 }, + { 2308, 2361 }, + { 2365, 2365 }, + { 2384, 2384 }, + { 2392, 2401 }, + { 2418, 2432 }, + { 2437, 2444 }, + { 2447, 2448 }, + { 2451, 2472 }, + { 2474, 2480 }, + { 2482, 2482 }, + { 2486, 2489 }, + { 2493, 2493 }, + { 2510, 2510 }, + { 2524, 2525 }, + { 2527, 2529 }, + { 2544, 2545 }, + { 2556, 2556 }, + { 2565, 2570 }, + { 2575, 2576 }, + { 2579, 2600 }, + { 2602, 2608 }, + { 2610, 2611 }, + { 2613, 2614 }, + { 2616, 2617 }, + { 2649, 2652 }, + { 2654, 2654 }, + { 2674, 2676 }, + { 2693, 2701 }, + { 2703, 2705 }, + { 2707, 2728 }, + { 2730, 2736 }, + { 2738, 2739 }, + { 2741, 2745 }, + { 2749, 2749 }, + { 2768, 2768 }, + { 2784, 2785 }, + { 2809, 2809 }, + { 2821, 2828 }, + { 2831, 2832 }, + { 2835, 2856 }, + { 2858, 2864 }, + { 2866, 2867 }, + { 2869, 2873 }, + { 2877, 2877 }, + { 2908, 2909 }, + { 2911, 2913 }, + { 2929, 2929 }, + { 2947, 2947 }, + { 2949, 2954 }, + { 2958, 2960 }, + { 2962, 2965 }, + { 2969, 2970 }, + { 2972, 2972 }, + { 2974, 2975 }, + { 2979, 2980 }, + { 2984, 2986 }, + { 2990, 3001 }, + { 3024, 3024 }, + { 3077, 3084 }, + { 3086, 3088 }, + { 3090, 3112 }, + { 3114, 3129 }, + { 3133, 3133 }, + { 3160, 3162 }, + { 3165, 3165 }, + { 3168, 3169 }, + { 3200, 3200 }, + { 3205, 3212 }, + { 3214, 3216 }, + { 3218, 3240 }, + { 3242, 3251 }, + { 3253, 3257 }, + { 3261, 3261 }, + { 3293, 3294 }, + { 3296, 3297 }, + { 3313, 3314 }, + { 3332, 3340 }, + { 3342, 3344 }, + { 3346, 3386 }, + { 3389, 3389 }, + { 3406, 3406 }, + { 3412, 3414 }, + { 3423, 3425 }, + { 3450, 3455 }, + { 3461, 3478 }, + { 3482, 3505 }, + { 3507, 3515 }, + { 3517, 3517 }, + { 3520, 3526 }, + { 3585, 3632 }, + { 3634, 3635 }, + { 3648, 3653 }, + { 3713, 3714 }, + { 3716, 3716 }, + { 3718, 3722 }, + { 3724, 3747 }, + { 3749, 3749 }, + { 3751, 3760 }, + { 3762, 3763 }, + { 3773, 3773 }, + { 3776, 3780 }, + { 3804, 3807 }, + { 3840, 3840 }, + { 3904, 3911 }, + { 3913, 3948 }, + { 3976, 3980 }, + { 4096, 4138 }, + { 4159, 4159 }, + { 4176, 4181 }, + { 4186, 4189 }, + { 4193, 4193 }, + { 4197, 4198 }, + { 4206, 4208 }, + { 4213, 4225 }, + { 4238, 4238 }, + { 4352, 4680 }, + { 4682, 4685 }, + { 4688, 4694 }, + { 4696, 4696 }, + { 4698, 4701 }, + { 4704, 4744 }, + { 4746, 4749 }, + { 4752, 4784 }, + { 4786, 4789 }, + { 4792, 4798 }, + { 4800, 4800 }, + { 4802, 4805 }, + { 4808, 4822 }, + { 4824, 4880 }, + { 4882, 4885 }, + { 4888, 4954 }, + { 4992, 5007 }, + { 5121, 5740 }, + { 5743, 5759 }, + { 5761, 5786 }, + { 5792, 5866 }, + { 5873, 5880 }, + { 5888, 5905 }, + { 5919, 5937 }, + { 5952, 5969 }, + { 5984, 5996 }, + { 5998, 6000 }, + { 6016, 6067 }, + { 6108, 6108 }, + { 6176, 6210 }, + { 6212, 6264 }, + { 6272, 6276 }, + { 6279, 6312 }, + { 6314, 6314 }, + { 6320, 6389 }, + { 6400, 6430 }, + { 6480, 6509 }, + { 6512, 6516 }, + { 6528, 6571 }, + { 6576, 6601 }, + { 6656, 6678 }, + { 6688, 6740 }, + { 6917, 6963 }, + { 6981, 6988 }, + { 7043, 7072 }, + { 7086, 7087 }, + { 7098, 7141 }, + { 7168, 7203 }, + { 7245, 7247 }, + { 7258, 7287 }, + { 7401, 7404 }, + { 7406, 7411 }, + { 7413, 7414 }, + { 7418, 7418 }, + { 8501, 8504 }, + { 11568, 11623 }, + { 11648, 11670 }, + { 11680, 11686 }, + { 11688, 11694 }, + { 11696, 11702 }, + { 11704, 11710 }, + { 11712, 11718 }, + { 11720, 11726 }, + { 11728, 11734 }, + { 11736, 11742 }, + { 12294, 12294 }, + { 12348, 12348 }, + { 12353, 12438 }, + { 12447, 12447 }, + { 12449, 12538 }, + { 12543, 12543 }, + { 12549, 12591 }, + { 12593, 12686 }, + { 12704, 12735 }, + { 12784, 12799 }, + { 13312, 19903 }, + { 19968, 40980 }, + { 40982, 42124 }, + { 42192, 42231 }, + { 42240, 42507 }, + { 42512, 42527 }, + { 42538, 42539 }, + { 42606, 42606 }, + { 42656, 42725 }, + { 42895, 42895 }, + { 42999, 42999 }, + { 43003, 43009 }, + { 43011, 43013 }, + { 43015, 43018 }, + { 43020, 43042 }, + { 43072, 43123 }, + { 43138, 43187 }, + { 43250, 43255 }, + { 43259, 43259 }, + { 43261, 43262 }, + { 43274, 43301 }, + { 43312, 43334 }, + { 43360, 43388 }, + { 43396, 43442 }, + { 43488, 43492 }, + { 43495, 43503 }, + { 43514, 43518 }, + { 43520, 43560 }, + { 43584, 43586 }, + { 43588, 43595 }, + { 43616, 43631 }, + { 43633, 43638 }, + { 43642, 43642 }, + { 43646, 43695 }, + { 43697, 43697 }, + { 43701, 43702 }, + { 43705, 43709 }, + { 43712, 43712 }, + { 43714, 43714 }, + { 43739, 43740 }, + { 43744, 43754 }, + { 43762, 43762 }, + { 43777, 43782 }, + { 43785, 43790 }, + { 43793, 43798 }, + { 43808, 43814 }, + { 43816, 43822 }, + { 43968, 44002 }, + { 44032, 55203 }, + { 55216, 55238 }, + { 55243, 55291 }, + { 63744, 64109 }, + { 64112, 64217 }, + { 64285, 64285 }, + { 64287, 64296 }, + { 64298, 64310 }, + { 64312, 64316 }, + { 64318, 64318 }, + { 64320, 64321 }, + { 64323, 64324 }, + { 64326, 64433 }, + { 64467, 64829 }, + { 64848, 64911 }, + { 64914, 64967 }, + { 65008, 65019 }, + { 65136, 65140 }, + { 65142, 65276 }, + { 65382, 65391 }, + { 65393, 65437 }, + { 65440, 65470 }, + { 65474, 65479 }, + { 65482, 65487 }, + { 65490, 65495 }, + { 65498, 65500 }, +}; +static const URange32 Lo_range32[] = { + { 65536, 65547 }, + { 65549, 65574 }, + { 65576, 65594 }, + { 65596, 65597 }, + { 65599, 65613 }, + { 65616, 65629 }, + { 65664, 65786 }, + { 66176, 66204 }, + { 66208, 66256 }, + { 66304, 66335 }, + { 66349, 66368 }, + { 66370, 66377 }, + { 66384, 66421 }, + { 66432, 66461 }, + { 66464, 66499 }, + { 66504, 66511 }, + { 66640, 66717 }, + { 66816, 66855 }, + { 66864, 66915 }, + { 67072, 67382 }, + { 67392, 67413 }, + { 67424, 67431 }, + { 67584, 67589 }, + { 67592, 67592 }, + { 67594, 67637 }, + { 67639, 67640 }, + { 67644, 67644 }, + { 67647, 67669 }, + { 67680, 67702 }, + { 67712, 67742 }, + { 67808, 67826 }, + { 67828, 67829 }, + { 67840, 67861 }, + { 67872, 67897 }, + { 67968, 68023 }, + { 68030, 68031 }, + { 68096, 68096 }, + { 68112, 68115 }, + { 68117, 68119 }, + { 68121, 68149 }, + { 68192, 68220 }, + { 68224, 68252 }, + { 68288, 68295 }, + { 68297, 68324 }, + { 68352, 68405 }, + { 68416, 68437 }, + { 68448, 68466 }, + { 68480, 68497 }, + { 68608, 68680 }, + { 68864, 68899 }, + { 69248, 69289 }, + { 69296, 69297 }, + { 69376, 69404 }, + { 69415, 69415 }, + { 69424, 69445 }, + { 69488, 69505 }, + { 69552, 69572 }, + { 69600, 69622 }, + { 69635, 69687 }, + { 69745, 69746 }, + { 69749, 69749 }, + { 69763, 69807 }, + { 69840, 69864 }, + { 69891, 69926 }, + { 69956, 69956 }, + { 69959, 69959 }, + { 69968, 70002 }, + { 70006, 70006 }, + { 70019, 70066 }, + { 70081, 70084 }, + { 70106, 70106 }, + { 70108, 70108 }, + { 70144, 70161 }, + { 70163, 70187 }, + { 70207, 70208 }, + { 70272, 70278 }, + { 70280, 70280 }, + { 70282, 70285 }, + { 70287, 70301 }, + { 70303, 70312 }, + { 70320, 70366 }, + { 70405, 70412 }, + { 70415, 70416 }, + { 70419, 70440 }, + { 70442, 70448 }, + { 70450, 70451 }, + { 70453, 70457 }, + { 70461, 70461 }, + { 70480, 70480 }, + { 70493, 70497 }, + { 70656, 70708 }, + { 70727, 70730 }, + { 70751, 70753 }, + { 70784, 70831 }, + { 70852, 70853 }, + { 70855, 70855 }, + { 71040, 71086 }, + { 71128, 71131 }, + { 71168, 71215 }, + { 71236, 71236 }, + { 71296, 71338 }, + { 71352, 71352 }, + { 71424, 71450 }, + { 71488, 71494 }, + { 71680, 71723 }, + { 71935, 71942 }, + { 71945, 71945 }, + { 71948, 71955 }, + { 71957, 71958 }, + { 71960, 71983 }, + { 71999, 71999 }, + { 72001, 72001 }, + { 72096, 72103 }, + { 72106, 72144 }, + { 72161, 72161 }, + { 72163, 72163 }, + { 72192, 72192 }, + { 72203, 72242 }, + { 72250, 72250 }, + { 72272, 72272 }, + { 72284, 72329 }, + { 72349, 72349 }, + { 72368, 72440 }, + { 72704, 72712 }, + { 72714, 72750 }, + { 72768, 72768 }, + { 72818, 72847 }, + { 72960, 72966 }, + { 72968, 72969 }, + { 72971, 73008 }, + { 73030, 73030 }, + { 73056, 73061 }, + { 73063, 73064 }, + { 73066, 73097 }, + { 73112, 73112 }, + { 73440, 73458 }, + { 73474, 73474 }, + { 73476, 73488 }, + { 73490, 73523 }, + { 73648, 73648 }, + { 73728, 74649 }, + { 74880, 75075 }, + { 77712, 77808 }, + { 77824, 78895 }, + { 78913, 78918 }, + { 82944, 83526 }, + { 92160, 92728 }, + { 92736, 92766 }, + { 92784, 92862 }, + { 92880, 92909 }, + { 92928, 92975 }, + { 93027, 93047 }, + { 93053, 93071 }, + { 93952, 94026 }, + { 94032, 94032 }, + { 94208, 100343 }, + { 100352, 101589 }, + { 101632, 101640 }, + { 110592, 110882 }, + { 110898, 110898 }, + { 110928, 110930 }, + { 110933, 110933 }, + { 110948, 110951 }, + { 110960, 111355 }, + { 113664, 113770 }, + { 113776, 113788 }, + { 113792, 113800 }, + { 113808, 113817 }, + { 122634, 122634 }, + { 123136, 123180 }, + { 123214, 123214 }, + { 123536, 123565 }, + { 123584, 123627 }, + { 124112, 124138 }, + { 124896, 124902 }, + { 124904, 124907 }, + { 124909, 124910 }, + { 124912, 124926 }, + { 124928, 125124 }, + { 126464, 126467 }, + { 126469, 126495 }, + { 126497, 126498 }, + { 126500, 126500 }, + { 126503, 126503 }, + { 126505, 126514 }, + { 126516, 126519 }, + { 126521, 126521 }, + { 126523, 126523 }, + { 126530, 126530 }, + { 126535, 126535 }, + { 126537, 126537 }, + { 126539, 126539 }, + { 126541, 126543 }, + { 126545, 126546 }, + { 126548, 126548 }, + { 126551, 126551 }, + { 126553, 126553 }, + { 126555, 126555 }, + { 126557, 126557 }, + { 126559, 126559 }, + { 126561, 126562 }, + { 126564, 126564 }, + { 126567, 126570 }, + { 126572, 126578 }, + { 126580, 126583 }, + { 126585, 126588 }, + { 126590, 126590 }, + { 126592, 126601 }, + { 126603, 126619 }, + { 126625, 126627 }, + { 126629, 126633 }, + { 126635, 126651 }, + { 131072, 173791 }, + { 173824, 177977 }, + { 177984, 178205 }, + { 178208, 183969 }, + { 183984, 191456 }, + { 194560, 195101 }, + { 196608, 201546 }, + { 201552, 205743 }, +}; +static const URange16 Lt_range16[] = { + { 453, 453 }, + { 456, 456 }, + { 459, 459 }, + { 498, 498 }, + { 8072, 8079 }, + { 8088, 8095 }, + { 8104, 8111 }, + { 8124, 8124 }, + { 8140, 8140 }, + { 8188, 8188 }, +}; +static const URange16 Lu_range16[] = { + { 65, 90 }, + { 192, 214 }, + { 216, 222 }, + { 256, 256 }, + { 258, 258 }, + { 260, 260 }, + { 262, 262 }, + { 264, 264 }, + { 266, 266 }, + { 268, 268 }, + { 270, 270 }, + { 272, 272 }, + { 274, 274 }, + { 276, 276 }, + { 278, 278 }, + { 280, 280 }, + { 282, 282 }, + { 284, 284 }, + { 286, 286 }, + { 288, 288 }, + { 290, 290 }, + { 292, 292 }, + { 294, 294 }, + { 296, 296 }, + { 298, 298 }, + { 300, 300 }, + { 302, 302 }, + { 304, 304 }, + { 306, 306 }, + { 308, 308 }, + { 310, 310 }, + { 313, 313 }, + { 315, 315 }, + { 317, 317 }, + { 319, 319 }, + { 321, 321 }, + { 323, 323 }, + { 325, 325 }, + { 327, 327 }, + { 330, 330 }, + { 332, 332 }, + { 334, 334 }, + { 336, 336 }, + { 338, 338 }, + { 340, 340 }, + { 342, 342 }, + { 344, 344 }, + { 346, 346 }, + { 348, 348 }, + { 350, 350 }, + { 352, 352 }, + { 354, 354 }, + { 356, 356 }, + { 358, 358 }, + { 360, 360 }, + { 362, 362 }, + { 364, 364 }, + { 366, 366 }, + { 368, 368 }, + { 370, 370 }, + { 372, 372 }, + { 374, 374 }, + { 376, 377 }, + { 379, 379 }, + { 381, 381 }, + { 385, 386 }, + { 388, 388 }, + { 390, 391 }, + { 393, 395 }, + { 398, 401 }, + { 403, 404 }, + { 406, 408 }, + { 412, 413 }, + { 415, 416 }, + { 418, 418 }, + { 420, 420 }, + { 422, 423 }, + { 425, 425 }, + { 428, 428 }, + { 430, 431 }, + { 433, 435 }, + { 437, 437 }, + { 439, 440 }, + { 444, 444 }, + { 452, 452 }, + { 455, 455 }, + { 458, 458 }, + { 461, 461 }, + { 463, 463 }, + { 465, 465 }, + { 467, 467 }, + { 469, 469 }, + { 471, 471 }, + { 473, 473 }, + { 475, 475 }, + { 478, 478 }, + { 480, 480 }, + { 482, 482 }, + { 484, 484 }, + { 486, 486 }, + { 488, 488 }, + { 490, 490 }, + { 492, 492 }, + { 494, 494 }, + { 497, 497 }, + { 500, 500 }, + { 502, 504 }, + { 506, 506 }, + { 508, 508 }, + { 510, 510 }, + { 512, 512 }, + { 514, 514 }, + { 516, 516 }, + { 518, 518 }, + { 520, 520 }, + { 522, 522 }, + { 524, 524 }, + { 526, 526 }, + { 528, 528 }, + { 530, 530 }, + { 532, 532 }, + { 534, 534 }, + { 536, 536 }, + { 538, 538 }, + { 540, 540 }, + { 542, 542 }, + { 544, 544 }, + { 546, 546 }, + { 548, 548 }, + { 550, 550 }, + { 552, 552 }, + { 554, 554 }, + { 556, 556 }, + { 558, 558 }, + { 560, 560 }, + { 562, 562 }, + { 570, 571 }, + { 573, 574 }, + { 577, 577 }, + { 579, 582 }, + { 584, 584 }, + { 586, 586 }, + { 588, 588 }, + { 590, 590 }, + { 880, 880 }, + { 882, 882 }, + { 886, 886 }, + { 895, 895 }, + { 902, 902 }, + { 904, 906 }, + { 908, 908 }, + { 910, 911 }, + { 913, 929 }, + { 931, 939 }, + { 975, 975 }, + { 978, 980 }, + { 984, 984 }, + { 986, 986 }, + { 988, 988 }, + { 990, 990 }, + { 992, 992 }, + { 994, 994 }, + { 996, 996 }, + { 998, 998 }, + { 1000, 1000 }, + { 1002, 1002 }, + { 1004, 1004 }, + { 1006, 1006 }, + { 1012, 1012 }, + { 1015, 1015 }, + { 1017, 1018 }, + { 1021, 1071 }, + { 1120, 1120 }, + { 1122, 1122 }, + { 1124, 1124 }, + { 1126, 1126 }, + { 1128, 1128 }, + { 1130, 1130 }, + { 1132, 1132 }, + { 1134, 1134 }, + { 1136, 1136 }, + { 1138, 1138 }, + { 1140, 1140 }, + { 1142, 1142 }, + { 1144, 1144 }, + { 1146, 1146 }, + { 1148, 1148 }, + { 1150, 1150 }, + { 1152, 1152 }, + { 1162, 1162 }, + { 1164, 1164 }, + { 1166, 1166 }, + { 1168, 1168 }, + { 1170, 1170 }, + { 1172, 1172 }, + { 1174, 1174 }, + { 1176, 1176 }, + { 1178, 1178 }, + { 1180, 1180 }, + { 1182, 1182 }, + { 1184, 1184 }, + { 1186, 1186 }, + { 1188, 1188 }, + { 1190, 1190 }, + { 1192, 1192 }, + { 1194, 1194 }, + { 1196, 1196 }, + { 1198, 1198 }, + { 1200, 1200 }, + { 1202, 1202 }, + { 1204, 1204 }, + { 1206, 1206 }, + { 1208, 1208 }, + { 1210, 1210 }, + { 1212, 1212 }, + { 1214, 1214 }, + { 1216, 1217 }, + { 1219, 1219 }, + { 1221, 1221 }, + { 1223, 1223 }, + { 1225, 1225 }, + { 1227, 1227 }, + { 1229, 1229 }, + { 1232, 1232 }, + { 1234, 1234 }, + { 1236, 1236 }, + { 1238, 1238 }, + { 1240, 1240 }, + { 1242, 1242 }, + { 1244, 1244 }, + { 1246, 1246 }, + { 1248, 1248 }, + { 1250, 1250 }, + { 1252, 1252 }, + { 1254, 1254 }, + { 1256, 1256 }, + { 1258, 1258 }, + { 1260, 1260 }, + { 1262, 1262 }, + { 1264, 1264 }, + { 1266, 1266 }, + { 1268, 1268 }, + { 1270, 1270 }, + { 1272, 1272 }, + { 1274, 1274 }, + { 1276, 1276 }, + { 1278, 1278 }, + { 1280, 1280 }, + { 1282, 1282 }, + { 1284, 1284 }, + { 1286, 1286 }, + { 1288, 1288 }, + { 1290, 1290 }, + { 1292, 1292 }, + { 1294, 1294 }, + { 1296, 1296 }, + { 1298, 1298 }, + { 1300, 1300 }, + { 1302, 1302 }, + { 1304, 1304 }, + { 1306, 1306 }, + { 1308, 1308 }, + { 1310, 1310 }, + { 1312, 1312 }, + { 1314, 1314 }, + { 1316, 1316 }, + { 1318, 1318 }, + { 1320, 1320 }, + { 1322, 1322 }, + { 1324, 1324 }, + { 1326, 1326 }, + { 1329, 1366 }, + { 4256, 4293 }, + { 4295, 4295 }, + { 4301, 4301 }, + { 5024, 5109 }, + { 7312, 7354 }, + { 7357, 7359 }, + { 7680, 7680 }, + { 7682, 7682 }, + { 7684, 7684 }, + { 7686, 7686 }, + { 7688, 7688 }, + { 7690, 7690 }, + { 7692, 7692 }, + { 7694, 7694 }, + { 7696, 7696 }, + { 7698, 7698 }, + { 7700, 7700 }, + { 7702, 7702 }, + { 7704, 7704 }, + { 7706, 7706 }, + { 7708, 7708 }, + { 7710, 7710 }, + { 7712, 7712 }, + { 7714, 7714 }, + { 7716, 7716 }, + { 7718, 7718 }, + { 7720, 7720 }, + { 7722, 7722 }, + { 7724, 7724 }, + { 7726, 7726 }, + { 7728, 7728 }, + { 7730, 7730 }, + { 7732, 7732 }, + { 7734, 7734 }, + { 7736, 7736 }, + { 7738, 7738 }, + { 7740, 7740 }, + { 7742, 7742 }, + { 7744, 7744 }, + { 7746, 7746 }, + { 7748, 7748 }, + { 7750, 7750 }, + { 7752, 7752 }, + { 7754, 7754 }, + { 7756, 7756 }, + { 7758, 7758 }, + { 7760, 7760 }, + { 7762, 7762 }, + { 7764, 7764 }, + { 7766, 7766 }, + { 7768, 7768 }, + { 7770, 7770 }, + { 7772, 7772 }, + { 7774, 7774 }, + { 7776, 7776 }, + { 7778, 7778 }, + { 7780, 7780 }, + { 7782, 7782 }, + { 7784, 7784 }, + { 7786, 7786 }, + { 7788, 7788 }, + { 7790, 7790 }, + { 7792, 7792 }, + { 7794, 7794 }, + { 7796, 7796 }, + { 7798, 7798 }, + { 7800, 7800 }, + { 7802, 7802 }, + { 7804, 7804 }, + { 7806, 7806 }, + { 7808, 7808 }, + { 7810, 7810 }, + { 7812, 7812 }, + { 7814, 7814 }, + { 7816, 7816 }, + { 7818, 7818 }, + { 7820, 7820 }, + { 7822, 7822 }, + { 7824, 7824 }, + { 7826, 7826 }, + { 7828, 7828 }, + { 7838, 7838 }, + { 7840, 7840 }, + { 7842, 7842 }, + { 7844, 7844 }, + { 7846, 7846 }, + { 7848, 7848 }, + { 7850, 7850 }, + { 7852, 7852 }, + { 7854, 7854 }, + { 7856, 7856 }, + { 7858, 7858 }, + { 7860, 7860 }, + { 7862, 7862 }, + { 7864, 7864 }, + { 7866, 7866 }, + { 7868, 7868 }, + { 7870, 7870 }, + { 7872, 7872 }, + { 7874, 7874 }, + { 7876, 7876 }, + { 7878, 7878 }, + { 7880, 7880 }, + { 7882, 7882 }, + { 7884, 7884 }, + { 7886, 7886 }, + { 7888, 7888 }, + { 7890, 7890 }, + { 7892, 7892 }, + { 7894, 7894 }, + { 7896, 7896 }, + { 7898, 7898 }, + { 7900, 7900 }, + { 7902, 7902 }, + { 7904, 7904 }, + { 7906, 7906 }, + { 7908, 7908 }, + { 7910, 7910 }, + { 7912, 7912 }, + { 7914, 7914 }, + { 7916, 7916 }, + { 7918, 7918 }, + { 7920, 7920 }, + { 7922, 7922 }, + { 7924, 7924 }, + { 7926, 7926 }, + { 7928, 7928 }, + { 7930, 7930 }, + { 7932, 7932 }, + { 7934, 7934 }, + { 7944, 7951 }, + { 7960, 7965 }, + { 7976, 7983 }, + { 7992, 7999 }, + { 8008, 8013 }, + { 8025, 8025 }, + { 8027, 8027 }, + { 8029, 8029 }, + { 8031, 8031 }, + { 8040, 8047 }, + { 8120, 8123 }, + { 8136, 8139 }, + { 8152, 8155 }, + { 8168, 8172 }, + { 8184, 8187 }, + { 8450, 8450 }, + { 8455, 8455 }, + { 8459, 8461 }, + { 8464, 8466 }, + { 8469, 8469 }, + { 8473, 8477 }, + { 8484, 8484 }, + { 8486, 8486 }, + { 8488, 8488 }, + { 8490, 8493 }, + { 8496, 8499 }, + { 8510, 8511 }, + { 8517, 8517 }, + { 8579, 8579 }, + { 11264, 11311 }, + { 11360, 11360 }, + { 11362, 11364 }, + { 11367, 11367 }, + { 11369, 11369 }, + { 11371, 11371 }, + { 11373, 11376 }, + { 11378, 11378 }, + { 11381, 11381 }, + { 11390, 11392 }, + { 11394, 11394 }, + { 11396, 11396 }, + { 11398, 11398 }, + { 11400, 11400 }, + { 11402, 11402 }, + { 11404, 11404 }, + { 11406, 11406 }, + { 11408, 11408 }, + { 11410, 11410 }, + { 11412, 11412 }, + { 11414, 11414 }, + { 11416, 11416 }, + { 11418, 11418 }, + { 11420, 11420 }, + { 11422, 11422 }, + { 11424, 11424 }, + { 11426, 11426 }, + { 11428, 11428 }, + { 11430, 11430 }, + { 11432, 11432 }, + { 11434, 11434 }, + { 11436, 11436 }, + { 11438, 11438 }, + { 11440, 11440 }, + { 11442, 11442 }, + { 11444, 11444 }, + { 11446, 11446 }, + { 11448, 11448 }, + { 11450, 11450 }, + { 11452, 11452 }, + { 11454, 11454 }, + { 11456, 11456 }, + { 11458, 11458 }, + { 11460, 11460 }, + { 11462, 11462 }, + { 11464, 11464 }, + { 11466, 11466 }, + { 11468, 11468 }, + { 11470, 11470 }, + { 11472, 11472 }, + { 11474, 11474 }, + { 11476, 11476 }, + { 11478, 11478 }, + { 11480, 11480 }, + { 11482, 11482 }, + { 11484, 11484 }, + { 11486, 11486 }, + { 11488, 11488 }, + { 11490, 11490 }, + { 11499, 11499 }, + { 11501, 11501 }, + { 11506, 11506 }, + { 42560, 42560 }, + { 42562, 42562 }, + { 42564, 42564 }, + { 42566, 42566 }, + { 42568, 42568 }, + { 42570, 42570 }, + { 42572, 42572 }, + { 42574, 42574 }, + { 42576, 42576 }, + { 42578, 42578 }, + { 42580, 42580 }, + { 42582, 42582 }, + { 42584, 42584 }, + { 42586, 42586 }, + { 42588, 42588 }, + { 42590, 42590 }, + { 42592, 42592 }, + { 42594, 42594 }, + { 42596, 42596 }, + { 42598, 42598 }, + { 42600, 42600 }, + { 42602, 42602 }, + { 42604, 42604 }, + { 42624, 42624 }, + { 42626, 42626 }, + { 42628, 42628 }, + { 42630, 42630 }, + { 42632, 42632 }, + { 42634, 42634 }, + { 42636, 42636 }, + { 42638, 42638 }, + { 42640, 42640 }, + { 42642, 42642 }, + { 42644, 42644 }, + { 42646, 42646 }, + { 42648, 42648 }, + { 42650, 42650 }, + { 42786, 42786 }, + { 42788, 42788 }, + { 42790, 42790 }, + { 42792, 42792 }, + { 42794, 42794 }, + { 42796, 42796 }, + { 42798, 42798 }, + { 42802, 42802 }, + { 42804, 42804 }, + { 42806, 42806 }, + { 42808, 42808 }, + { 42810, 42810 }, + { 42812, 42812 }, + { 42814, 42814 }, + { 42816, 42816 }, + { 42818, 42818 }, + { 42820, 42820 }, + { 42822, 42822 }, + { 42824, 42824 }, + { 42826, 42826 }, + { 42828, 42828 }, + { 42830, 42830 }, + { 42832, 42832 }, + { 42834, 42834 }, + { 42836, 42836 }, + { 42838, 42838 }, + { 42840, 42840 }, + { 42842, 42842 }, + { 42844, 42844 }, + { 42846, 42846 }, + { 42848, 42848 }, + { 42850, 42850 }, + { 42852, 42852 }, + { 42854, 42854 }, + { 42856, 42856 }, + { 42858, 42858 }, + { 42860, 42860 }, + { 42862, 42862 }, + { 42873, 42873 }, + { 42875, 42875 }, + { 42877, 42878 }, + { 42880, 42880 }, + { 42882, 42882 }, + { 42884, 42884 }, + { 42886, 42886 }, + { 42891, 42891 }, + { 42893, 42893 }, + { 42896, 42896 }, + { 42898, 42898 }, + { 42902, 42902 }, + { 42904, 42904 }, + { 42906, 42906 }, + { 42908, 42908 }, + { 42910, 42910 }, + { 42912, 42912 }, + { 42914, 42914 }, + { 42916, 42916 }, + { 42918, 42918 }, + { 42920, 42920 }, + { 42922, 42926 }, + { 42928, 42932 }, + { 42934, 42934 }, + { 42936, 42936 }, + { 42938, 42938 }, + { 42940, 42940 }, + { 42942, 42942 }, + { 42944, 42944 }, + { 42946, 42946 }, + { 42948, 42951 }, + { 42953, 42953 }, + { 42960, 42960 }, + { 42966, 42966 }, + { 42968, 42968 }, + { 42997, 42997 }, + { 65313, 65338 }, +}; +static const URange32 Lu_range32[] = { + { 66560, 66599 }, + { 66736, 66771 }, + { 66928, 66938 }, + { 66940, 66954 }, + { 66956, 66962 }, + { 66964, 66965 }, + { 68736, 68786 }, + { 71840, 71871 }, + { 93760, 93791 }, + { 119808, 119833 }, + { 119860, 119885 }, + { 119912, 119937 }, + { 119964, 119964 }, + { 119966, 119967 }, + { 119970, 119970 }, + { 119973, 119974 }, + { 119977, 119980 }, + { 119982, 119989 }, + { 120016, 120041 }, + { 120068, 120069 }, + { 120071, 120074 }, + { 120077, 120084 }, + { 120086, 120092 }, + { 120120, 120121 }, + { 120123, 120126 }, + { 120128, 120132 }, + { 120134, 120134 }, + { 120138, 120144 }, + { 120172, 120197 }, + { 120224, 120249 }, + { 120276, 120301 }, + { 120328, 120353 }, + { 120380, 120405 }, + { 120432, 120457 }, + { 120488, 120512 }, + { 120546, 120570 }, + { 120604, 120628 }, + { 120662, 120686 }, + { 120720, 120744 }, + { 120778, 120778 }, + { 125184, 125217 }, +}; +static const URange16 M_range16[] = { + { 768, 879 }, + { 1155, 1161 }, + { 1425, 1469 }, + { 1471, 1471 }, + { 1473, 1474 }, + { 1476, 1477 }, + { 1479, 1479 }, + { 1552, 1562 }, + { 1611, 1631 }, + { 1648, 1648 }, + { 1750, 1756 }, + { 1759, 1764 }, + { 1767, 1768 }, + { 1770, 1773 }, + { 1809, 1809 }, + { 1840, 1866 }, + { 1958, 1968 }, + { 2027, 2035 }, + { 2045, 2045 }, + { 2070, 2073 }, + { 2075, 2083 }, + { 2085, 2087 }, + { 2089, 2093 }, + { 2137, 2139 }, + { 2200, 2207 }, + { 2250, 2273 }, + { 2275, 2307 }, + { 2362, 2364 }, + { 2366, 2383 }, + { 2385, 2391 }, + { 2402, 2403 }, + { 2433, 2435 }, + { 2492, 2492 }, + { 2494, 2500 }, + { 2503, 2504 }, + { 2507, 2509 }, + { 2519, 2519 }, + { 2530, 2531 }, + { 2558, 2558 }, + { 2561, 2563 }, + { 2620, 2620 }, + { 2622, 2626 }, + { 2631, 2632 }, + { 2635, 2637 }, + { 2641, 2641 }, + { 2672, 2673 }, + { 2677, 2677 }, + { 2689, 2691 }, + { 2748, 2748 }, + { 2750, 2757 }, + { 2759, 2761 }, + { 2763, 2765 }, + { 2786, 2787 }, + { 2810, 2815 }, + { 2817, 2819 }, + { 2876, 2876 }, + { 2878, 2884 }, + { 2887, 2888 }, + { 2891, 2893 }, + { 2901, 2903 }, + { 2914, 2915 }, + { 2946, 2946 }, + { 3006, 3010 }, + { 3014, 3016 }, + { 3018, 3021 }, + { 3031, 3031 }, + { 3072, 3076 }, + { 3132, 3132 }, + { 3134, 3140 }, + { 3142, 3144 }, + { 3146, 3149 }, + { 3157, 3158 }, + { 3170, 3171 }, + { 3201, 3203 }, + { 3260, 3260 }, + { 3262, 3268 }, + { 3270, 3272 }, + { 3274, 3277 }, + { 3285, 3286 }, + { 3298, 3299 }, + { 3315, 3315 }, + { 3328, 3331 }, + { 3387, 3388 }, + { 3390, 3396 }, + { 3398, 3400 }, + { 3402, 3405 }, + { 3415, 3415 }, + { 3426, 3427 }, + { 3457, 3459 }, + { 3530, 3530 }, + { 3535, 3540 }, + { 3542, 3542 }, + { 3544, 3551 }, + { 3570, 3571 }, + { 3633, 3633 }, + { 3636, 3642 }, + { 3655, 3662 }, + { 3761, 3761 }, + { 3764, 3772 }, + { 3784, 3790 }, + { 3864, 3865 }, + { 3893, 3893 }, + { 3895, 3895 }, + { 3897, 3897 }, + { 3902, 3903 }, + { 3953, 3972 }, + { 3974, 3975 }, + { 3981, 3991 }, + { 3993, 4028 }, + { 4038, 4038 }, + { 4139, 4158 }, + { 4182, 4185 }, + { 4190, 4192 }, + { 4194, 4196 }, + { 4199, 4205 }, + { 4209, 4212 }, + { 4226, 4237 }, + { 4239, 4239 }, + { 4250, 4253 }, + { 4957, 4959 }, + { 5906, 5909 }, + { 5938, 5940 }, + { 5970, 5971 }, + { 6002, 6003 }, + { 6068, 6099 }, + { 6109, 6109 }, + { 6155, 6157 }, + { 6159, 6159 }, + { 6277, 6278 }, + { 6313, 6313 }, + { 6432, 6443 }, + { 6448, 6459 }, + { 6679, 6683 }, + { 6741, 6750 }, + { 6752, 6780 }, + { 6783, 6783 }, + { 6832, 6862 }, + { 6912, 6916 }, + { 6964, 6980 }, + { 7019, 7027 }, + { 7040, 7042 }, + { 7073, 7085 }, + { 7142, 7155 }, + { 7204, 7223 }, + { 7376, 7378 }, + { 7380, 7400 }, + { 7405, 7405 }, + { 7412, 7412 }, + { 7415, 7417 }, + { 7616, 7679 }, + { 8400, 8432 }, + { 11503, 11505 }, + { 11647, 11647 }, + { 11744, 11775 }, + { 12330, 12335 }, + { 12441, 12442 }, + { 42607, 42610 }, + { 42612, 42621 }, + { 42654, 42655 }, + { 42736, 42737 }, + { 43010, 43010 }, + { 43014, 43014 }, + { 43019, 43019 }, + { 43043, 43047 }, + { 43052, 43052 }, + { 43136, 43137 }, + { 43188, 43205 }, + { 43232, 43249 }, + { 43263, 43263 }, + { 43302, 43309 }, + { 43335, 43347 }, + { 43392, 43395 }, + { 43443, 43456 }, + { 43493, 43493 }, + { 43561, 43574 }, + { 43587, 43587 }, + { 43596, 43597 }, + { 43643, 43645 }, + { 43696, 43696 }, + { 43698, 43700 }, + { 43703, 43704 }, + { 43710, 43711 }, + { 43713, 43713 }, + { 43755, 43759 }, + { 43765, 43766 }, + { 44003, 44010 }, + { 44012, 44013 }, + { 64286, 64286 }, + { 65024, 65039 }, + { 65056, 65071 }, +}; +static const URange32 M_range32[] = { + { 66045, 66045 }, + { 66272, 66272 }, + { 66422, 66426 }, + { 68097, 68099 }, + { 68101, 68102 }, + { 68108, 68111 }, + { 68152, 68154 }, + { 68159, 68159 }, + { 68325, 68326 }, + { 68900, 68903 }, + { 69291, 69292 }, + { 69373, 69375 }, + { 69446, 69456 }, + { 69506, 69509 }, + { 69632, 69634 }, + { 69688, 69702 }, + { 69744, 69744 }, + { 69747, 69748 }, + { 69759, 69762 }, + { 69808, 69818 }, + { 69826, 69826 }, + { 69888, 69890 }, + { 69927, 69940 }, + { 69957, 69958 }, + { 70003, 70003 }, + { 70016, 70018 }, + { 70067, 70080 }, + { 70089, 70092 }, + { 70094, 70095 }, + { 70188, 70199 }, + { 70206, 70206 }, + { 70209, 70209 }, + { 70367, 70378 }, + { 70400, 70403 }, + { 70459, 70460 }, + { 70462, 70468 }, + { 70471, 70472 }, + { 70475, 70477 }, + { 70487, 70487 }, + { 70498, 70499 }, + { 70502, 70508 }, + { 70512, 70516 }, + { 70709, 70726 }, + { 70750, 70750 }, + { 70832, 70851 }, + { 71087, 71093 }, + { 71096, 71104 }, + { 71132, 71133 }, + { 71216, 71232 }, + { 71339, 71351 }, + { 71453, 71467 }, + { 71724, 71738 }, + { 71984, 71989 }, + { 71991, 71992 }, + { 71995, 71998 }, + { 72000, 72000 }, + { 72002, 72003 }, + { 72145, 72151 }, + { 72154, 72160 }, + { 72164, 72164 }, + { 72193, 72202 }, + { 72243, 72249 }, + { 72251, 72254 }, + { 72263, 72263 }, + { 72273, 72283 }, + { 72330, 72345 }, + { 72751, 72758 }, + { 72760, 72767 }, + { 72850, 72871 }, + { 72873, 72886 }, + { 73009, 73014 }, + { 73018, 73018 }, + { 73020, 73021 }, + { 73023, 73029 }, + { 73031, 73031 }, + { 73098, 73102 }, + { 73104, 73105 }, + { 73107, 73111 }, + { 73459, 73462 }, + { 73472, 73473 }, + { 73475, 73475 }, + { 73524, 73530 }, + { 73534, 73538 }, + { 78912, 78912 }, + { 78919, 78933 }, + { 92912, 92916 }, + { 92976, 92982 }, + { 94031, 94031 }, + { 94033, 94087 }, + { 94095, 94098 }, + { 94180, 94180 }, + { 94192, 94193 }, + { 113821, 113822 }, + { 118528, 118573 }, + { 118576, 118598 }, + { 119141, 119145 }, + { 119149, 119154 }, + { 119163, 119170 }, + { 119173, 119179 }, + { 119210, 119213 }, + { 119362, 119364 }, + { 121344, 121398 }, + { 121403, 121452 }, + { 121461, 121461 }, + { 121476, 121476 }, + { 121499, 121503 }, + { 121505, 121519 }, + { 122880, 122886 }, + { 122888, 122904 }, + { 122907, 122913 }, + { 122915, 122916 }, + { 122918, 122922 }, + { 123023, 123023 }, + { 123184, 123190 }, + { 123566, 123566 }, + { 123628, 123631 }, + { 124140, 124143 }, + { 125136, 125142 }, + { 125252, 125258 }, + { 917760, 917999 }, +}; +static const URange16 Mc_range16[] = { + { 2307, 2307 }, + { 2363, 2363 }, + { 2366, 2368 }, + { 2377, 2380 }, + { 2382, 2383 }, + { 2434, 2435 }, + { 2494, 2496 }, + { 2503, 2504 }, + { 2507, 2508 }, + { 2519, 2519 }, + { 2563, 2563 }, + { 2622, 2624 }, + { 2691, 2691 }, + { 2750, 2752 }, + { 2761, 2761 }, + { 2763, 2764 }, + { 2818, 2819 }, + { 2878, 2878 }, + { 2880, 2880 }, + { 2887, 2888 }, + { 2891, 2892 }, + { 2903, 2903 }, + { 3006, 3007 }, + { 3009, 3010 }, + { 3014, 3016 }, + { 3018, 3020 }, + { 3031, 3031 }, + { 3073, 3075 }, + { 3137, 3140 }, + { 3202, 3203 }, + { 3262, 3262 }, + { 3264, 3268 }, + { 3271, 3272 }, + { 3274, 3275 }, + { 3285, 3286 }, + { 3315, 3315 }, + { 3330, 3331 }, + { 3390, 3392 }, + { 3398, 3400 }, + { 3402, 3404 }, + { 3415, 3415 }, + { 3458, 3459 }, + { 3535, 3537 }, + { 3544, 3551 }, + { 3570, 3571 }, + { 3902, 3903 }, + { 3967, 3967 }, + { 4139, 4140 }, + { 4145, 4145 }, + { 4152, 4152 }, + { 4155, 4156 }, + { 4182, 4183 }, + { 4194, 4196 }, + { 4199, 4205 }, + { 4227, 4228 }, + { 4231, 4236 }, + { 4239, 4239 }, + { 4250, 4252 }, + { 5909, 5909 }, + { 5940, 5940 }, + { 6070, 6070 }, + { 6078, 6085 }, + { 6087, 6088 }, + { 6435, 6438 }, + { 6441, 6443 }, + { 6448, 6449 }, + { 6451, 6456 }, + { 6681, 6682 }, + { 6741, 6741 }, + { 6743, 6743 }, + { 6753, 6753 }, + { 6755, 6756 }, + { 6765, 6770 }, + { 6916, 6916 }, + { 6965, 6965 }, + { 6971, 6971 }, + { 6973, 6977 }, + { 6979, 6980 }, + { 7042, 7042 }, + { 7073, 7073 }, + { 7078, 7079 }, + { 7082, 7082 }, + { 7143, 7143 }, + { 7146, 7148 }, + { 7150, 7150 }, + { 7154, 7155 }, + { 7204, 7211 }, + { 7220, 7221 }, + { 7393, 7393 }, + { 7415, 7415 }, + { 12334, 12335 }, + { 43043, 43044 }, + { 43047, 43047 }, + { 43136, 43137 }, + { 43188, 43203 }, + { 43346, 43347 }, + { 43395, 43395 }, + { 43444, 43445 }, + { 43450, 43451 }, + { 43454, 43456 }, + { 43567, 43568 }, + { 43571, 43572 }, + { 43597, 43597 }, + { 43643, 43643 }, + { 43645, 43645 }, + { 43755, 43755 }, + { 43758, 43759 }, + { 43765, 43765 }, + { 44003, 44004 }, + { 44006, 44007 }, + { 44009, 44010 }, + { 44012, 44012 }, +}; +static const URange32 Mc_range32[] = { + { 69632, 69632 }, + { 69634, 69634 }, + { 69762, 69762 }, + { 69808, 69810 }, + { 69815, 69816 }, + { 69932, 69932 }, + { 69957, 69958 }, + { 70018, 70018 }, + { 70067, 70069 }, + { 70079, 70080 }, + { 70094, 70094 }, + { 70188, 70190 }, + { 70194, 70195 }, + { 70197, 70197 }, + { 70368, 70370 }, + { 70402, 70403 }, + { 70462, 70463 }, + { 70465, 70468 }, + { 70471, 70472 }, + { 70475, 70477 }, + { 70487, 70487 }, + { 70498, 70499 }, + { 70709, 70711 }, + { 70720, 70721 }, + { 70725, 70725 }, + { 70832, 70834 }, + { 70841, 70841 }, + { 70843, 70846 }, + { 70849, 70849 }, + { 71087, 71089 }, + { 71096, 71099 }, + { 71102, 71102 }, + { 71216, 71218 }, + { 71227, 71228 }, + { 71230, 71230 }, + { 71340, 71340 }, + { 71342, 71343 }, + { 71350, 71350 }, + { 71456, 71457 }, + { 71462, 71462 }, + { 71724, 71726 }, + { 71736, 71736 }, + { 71984, 71989 }, + { 71991, 71992 }, + { 71997, 71997 }, + { 72000, 72000 }, + { 72002, 72002 }, + { 72145, 72147 }, + { 72156, 72159 }, + { 72164, 72164 }, + { 72249, 72249 }, + { 72279, 72280 }, + { 72343, 72343 }, + { 72751, 72751 }, + { 72766, 72766 }, + { 72873, 72873 }, + { 72881, 72881 }, + { 72884, 72884 }, + { 73098, 73102 }, + { 73107, 73108 }, + { 73110, 73110 }, + { 73461, 73462 }, + { 73475, 73475 }, + { 73524, 73525 }, + { 73534, 73535 }, + { 73537, 73537 }, + { 94033, 94087 }, + { 94192, 94193 }, + { 119141, 119142 }, + { 119149, 119154 }, +}; +static const URange16 Me_range16[] = { + { 1160, 1161 }, + { 6846, 6846 }, + { 8413, 8416 }, + { 8418, 8420 }, + { 42608, 42610 }, +}; +static const URange16 Mn_range16[] = { + { 768, 879 }, + { 1155, 1159 }, + { 1425, 1469 }, + { 1471, 1471 }, + { 1473, 1474 }, + { 1476, 1477 }, + { 1479, 1479 }, + { 1552, 1562 }, + { 1611, 1631 }, + { 1648, 1648 }, + { 1750, 1756 }, + { 1759, 1764 }, + { 1767, 1768 }, + { 1770, 1773 }, + { 1809, 1809 }, + { 1840, 1866 }, + { 1958, 1968 }, + { 2027, 2035 }, + { 2045, 2045 }, + { 2070, 2073 }, + { 2075, 2083 }, + { 2085, 2087 }, + { 2089, 2093 }, + { 2137, 2139 }, + { 2200, 2207 }, + { 2250, 2273 }, + { 2275, 2306 }, + { 2362, 2362 }, + { 2364, 2364 }, + { 2369, 2376 }, + { 2381, 2381 }, + { 2385, 2391 }, + { 2402, 2403 }, + { 2433, 2433 }, + { 2492, 2492 }, + { 2497, 2500 }, + { 2509, 2509 }, + { 2530, 2531 }, + { 2558, 2558 }, + { 2561, 2562 }, + { 2620, 2620 }, + { 2625, 2626 }, + { 2631, 2632 }, + { 2635, 2637 }, + { 2641, 2641 }, + { 2672, 2673 }, + { 2677, 2677 }, + { 2689, 2690 }, + { 2748, 2748 }, + { 2753, 2757 }, + { 2759, 2760 }, + { 2765, 2765 }, + { 2786, 2787 }, + { 2810, 2815 }, + { 2817, 2817 }, + { 2876, 2876 }, + { 2879, 2879 }, + { 2881, 2884 }, + { 2893, 2893 }, + { 2901, 2902 }, + { 2914, 2915 }, + { 2946, 2946 }, + { 3008, 3008 }, + { 3021, 3021 }, + { 3072, 3072 }, + { 3076, 3076 }, + { 3132, 3132 }, + { 3134, 3136 }, + { 3142, 3144 }, + { 3146, 3149 }, + { 3157, 3158 }, + { 3170, 3171 }, + { 3201, 3201 }, + { 3260, 3260 }, + { 3263, 3263 }, + { 3270, 3270 }, + { 3276, 3277 }, + { 3298, 3299 }, + { 3328, 3329 }, + { 3387, 3388 }, + { 3393, 3396 }, + { 3405, 3405 }, + { 3426, 3427 }, + { 3457, 3457 }, + { 3530, 3530 }, + { 3538, 3540 }, + { 3542, 3542 }, + { 3633, 3633 }, + { 3636, 3642 }, + { 3655, 3662 }, + { 3761, 3761 }, + { 3764, 3772 }, + { 3784, 3790 }, + { 3864, 3865 }, + { 3893, 3893 }, + { 3895, 3895 }, + { 3897, 3897 }, + { 3953, 3966 }, + { 3968, 3972 }, + { 3974, 3975 }, + { 3981, 3991 }, + { 3993, 4028 }, + { 4038, 4038 }, + { 4141, 4144 }, + { 4146, 4151 }, + { 4153, 4154 }, + { 4157, 4158 }, + { 4184, 4185 }, + { 4190, 4192 }, + { 4209, 4212 }, + { 4226, 4226 }, + { 4229, 4230 }, + { 4237, 4237 }, + { 4253, 4253 }, + { 4957, 4959 }, + { 5906, 5908 }, + { 5938, 5939 }, + { 5970, 5971 }, + { 6002, 6003 }, + { 6068, 6069 }, + { 6071, 6077 }, + { 6086, 6086 }, + { 6089, 6099 }, + { 6109, 6109 }, + { 6155, 6157 }, + { 6159, 6159 }, + { 6277, 6278 }, + { 6313, 6313 }, + { 6432, 6434 }, + { 6439, 6440 }, + { 6450, 6450 }, + { 6457, 6459 }, + { 6679, 6680 }, + { 6683, 6683 }, + { 6742, 6742 }, + { 6744, 6750 }, + { 6752, 6752 }, + { 6754, 6754 }, + { 6757, 6764 }, + { 6771, 6780 }, + { 6783, 6783 }, + { 6832, 6845 }, + { 6847, 6862 }, + { 6912, 6915 }, + { 6964, 6964 }, + { 6966, 6970 }, + { 6972, 6972 }, + { 6978, 6978 }, + { 7019, 7027 }, + { 7040, 7041 }, + { 7074, 7077 }, + { 7080, 7081 }, + { 7083, 7085 }, + { 7142, 7142 }, + { 7144, 7145 }, + { 7149, 7149 }, + { 7151, 7153 }, + { 7212, 7219 }, + { 7222, 7223 }, + { 7376, 7378 }, + { 7380, 7392 }, + { 7394, 7400 }, + { 7405, 7405 }, + { 7412, 7412 }, + { 7416, 7417 }, + { 7616, 7679 }, + { 8400, 8412 }, + { 8417, 8417 }, + { 8421, 8432 }, + { 11503, 11505 }, + { 11647, 11647 }, + { 11744, 11775 }, + { 12330, 12333 }, + { 12441, 12442 }, + { 42607, 42607 }, + { 42612, 42621 }, + { 42654, 42655 }, + { 42736, 42737 }, + { 43010, 43010 }, + { 43014, 43014 }, + { 43019, 43019 }, + { 43045, 43046 }, + { 43052, 43052 }, + { 43204, 43205 }, + { 43232, 43249 }, + { 43263, 43263 }, + { 43302, 43309 }, + { 43335, 43345 }, + { 43392, 43394 }, + { 43443, 43443 }, + { 43446, 43449 }, + { 43452, 43453 }, + { 43493, 43493 }, + { 43561, 43566 }, + { 43569, 43570 }, + { 43573, 43574 }, + { 43587, 43587 }, + { 43596, 43596 }, + { 43644, 43644 }, + { 43696, 43696 }, + { 43698, 43700 }, + { 43703, 43704 }, + { 43710, 43711 }, + { 43713, 43713 }, + { 43756, 43757 }, + { 43766, 43766 }, + { 44005, 44005 }, + { 44008, 44008 }, + { 44013, 44013 }, + { 64286, 64286 }, + { 65024, 65039 }, + { 65056, 65071 }, +}; +static const URange32 Mn_range32[] = { + { 66045, 66045 }, + { 66272, 66272 }, + { 66422, 66426 }, + { 68097, 68099 }, + { 68101, 68102 }, + { 68108, 68111 }, + { 68152, 68154 }, + { 68159, 68159 }, + { 68325, 68326 }, + { 68900, 68903 }, + { 69291, 69292 }, + { 69373, 69375 }, + { 69446, 69456 }, + { 69506, 69509 }, + { 69633, 69633 }, + { 69688, 69702 }, + { 69744, 69744 }, + { 69747, 69748 }, + { 69759, 69761 }, + { 69811, 69814 }, + { 69817, 69818 }, + { 69826, 69826 }, + { 69888, 69890 }, + { 69927, 69931 }, + { 69933, 69940 }, + { 70003, 70003 }, + { 70016, 70017 }, + { 70070, 70078 }, + { 70089, 70092 }, + { 70095, 70095 }, + { 70191, 70193 }, + { 70196, 70196 }, + { 70198, 70199 }, + { 70206, 70206 }, + { 70209, 70209 }, + { 70367, 70367 }, + { 70371, 70378 }, + { 70400, 70401 }, + { 70459, 70460 }, + { 70464, 70464 }, + { 70502, 70508 }, + { 70512, 70516 }, + { 70712, 70719 }, + { 70722, 70724 }, + { 70726, 70726 }, + { 70750, 70750 }, + { 70835, 70840 }, + { 70842, 70842 }, + { 70847, 70848 }, + { 70850, 70851 }, + { 71090, 71093 }, + { 71100, 71101 }, + { 71103, 71104 }, + { 71132, 71133 }, + { 71219, 71226 }, + { 71229, 71229 }, + { 71231, 71232 }, + { 71339, 71339 }, + { 71341, 71341 }, + { 71344, 71349 }, + { 71351, 71351 }, + { 71453, 71455 }, + { 71458, 71461 }, + { 71463, 71467 }, + { 71727, 71735 }, + { 71737, 71738 }, + { 71995, 71996 }, + { 71998, 71998 }, + { 72003, 72003 }, + { 72148, 72151 }, + { 72154, 72155 }, + { 72160, 72160 }, + { 72193, 72202 }, + { 72243, 72248 }, + { 72251, 72254 }, + { 72263, 72263 }, + { 72273, 72278 }, + { 72281, 72283 }, + { 72330, 72342 }, + { 72344, 72345 }, + { 72752, 72758 }, + { 72760, 72765 }, + { 72767, 72767 }, + { 72850, 72871 }, + { 72874, 72880 }, + { 72882, 72883 }, + { 72885, 72886 }, + { 73009, 73014 }, + { 73018, 73018 }, + { 73020, 73021 }, + { 73023, 73029 }, + { 73031, 73031 }, + { 73104, 73105 }, + { 73109, 73109 }, + { 73111, 73111 }, + { 73459, 73460 }, + { 73472, 73473 }, + { 73526, 73530 }, + { 73536, 73536 }, + { 73538, 73538 }, + { 78912, 78912 }, + { 78919, 78933 }, + { 92912, 92916 }, + { 92976, 92982 }, + { 94031, 94031 }, + { 94095, 94098 }, + { 94180, 94180 }, + { 113821, 113822 }, + { 118528, 118573 }, + { 118576, 118598 }, + { 119143, 119145 }, + { 119163, 119170 }, + { 119173, 119179 }, + { 119210, 119213 }, + { 119362, 119364 }, + { 121344, 121398 }, + { 121403, 121452 }, + { 121461, 121461 }, + { 121476, 121476 }, + { 121499, 121503 }, + { 121505, 121519 }, + { 122880, 122886 }, + { 122888, 122904 }, + { 122907, 122913 }, + { 122915, 122916 }, + { 122918, 122922 }, + { 123023, 123023 }, + { 123184, 123190 }, + { 123566, 123566 }, + { 123628, 123631 }, + { 124140, 124143 }, + { 125136, 125142 }, + { 125252, 125258 }, + { 917760, 917999 }, +}; +static const URange16 N_range16[] = { + { 48, 57 }, + { 178, 179 }, + { 185, 185 }, + { 188, 190 }, + { 1632, 1641 }, + { 1776, 1785 }, + { 1984, 1993 }, + { 2406, 2415 }, + { 2534, 2543 }, + { 2548, 2553 }, + { 2662, 2671 }, + { 2790, 2799 }, + { 2918, 2927 }, + { 2930, 2935 }, + { 3046, 3058 }, + { 3174, 3183 }, + { 3192, 3198 }, + { 3302, 3311 }, + { 3416, 3422 }, + { 3430, 3448 }, + { 3558, 3567 }, + { 3664, 3673 }, + { 3792, 3801 }, + { 3872, 3891 }, + { 4160, 4169 }, + { 4240, 4249 }, + { 4969, 4988 }, + { 5870, 5872 }, + { 6112, 6121 }, + { 6128, 6137 }, + { 6160, 6169 }, + { 6470, 6479 }, + { 6608, 6618 }, + { 6784, 6793 }, + { 6800, 6809 }, + { 6992, 7001 }, + { 7088, 7097 }, + { 7232, 7241 }, + { 7248, 7257 }, + { 8304, 8304 }, + { 8308, 8313 }, + { 8320, 8329 }, + { 8528, 8578 }, + { 8581, 8585 }, + { 9312, 9371 }, + { 9450, 9471 }, + { 10102, 10131 }, + { 11517, 11517 }, + { 12295, 12295 }, + { 12321, 12329 }, + { 12344, 12346 }, + { 12690, 12693 }, + { 12832, 12841 }, + { 12872, 12879 }, + { 12881, 12895 }, + { 12928, 12937 }, + { 12977, 12991 }, + { 42528, 42537 }, + { 42726, 42735 }, + { 43056, 43061 }, + { 43216, 43225 }, + { 43264, 43273 }, + { 43472, 43481 }, + { 43504, 43513 }, + { 43600, 43609 }, + { 44016, 44025 }, + { 65296, 65305 }, +}; +static const URange32 N_range32[] = { + { 65799, 65843 }, + { 65856, 65912 }, + { 65930, 65931 }, + { 66273, 66299 }, + { 66336, 66339 }, + { 66369, 66369 }, + { 66378, 66378 }, + { 66513, 66517 }, + { 66720, 66729 }, + { 67672, 67679 }, + { 67705, 67711 }, + { 67751, 67759 }, + { 67835, 67839 }, + { 67862, 67867 }, + { 68028, 68029 }, + { 68032, 68047 }, + { 68050, 68095 }, + { 68160, 68168 }, + { 68221, 68222 }, + { 68253, 68255 }, + { 68331, 68335 }, + { 68440, 68447 }, + { 68472, 68479 }, + { 68521, 68527 }, + { 68858, 68863 }, + { 68912, 68921 }, + { 69216, 69246 }, + { 69405, 69414 }, + { 69457, 69460 }, + { 69573, 69579 }, + { 69714, 69743 }, + { 69872, 69881 }, + { 69942, 69951 }, + { 70096, 70105 }, + { 70113, 70132 }, + { 70384, 70393 }, + { 70736, 70745 }, + { 70864, 70873 }, + { 71248, 71257 }, + { 71360, 71369 }, + { 71472, 71483 }, + { 71904, 71922 }, + { 72016, 72025 }, + { 72784, 72812 }, + { 73040, 73049 }, + { 73120, 73129 }, + { 73552, 73561 }, + { 73664, 73684 }, + { 74752, 74862 }, + { 92768, 92777 }, + { 92864, 92873 }, + { 93008, 93017 }, + { 93019, 93025 }, + { 93824, 93846 }, + { 119488, 119507 }, + { 119520, 119539 }, + { 119648, 119672 }, + { 120782, 120831 }, + { 123200, 123209 }, + { 123632, 123641 }, + { 124144, 124153 }, + { 125127, 125135 }, + { 125264, 125273 }, + { 126065, 126123 }, + { 126125, 126127 }, + { 126129, 126132 }, + { 126209, 126253 }, + { 126255, 126269 }, + { 127232, 127244 }, + { 130032, 130041 }, +}; +static const URange16 Nd_range16[] = { + { 48, 57 }, + { 1632, 1641 }, + { 1776, 1785 }, + { 1984, 1993 }, + { 2406, 2415 }, + { 2534, 2543 }, + { 2662, 2671 }, + { 2790, 2799 }, + { 2918, 2927 }, + { 3046, 3055 }, + { 3174, 3183 }, + { 3302, 3311 }, + { 3430, 3439 }, + { 3558, 3567 }, + { 3664, 3673 }, + { 3792, 3801 }, + { 3872, 3881 }, + { 4160, 4169 }, + { 4240, 4249 }, + { 6112, 6121 }, + { 6160, 6169 }, + { 6470, 6479 }, + { 6608, 6617 }, + { 6784, 6793 }, + { 6800, 6809 }, + { 6992, 7001 }, + { 7088, 7097 }, + { 7232, 7241 }, + { 7248, 7257 }, + { 42528, 42537 }, + { 43216, 43225 }, + { 43264, 43273 }, + { 43472, 43481 }, + { 43504, 43513 }, + { 43600, 43609 }, + { 44016, 44025 }, + { 65296, 65305 }, +}; +static const URange32 Nd_range32[] = { + { 66720, 66729 }, + { 68912, 68921 }, + { 69734, 69743 }, + { 69872, 69881 }, + { 69942, 69951 }, + { 70096, 70105 }, + { 70384, 70393 }, + { 70736, 70745 }, + { 70864, 70873 }, + { 71248, 71257 }, + { 71360, 71369 }, + { 71472, 71481 }, + { 71904, 71913 }, + { 72016, 72025 }, + { 72784, 72793 }, + { 73040, 73049 }, + { 73120, 73129 }, + { 73552, 73561 }, + { 92768, 92777 }, + { 92864, 92873 }, + { 93008, 93017 }, + { 120782, 120831 }, + { 123200, 123209 }, + { 123632, 123641 }, + { 124144, 124153 }, + { 125264, 125273 }, + { 130032, 130041 }, +}; +static const URange16 Nl_range16[] = { + { 5870, 5872 }, + { 8544, 8578 }, + { 8581, 8584 }, + { 12295, 12295 }, + { 12321, 12329 }, + { 12344, 12346 }, + { 42726, 42735 }, +}; +static const URange32 Nl_range32[] = { + { 65856, 65908 }, + { 66369, 66369 }, + { 66378, 66378 }, + { 66513, 66517 }, + { 74752, 74862 }, +}; +static const URange16 No_range16[] = { + { 178, 179 }, + { 185, 185 }, + { 188, 190 }, + { 2548, 2553 }, + { 2930, 2935 }, + { 3056, 3058 }, + { 3192, 3198 }, + { 3416, 3422 }, + { 3440, 3448 }, + { 3882, 3891 }, + { 4969, 4988 }, + { 6128, 6137 }, + { 6618, 6618 }, + { 8304, 8304 }, + { 8308, 8313 }, + { 8320, 8329 }, + { 8528, 8543 }, + { 8585, 8585 }, + { 9312, 9371 }, + { 9450, 9471 }, + { 10102, 10131 }, + { 11517, 11517 }, + { 12690, 12693 }, + { 12832, 12841 }, + { 12872, 12879 }, + { 12881, 12895 }, + { 12928, 12937 }, + { 12977, 12991 }, + { 43056, 43061 }, +}; +static const URange32 No_range32[] = { + { 65799, 65843 }, + { 65909, 65912 }, + { 65930, 65931 }, + { 66273, 66299 }, + { 66336, 66339 }, + { 67672, 67679 }, + { 67705, 67711 }, + { 67751, 67759 }, + { 67835, 67839 }, + { 67862, 67867 }, + { 68028, 68029 }, + { 68032, 68047 }, + { 68050, 68095 }, + { 68160, 68168 }, + { 68221, 68222 }, + { 68253, 68255 }, + { 68331, 68335 }, + { 68440, 68447 }, + { 68472, 68479 }, + { 68521, 68527 }, + { 68858, 68863 }, + { 69216, 69246 }, + { 69405, 69414 }, + { 69457, 69460 }, + { 69573, 69579 }, + { 69714, 69733 }, + { 70113, 70132 }, + { 71482, 71483 }, + { 71914, 71922 }, + { 72794, 72812 }, + { 73664, 73684 }, + { 93019, 93025 }, + { 93824, 93846 }, + { 119488, 119507 }, + { 119520, 119539 }, + { 119648, 119672 }, + { 125127, 125135 }, + { 126065, 126123 }, + { 126125, 126127 }, + { 126129, 126132 }, + { 126209, 126253 }, + { 126255, 126269 }, + { 127232, 127244 }, +}; +static const URange16 P_range16[] = { + { 33, 35 }, + { 37, 42 }, + { 44, 47 }, + { 58, 59 }, + { 63, 64 }, + { 91, 93 }, + { 95, 95 }, + { 123, 123 }, + { 125, 125 }, + { 161, 161 }, + { 167, 167 }, + { 171, 171 }, + { 182, 183 }, + { 187, 187 }, + { 191, 191 }, + { 894, 894 }, + { 903, 903 }, + { 1370, 1375 }, + { 1417, 1418 }, + { 1470, 1470 }, + { 1472, 1472 }, + { 1475, 1475 }, + { 1478, 1478 }, + { 1523, 1524 }, + { 1545, 1546 }, + { 1548, 1549 }, + { 1563, 1563 }, + { 1565, 1567 }, + { 1642, 1645 }, + { 1748, 1748 }, + { 1792, 1805 }, + { 2039, 2041 }, + { 2096, 2110 }, + { 2142, 2142 }, + { 2404, 2405 }, + { 2416, 2416 }, + { 2557, 2557 }, + { 2678, 2678 }, + { 2800, 2800 }, + { 3191, 3191 }, + { 3204, 3204 }, + { 3572, 3572 }, + { 3663, 3663 }, + { 3674, 3675 }, + { 3844, 3858 }, + { 3860, 3860 }, + { 3898, 3901 }, + { 3973, 3973 }, + { 4048, 4052 }, + { 4057, 4058 }, + { 4170, 4175 }, + { 4347, 4347 }, + { 4960, 4968 }, + { 5120, 5120 }, + { 5742, 5742 }, + { 5787, 5788 }, + { 5867, 5869 }, + { 5941, 5942 }, + { 6100, 6102 }, + { 6104, 6106 }, + { 6144, 6154 }, + { 6468, 6469 }, + { 6686, 6687 }, + { 6816, 6822 }, + { 6824, 6829 }, + { 7002, 7008 }, + { 7037, 7038 }, + { 7164, 7167 }, + { 7227, 7231 }, + { 7294, 7295 }, + { 7360, 7367 }, + { 7379, 7379 }, + { 8208, 8231 }, + { 8240, 8259 }, + { 8261, 8273 }, + { 8275, 8286 }, + { 8317, 8318 }, + { 8333, 8334 }, + { 8968, 8971 }, + { 9001, 9002 }, + { 10088, 10101 }, + { 10181, 10182 }, + { 10214, 10223 }, + { 10627, 10648 }, + { 10712, 10715 }, + { 10748, 10749 }, + { 11513, 11516 }, + { 11518, 11519 }, + { 11632, 11632 }, + { 11776, 11822 }, + { 11824, 11855 }, + { 11858, 11869 }, + { 12289, 12291 }, + { 12296, 12305 }, + { 12308, 12319 }, + { 12336, 12336 }, + { 12349, 12349 }, + { 12448, 12448 }, + { 12539, 12539 }, + { 42238, 42239 }, + { 42509, 42511 }, + { 42611, 42611 }, + { 42622, 42622 }, + { 42738, 42743 }, + { 43124, 43127 }, + { 43214, 43215 }, + { 43256, 43258 }, + { 43260, 43260 }, + { 43310, 43311 }, + { 43359, 43359 }, + { 43457, 43469 }, + { 43486, 43487 }, + { 43612, 43615 }, + { 43742, 43743 }, + { 43760, 43761 }, + { 44011, 44011 }, + { 64830, 64831 }, + { 65040, 65049 }, + { 65072, 65106 }, + { 65108, 65121 }, + { 65123, 65123 }, + { 65128, 65128 }, + { 65130, 65131 }, + { 65281, 65283 }, + { 65285, 65290 }, + { 65292, 65295 }, + { 65306, 65307 }, + { 65311, 65312 }, + { 65339, 65341 }, + { 65343, 65343 }, + { 65371, 65371 }, + { 65373, 65373 }, + { 65375, 65381 }, +}; +static const URange32 P_range32[] = { + { 65792, 65794 }, + { 66463, 66463 }, + { 66512, 66512 }, + { 66927, 66927 }, + { 67671, 67671 }, + { 67871, 67871 }, + { 67903, 67903 }, + { 68176, 68184 }, + { 68223, 68223 }, + { 68336, 68342 }, + { 68409, 68415 }, + { 68505, 68508 }, + { 69293, 69293 }, + { 69461, 69465 }, + { 69510, 69513 }, + { 69703, 69709 }, + { 69819, 69820 }, + { 69822, 69825 }, + { 69952, 69955 }, + { 70004, 70005 }, + { 70085, 70088 }, + { 70093, 70093 }, + { 70107, 70107 }, + { 70109, 70111 }, + { 70200, 70205 }, + { 70313, 70313 }, + { 70731, 70735 }, + { 70746, 70747 }, + { 70749, 70749 }, + { 70854, 70854 }, + { 71105, 71127 }, + { 71233, 71235 }, + { 71264, 71276 }, + { 71353, 71353 }, + { 71484, 71486 }, + { 71739, 71739 }, + { 72004, 72006 }, + { 72162, 72162 }, + { 72255, 72262 }, + { 72346, 72348 }, + { 72350, 72354 }, + { 72448, 72457 }, + { 72769, 72773 }, + { 72816, 72817 }, + { 73463, 73464 }, + { 73539, 73551 }, + { 73727, 73727 }, + { 74864, 74868 }, + { 77809, 77810 }, + { 92782, 92783 }, + { 92917, 92917 }, + { 92983, 92987 }, + { 92996, 92996 }, + { 93847, 93850 }, + { 94178, 94178 }, + { 113823, 113823 }, + { 121479, 121483 }, + { 125278, 125279 }, +}; +static const URange16 Pc_range16[] = { + { 95, 95 }, + { 8255, 8256 }, + { 8276, 8276 }, + { 65075, 65076 }, + { 65101, 65103 }, + { 65343, 65343 }, +}; +static const URange16 Pd_range16[] = { + { 45, 45 }, + { 1418, 1418 }, + { 1470, 1470 }, + { 5120, 5120 }, + { 6150, 6150 }, + { 8208, 8213 }, + { 11799, 11799 }, + { 11802, 11802 }, + { 11834, 11835 }, + { 11840, 11840 }, + { 11869, 11869 }, + { 12316, 12316 }, + { 12336, 12336 }, + { 12448, 12448 }, + { 65073, 65074 }, + { 65112, 65112 }, + { 65123, 65123 }, + { 65293, 65293 }, +}; +static const URange32 Pd_range32[] = { + { 69293, 69293 }, +}; +static const URange16 Pe_range16[] = { + { 41, 41 }, + { 93, 93 }, + { 125, 125 }, + { 3899, 3899 }, + { 3901, 3901 }, + { 5788, 5788 }, + { 8262, 8262 }, + { 8318, 8318 }, + { 8334, 8334 }, + { 8969, 8969 }, + { 8971, 8971 }, + { 9002, 9002 }, + { 10089, 10089 }, + { 10091, 10091 }, + { 10093, 10093 }, + { 10095, 10095 }, + { 10097, 10097 }, + { 10099, 10099 }, + { 10101, 10101 }, + { 10182, 10182 }, + { 10215, 10215 }, + { 10217, 10217 }, + { 10219, 10219 }, + { 10221, 10221 }, + { 10223, 10223 }, + { 10628, 10628 }, + { 10630, 10630 }, + { 10632, 10632 }, + { 10634, 10634 }, + { 10636, 10636 }, + { 10638, 10638 }, + { 10640, 10640 }, + { 10642, 10642 }, + { 10644, 10644 }, + { 10646, 10646 }, + { 10648, 10648 }, + { 10713, 10713 }, + { 10715, 10715 }, + { 10749, 10749 }, + { 11811, 11811 }, + { 11813, 11813 }, + { 11815, 11815 }, + { 11817, 11817 }, + { 11862, 11862 }, + { 11864, 11864 }, + { 11866, 11866 }, + { 11868, 11868 }, + { 12297, 12297 }, + { 12299, 12299 }, + { 12301, 12301 }, + { 12303, 12303 }, + { 12305, 12305 }, + { 12309, 12309 }, + { 12311, 12311 }, + { 12313, 12313 }, + { 12315, 12315 }, + { 12318, 12319 }, + { 64830, 64830 }, + { 65048, 65048 }, + { 65078, 65078 }, + { 65080, 65080 }, + { 65082, 65082 }, + { 65084, 65084 }, + { 65086, 65086 }, + { 65088, 65088 }, + { 65090, 65090 }, + { 65092, 65092 }, + { 65096, 65096 }, + { 65114, 65114 }, + { 65116, 65116 }, + { 65118, 65118 }, + { 65289, 65289 }, + { 65341, 65341 }, + { 65373, 65373 }, + { 65376, 65376 }, + { 65379, 65379 }, +}; +static const URange16 Pf_range16[] = { + { 187, 187 }, + { 8217, 8217 }, + { 8221, 8221 }, + { 8250, 8250 }, + { 11779, 11779 }, + { 11781, 11781 }, + { 11786, 11786 }, + { 11789, 11789 }, + { 11805, 11805 }, + { 11809, 11809 }, +}; +static const URange16 Pi_range16[] = { + { 171, 171 }, + { 8216, 8216 }, + { 8219, 8220 }, + { 8223, 8223 }, + { 8249, 8249 }, + { 11778, 11778 }, + { 11780, 11780 }, + { 11785, 11785 }, + { 11788, 11788 }, + { 11804, 11804 }, + { 11808, 11808 }, +}; +static const URange16 Po_range16[] = { + { 33, 35 }, + { 37, 39 }, + { 42, 42 }, + { 44, 44 }, + { 46, 47 }, + { 58, 59 }, + { 63, 64 }, + { 92, 92 }, + { 161, 161 }, + { 167, 167 }, + { 182, 183 }, + { 191, 191 }, + { 894, 894 }, + { 903, 903 }, + { 1370, 1375 }, + { 1417, 1417 }, + { 1472, 1472 }, + { 1475, 1475 }, + { 1478, 1478 }, + { 1523, 1524 }, + { 1545, 1546 }, + { 1548, 1549 }, + { 1563, 1563 }, + { 1565, 1567 }, + { 1642, 1645 }, + { 1748, 1748 }, + { 1792, 1805 }, + { 2039, 2041 }, + { 2096, 2110 }, + { 2142, 2142 }, + { 2404, 2405 }, + { 2416, 2416 }, + { 2557, 2557 }, + { 2678, 2678 }, + { 2800, 2800 }, + { 3191, 3191 }, + { 3204, 3204 }, + { 3572, 3572 }, + { 3663, 3663 }, + { 3674, 3675 }, + { 3844, 3858 }, + { 3860, 3860 }, + { 3973, 3973 }, + { 4048, 4052 }, + { 4057, 4058 }, + { 4170, 4175 }, + { 4347, 4347 }, + { 4960, 4968 }, + { 5742, 5742 }, + { 5867, 5869 }, + { 5941, 5942 }, + { 6100, 6102 }, + { 6104, 6106 }, + { 6144, 6149 }, + { 6151, 6154 }, + { 6468, 6469 }, + { 6686, 6687 }, + { 6816, 6822 }, + { 6824, 6829 }, + { 7002, 7008 }, + { 7037, 7038 }, + { 7164, 7167 }, + { 7227, 7231 }, + { 7294, 7295 }, + { 7360, 7367 }, + { 7379, 7379 }, + { 8214, 8215 }, + { 8224, 8231 }, + { 8240, 8248 }, + { 8251, 8254 }, + { 8257, 8259 }, + { 8263, 8273 }, + { 8275, 8275 }, + { 8277, 8286 }, + { 11513, 11516 }, + { 11518, 11519 }, + { 11632, 11632 }, + { 11776, 11777 }, + { 11782, 11784 }, + { 11787, 11787 }, + { 11790, 11798 }, + { 11800, 11801 }, + { 11803, 11803 }, + { 11806, 11807 }, + { 11818, 11822 }, + { 11824, 11833 }, + { 11836, 11839 }, + { 11841, 11841 }, + { 11843, 11855 }, + { 11858, 11860 }, + { 12289, 12291 }, + { 12349, 12349 }, + { 12539, 12539 }, + { 42238, 42239 }, + { 42509, 42511 }, + { 42611, 42611 }, + { 42622, 42622 }, + { 42738, 42743 }, + { 43124, 43127 }, + { 43214, 43215 }, + { 43256, 43258 }, + { 43260, 43260 }, + { 43310, 43311 }, + { 43359, 43359 }, + { 43457, 43469 }, + { 43486, 43487 }, + { 43612, 43615 }, + { 43742, 43743 }, + { 43760, 43761 }, + { 44011, 44011 }, + { 65040, 65046 }, + { 65049, 65049 }, + { 65072, 65072 }, + { 65093, 65094 }, + { 65097, 65100 }, + { 65104, 65106 }, + { 65108, 65111 }, + { 65119, 65121 }, + { 65128, 65128 }, + { 65130, 65131 }, + { 65281, 65283 }, + { 65285, 65287 }, + { 65290, 65290 }, + { 65292, 65292 }, + { 65294, 65295 }, + { 65306, 65307 }, + { 65311, 65312 }, + { 65340, 65340 }, + { 65377, 65377 }, + { 65380, 65381 }, +}; +static const URange32 Po_range32[] = { + { 65792, 65794 }, + { 66463, 66463 }, + { 66512, 66512 }, + { 66927, 66927 }, + { 67671, 67671 }, + { 67871, 67871 }, + { 67903, 67903 }, + { 68176, 68184 }, + { 68223, 68223 }, + { 68336, 68342 }, + { 68409, 68415 }, + { 68505, 68508 }, + { 69461, 69465 }, + { 69510, 69513 }, + { 69703, 69709 }, + { 69819, 69820 }, + { 69822, 69825 }, + { 69952, 69955 }, + { 70004, 70005 }, + { 70085, 70088 }, + { 70093, 70093 }, + { 70107, 70107 }, + { 70109, 70111 }, + { 70200, 70205 }, + { 70313, 70313 }, + { 70731, 70735 }, + { 70746, 70747 }, + { 70749, 70749 }, + { 70854, 70854 }, + { 71105, 71127 }, + { 71233, 71235 }, + { 71264, 71276 }, + { 71353, 71353 }, + { 71484, 71486 }, + { 71739, 71739 }, + { 72004, 72006 }, + { 72162, 72162 }, + { 72255, 72262 }, + { 72346, 72348 }, + { 72350, 72354 }, + { 72448, 72457 }, + { 72769, 72773 }, + { 72816, 72817 }, + { 73463, 73464 }, + { 73539, 73551 }, + { 73727, 73727 }, + { 74864, 74868 }, + { 77809, 77810 }, + { 92782, 92783 }, + { 92917, 92917 }, + { 92983, 92987 }, + { 92996, 92996 }, + { 93847, 93850 }, + { 94178, 94178 }, + { 113823, 113823 }, + { 121479, 121483 }, + { 125278, 125279 }, +}; +static const URange16 Ps_range16[] = { + { 40, 40 }, + { 91, 91 }, + { 123, 123 }, + { 3898, 3898 }, + { 3900, 3900 }, + { 5787, 5787 }, + { 8218, 8218 }, + { 8222, 8222 }, + { 8261, 8261 }, + { 8317, 8317 }, + { 8333, 8333 }, + { 8968, 8968 }, + { 8970, 8970 }, + { 9001, 9001 }, + { 10088, 10088 }, + { 10090, 10090 }, + { 10092, 10092 }, + { 10094, 10094 }, + { 10096, 10096 }, + { 10098, 10098 }, + { 10100, 10100 }, + { 10181, 10181 }, + { 10214, 10214 }, + { 10216, 10216 }, + { 10218, 10218 }, + { 10220, 10220 }, + { 10222, 10222 }, + { 10627, 10627 }, + { 10629, 10629 }, + { 10631, 10631 }, + { 10633, 10633 }, + { 10635, 10635 }, + { 10637, 10637 }, + { 10639, 10639 }, + { 10641, 10641 }, + { 10643, 10643 }, + { 10645, 10645 }, + { 10647, 10647 }, + { 10712, 10712 }, + { 10714, 10714 }, + { 10748, 10748 }, + { 11810, 11810 }, + { 11812, 11812 }, + { 11814, 11814 }, + { 11816, 11816 }, + { 11842, 11842 }, + { 11861, 11861 }, + { 11863, 11863 }, + { 11865, 11865 }, + { 11867, 11867 }, + { 12296, 12296 }, + { 12298, 12298 }, + { 12300, 12300 }, + { 12302, 12302 }, + { 12304, 12304 }, + { 12308, 12308 }, + { 12310, 12310 }, + { 12312, 12312 }, + { 12314, 12314 }, + { 12317, 12317 }, + { 64831, 64831 }, + { 65047, 65047 }, + { 65077, 65077 }, + { 65079, 65079 }, + { 65081, 65081 }, + { 65083, 65083 }, + { 65085, 65085 }, + { 65087, 65087 }, + { 65089, 65089 }, + { 65091, 65091 }, + { 65095, 65095 }, + { 65113, 65113 }, + { 65115, 65115 }, + { 65117, 65117 }, + { 65288, 65288 }, + { 65339, 65339 }, + { 65371, 65371 }, + { 65375, 65375 }, + { 65378, 65378 }, +}; +static const URange16 S_range16[] = { + { 36, 36 }, + { 43, 43 }, + { 60, 62 }, + { 94, 94 }, + { 96, 96 }, + { 124, 124 }, + { 126, 126 }, + { 162, 166 }, + { 168, 169 }, + { 172, 172 }, + { 174, 177 }, + { 180, 180 }, + { 184, 184 }, + { 215, 215 }, + { 247, 247 }, + { 706, 709 }, + { 722, 735 }, + { 741, 747 }, + { 749, 749 }, + { 751, 767 }, + { 885, 885 }, + { 900, 901 }, + { 1014, 1014 }, + { 1154, 1154 }, + { 1421, 1423 }, + { 1542, 1544 }, + { 1547, 1547 }, + { 1550, 1551 }, + { 1758, 1758 }, + { 1769, 1769 }, + { 1789, 1790 }, + { 2038, 2038 }, + { 2046, 2047 }, + { 2184, 2184 }, + { 2546, 2547 }, + { 2554, 2555 }, + { 2801, 2801 }, + { 2928, 2928 }, + { 3059, 3066 }, + { 3199, 3199 }, + { 3407, 3407 }, + { 3449, 3449 }, + { 3647, 3647 }, + { 3841, 3843 }, + { 3859, 3859 }, + { 3861, 3863 }, + { 3866, 3871 }, + { 3892, 3892 }, + { 3894, 3894 }, + { 3896, 3896 }, + { 4030, 4037 }, + { 4039, 4044 }, + { 4046, 4047 }, + { 4053, 4056 }, + { 4254, 4255 }, + { 5008, 5017 }, + { 5741, 5741 }, + { 6107, 6107 }, + { 6464, 6464 }, + { 6622, 6655 }, + { 7009, 7018 }, + { 7028, 7036 }, + { 8125, 8125 }, + { 8127, 8129 }, + { 8141, 8143 }, + { 8157, 8159 }, + { 8173, 8175 }, + { 8189, 8190 }, + { 8260, 8260 }, + { 8274, 8274 }, + { 8314, 8316 }, + { 8330, 8332 }, + { 8352, 8384 }, + { 8448, 8449 }, + { 8451, 8454 }, + { 8456, 8457 }, + { 8468, 8468 }, + { 8470, 8472 }, + { 8478, 8483 }, + { 8485, 8485 }, + { 8487, 8487 }, + { 8489, 8489 }, + { 8494, 8494 }, + { 8506, 8507 }, + { 8512, 8516 }, + { 8522, 8525 }, + { 8527, 8527 }, + { 8586, 8587 }, + { 8592, 8967 }, + { 8972, 9000 }, + { 9003, 9254 }, + { 9280, 9290 }, + { 9372, 9449 }, + { 9472, 10087 }, + { 10132, 10180 }, + { 10183, 10213 }, + { 10224, 10626 }, + { 10649, 10711 }, + { 10716, 10747 }, + { 10750, 11123 }, + { 11126, 11157 }, + { 11159, 11263 }, + { 11493, 11498 }, + { 11856, 11857 }, + { 11904, 11929 }, + { 11931, 12019 }, + { 12032, 12245 }, + { 12272, 12283 }, + { 12292, 12292 }, + { 12306, 12307 }, + { 12320, 12320 }, + { 12342, 12343 }, + { 12350, 12351 }, + { 12443, 12444 }, + { 12688, 12689 }, + { 12694, 12703 }, + { 12736, 12771 }, + { 12800, 12830 }, + { 12842, 12871 }, + { 12880, 12880 }, + { 12896, 12927 }, + { 12938, 12976 }, + { 12992, 13311 }, + { 19904, 19967 }, + { 42128, 42182 }, + { 42752, 42774 }, + { 42784, 42785 }, + { 42889, 42890 }, + { 43048, 43051 }, + { 43062, 43065 }, + { 43639, 43641 }, + { 43867, 43867 }, + { 43882, 43883 }, + { 64297, 64297 }, + { 64434, 64450 }, + { 64832, 64847 }, + { 64975, 64975 }, + { 65020, 65023 }, + { 65122, 65122 }, + { 65124, 65126 }, + { 65129, 65129 }, + { 65284, 65284 }, + { 65291, 65291 }, + { 65308, 65310 }, + { 65342, 65342 }, + { 65344, 65344 }, + { 65372, 65372 }, + { 65374, 65374 }, + { 65504, 65510 }, + { 65512, 65518 }, + { 65532, 65533 }, +}; +static const URange32 S_range32[] = { + { 65847, 65855 }, + { 65913, 65929 }, + { 65932, 65934 }, + { 65936, 65948 }, + { 65952, 65952 }, + { 66000, 66044 }, + { 67703, 67704 }, + { 68296, 68296 }, + { 71487, 71487 }, + { 73685, 73713 }, + { 92988, 92991 }, + { 92997, 92997 }, + { 113820, 113820 }, + { 118608, 118723 }, + { 118784, 119029 }, + { 119040, 119078 }, + { 119081, 119140 }, + { 119146, 119148 }, + { 119171, 119172 }, + { 119180, 119209 }, + { 119214, 119274 }, + { 119296, 119361 }, + { 119365, 119365 }, + { 119552, 119638 }, + { 120513, 120513 }, + { 120539, 120539 }, + { 120571, 120571 }, + { 120597, 120597 }, + { 120629, 120629 }, + { 120655, 120655 }, + { 120687, 120687 }, + { 120713, 120713 }, + { 120745, 120745 }, + { 120771, 120771 }, + { 120832, 121343 }, + { 121399, 121402 }, + { 121453, 121460 }, + { 121462, 121475 }, + { 121477, 121478 }, + { 123215, 123215 }, + { 123647, 123647 }, + { 126124, 126124 }, + { 126128, 126128 }, + { 126254, 126254 }, + { 126704, 126705 }, + { 126976, 127019 }, + { 127024, 127123 }, + { 127136, 127150 }, + { 127153, 127167 }, + { 127169, 127183 }, + { 127185, 127221 }, + { 127245, 127405 }, + { 127462, 127490 }, + { 127504, 127547 }, + { 127552, 127560 }, + { 127568, 127569 }, + { 127584, 127589 }, + { 127744, 128727 }, + { 128732, 128748 }, + { 128752, 128764 }, + { 128768, 128886 }, + { 128891, 128985 }, + { 128992, 129003 }, + { 129008, 129008 }, + { 129024, 129035 }, + { 129040, 129095 }, + { 129104, 129113 }, + { 129120, 129159 }, + { 129168, 129197 }, + { 129200, 129201 }, + { 129280, 129619 }, + { 129632, 129645 }, + { 129648, 129660 }, + { 129664, 129672 }, + { 129680, 129725 }, + { 129727, 129733 }, + { 129742, 129755 }, + { 129760, 129768 }, + { 129776, 129784 }, + { 129792, 129938 }, + { 129940, 129994 }, +}; +static const URange16 Sc_range16[] = { + { 36, 36 }, + { 162, 165 }, + { 1423, 1423 }, + { 1547, 1547 }, + { 2046, 2047 }, + { 2546, 2547 }, + { 2555, 2555 }, + { 2801, 2801 }, + { 3065, 3065 }, + { 3647, 3647 }, + { 6107, 6107 }, + { 8352, 8384 }, + { 43064, 43064 }, + { 65020, 65020 }, + { 65129, 65129 }, + { 65284, 65284 }, + { 65504, 65505 }, + { 65509, 65510 }, +}; +static const URange32 Sc_range32[] = { + { 73693, 73696 }, + { 123647, 123647 }, + { 126128, 126128 }, +}; +static const URange16 Sk_range16[] = { + { 94, 94 }, + { 96, 96 }, + { 168, 168 }, + { 175, 175 }, + { 180, 180 }, + { 184, 184 }, + { 706, 709 }, + { 722, 735 }, + { 741, 747 }, + { 749, 749 }, + { 751, 767 }, + { 885, 885 }, + { 900, 901 }, + { 2184, 2184 }, + { 8125, 8125 }, + { 8127, 8129 }, + { 8141, 8143 }, + { 8157, 8159 }, + { 8173, 8175 }, + { 8189, 8190 }, + { 12443, 12444 }, + { 42752, 42774 }, + { 42784, 42785 }, + { 42889, 42890 }, + { 43867, 43867 }, + { 43882, 43883 }, + { 64434, 64450 }, + { 65342, 65342 }, + { 65344, 65344 }, + { 65507, 65507 }, +}; +static const URange32 Sk_range32[] = { + { 127995, 127999 }, +}; +static const URange16 Sm_range16[] = { + { 43, 43 }, + { 60, 62 }, + { 124, 124 }, + { 126, 126 }, + { 172, 172 }, + { 177, 177 }, + { 215, 215 }, + { 247, 247 }, + { 1014, 1014 }, + { 1542, 1544 }, + { 8260, 8260 }, + { 8274, 8274 }, + { 8314, 8316 }, + { 8330, 8332 }, + { 8472, 8472 }, + { 8512, 8516 }, + { 8523, 8523 }, + { 8592, 8596 }, + { 8602, 8603 }, + { 8608, 8608 }, + { 8611, 8611 }, + { 8614, 8614 }, + { 8622, 8622 }, + { 8654, 8655 }, + { 8658, 8658 }, + { 8660, 8660 }, + { 8692, 8959 }, + { 8992, 8993 }, + { 9084, 9084 }, + { 9115, 9139 }, + { 9180, 9185 }, + { 9655, 9655 }, + { 9665, 9665 }, + { 9720, 9727 }, + { 9839, 9839 }, + { 10176, 10180 }, + { 10183, 10213 }, + { 10224, 10239 }, + { 10496, 10626 }, + { 10649, 10711 }, + { 10716, 10747 }, + { 10750, 11007 }, + { 11056, 11076 }, + { 11079, 11084 }, + { 64297, 64297 }, + { 65122, 65122 }, + { 65124, 65126 }, + { 65291, 65291 }, + { 65308, 65310 }, + { 65372, 65372 }, + { 65374, 65374 }, + { 65506, 65506 }, + { 65513, 65516 }, +}; +static const URange32 Sm_range32[] = { + { 120513, 120513 }, + { 120539, 120539 }, + { 120571, 120571 }, + { 120597, 120597 }, + { 120629, 120629 }, + { 120655, 120655 }, + { 120687, 120687 }, + { 120713, 120713 }, + { 120745, 120745 }, + { 120771, 120771 }, + { 126704, 126705 }, +}; +static const URange16 So_range16[] = { + { 166, 166 }, + { 169, 169 }, + { 174, 174 }, + { 176, 176 }, + { 1154, 1154 }, + { 1421, 1422 }, + { 1550, 1551 }, + { 1758, 1758 }, + { 1769, 1769 }, + { 1789, 1790 }, + { 2038, 2038 }, + { 2554, 2554 }, + { 2928, 2928 }, + { 3059, 3064 }, + { 3066, 3066 }, + { 3199, 3199 }, + { 3407, 3407 }, + { 3449, 3449 }, + { 3841, 3843 }, + { 3859, 3859 }, + { 3861, 3863 }, + { 3866, 3871 }, + { 3892, 3892 }, + { 3894, 3894 }, + { 3896, 3896 }, + { 4030, 4037 }, + { 4039, 4044 }, + { 4046, 4047 }, + { 4053, 4056 }, + { 4254, 4255 }, + { 5008, 5017 }, + { 5741, 5741 }, + { 6464, 6464 }, + { 6622, 6655 }, + { 7009, 7018 }, + { 7028, 7036 }, + { 8448, 8449 }, + { 8451, 8454 }, + { 8456, 8457 }, + { 8468, 8468 }, + { 8470, 8471 }, + { 8478, 8483 }, + { 8485, 8485 }, + { 8487, 8487 }, + { 8489, 8489 }, + { 8494, 8494 }, + { 8506, 8507 }, + { 8522, 8522 }, + { 8524, 8525 }, + { 8527, 8527 }, + { 8586, 8587 }, + { 8597, 8601 }, + { 8604, 8607 }, + { 8609, 8610 }, + { 8612, 8613 }, + { 8615, 8621 }, + { 8623, 8653 }, + { 8656, 8657 }, + { 8659, 8659 }, + { 8661, 8691 }, + { 8960, 8967 }, + { 8972, 8991 }, + { 8994, 9000 }, + { 9003, 9083 }, + { 9085, 9114 }, + { 9140, 9179 }, + { 9186, 9254 }, + { 9280, 9290 }, + { 9372, 9449 }, + { 9472, 9654 }, + { 9656, 9664 }, + { 9666, 9719 }, + { 9728, 9838 }, + { 9840, 10087 }, + { 10132, 10175 }, + { 10240, 10495 }, + { 11008, 11055 }, + { 11077, 11078 }, + { 11085, 11123 }, + { 11126, 11157 }, + { 11159, 11263 }, + { 11493, 11498 }, + { 11856, 11857 }, + { 11904, 11929 }, + { 11931, 12019 }, + { 12032, 12245 }, + { 12272, 12283 }, + { 12292, 12292 }, + { 12306, 12307 }, + { 12320, 12320 }, + { 12342, 12343 }, + { 12350, 12351 }, + { 12688, 12689 }, + { 12694, 12703 }, + { 12736, 12771 }, + { 12800, 12830 }, + { 12842, 12871 }, + { 12880, 12880 }, + { 12896, 12927 }, + { 12938, 12976 }, + { 12992, 13311 }, + { 19904, 19967 }, + { 42128, 42182 }, + { 43048, 43051 }, + { 43062, 43063 }, + { 43065, 43065 }, + { 43639, 43641 }, + { 64832, 64847 }, + { 64975, 64975 }, + { 65021, 65023 }, + { 65508, 65508 }, + { 65512, 65512 }, + { 65517, 65518 }, + { 65532, 65533 }, +}; +static const URange32 So_range32[] = { + { 65847, 65855 }, + { 65913, 65929 }, + { 65932, 65934 }, + { 65936, 65948 }, + { 65952, 65952 }, + { 66000, 66044 }, + { 67703, 67704 }, + { 68296, 68296 }, + { 71487, 71487 }, + { 73685, 73692 }, + { 73697, 73713 }, + { 92988, 92991 }, + { 92997, 92997 }, + { 113820, 113820 }, + { 118608, 118723 }, + { 118784, 119029 }, + { 119040, 119078 }, + { 119081, 119140 }, + { 119146, 119148 }, + { 119171, 119172 }, + { 119180, 119209 }, + { 119214, 119274 }, + { 119296, 119361 }, + { 119365, 119365 }, + { 119552, 119638 }, + { 120832, 121343 }, + { 121399, 121402 }, + { 121453, 121460 }, + { 121462, 121475 }, + { 121477, 121478 }, + { 123215, 123215 }, + { 126124, 126124 }, + { 126254, 126254 }, + { 126976, 127019 }, + { 127024, 127123 }, + { 127136, 127150 }, + { 127153, 127167 }, + { 127169, 127183 }, + { 127185, 127221 }, + { 127245, 127405 }, + { 127462, 127490 }, + { 127504, 127547 }, + { 127552, 127560 }, + { 127568, 127569 }, + { 127584, 127589 }, + { 127744, 127994 }, + { 128000, 128727 }, + { 128732, 128748 }, + { 128752, 128764 }, + { 128768, 128886 }, + { 128891, 128985 }, + { 128992, 129003 }, + { 129008, 129008 }, + { 129024, 129035 }, + { 129040, 129095 }, + { 129104, 129113 }, + { 129120, 129159 }, + { 129168, 129197 }, + { 129200, 129201 }, + { 129280, 129619 }, + { 129632, 129645 }, + { 129648, 129660 }, + { 129664, 129672 }, + { 129680, 129725 }, + { 129727, 129733 }, + { 129742, 129755 }, + { 129760, 129768 }, + { 129776, 129784 }, + { 129792, 129938 }, + { 129940, 129994 }, +}; +static const URange16 Z_range16[] = { + { 32, 32 }, + { 160, 160 }, + { 5760, 5760 }, + { 8192, 8202 }, + { 8232, 8233 }, + { 8239, 8239 }, + { 8287, 8287 }, + { 12288, 12288 }, +}; +static const URange16 Zl_range16[] = { + { 8232, 8232 }, +}; +static const URange16 Zp_range16[] = { + { 8233, 8233 }, +}; +static const URange16 Zs_range16[] = { + { 32, 32 }, + { 160, 160 }, + { 5760, 5760 }, + { 8192, 8202 }, + { 8239, 8239 }, + { 8287, 8287 }, + { 12288, 12288 }, +}; +static const URange32 Adlam_range32[] = { + { 125184, 125259 }, + { 125264, 125273 }, + { 125278, 125279 }, +}; +static const URange32 Ahom_range32[] = { + { 71424, 71450 }, + { 71453, 71467 }, + { 71472, 71494 }, +}; +static const URange32 Anatolian_Hieroglyphs_range32[] = { + { 82944, 83526 }, +}; +static const URange16 Arabic_range16[] = { + { 1536, 1540 }, + { 1542, 1547 }, + { 1549, 1562 }, + { 1564, 1566 }, + { 1568, 1599 }, + { 1601, 1610 }, + { 1622, 1647 }, + { 1649, 1756 }, + { 1758, 1791 }, + { 1872, 1919 }, + { 2160, 2190 }, + { 2192, 2193 }, + { 2200, 2273 }, + { 2275, 2303 }, + { 64336, 64450 }, + { 64467, 64829 }, + { 64832, 64911 }, + { 64914, 64967 }, + { 64975, 64975 }, + { 65008, 65023 }, + { 65136, 65140 }, + { 65142, 65276 }, +}; +static const URange32 Arabic_range32[] = { + { 69216, 69246 }, + { 69373, 69375 }, + { 126464, 126467 }, + { 126469, 126495 }, + { 126497, 126498 }, + { 126500, 126500 }, + { 126503, 126503 }, + { 126505, 126514 }, + { 126516, 126519 }, + { 126521, 126521 }, + { 126523, 126523 }, + { 126530, 126530 }, + { 126535, 126535 }, + { 126537, 126537 }, + { 126539, 126539 }, + { 126541, 126543 }, + { 126545, 126546 }, + { 126548, 126548 }, + { 126551, 126551 }, + { 126553, 126553 }, + { 126555, 126555 }, + { 126557, 126557 }, + { 126559, 126559 }, + { 126561, 126562 }, + { 126564, 126564 }, + { 126567, 126570 }, + { 126572, 126578 }, + { 126580, 126583 }, + { 126585, 126588 }, + { 126590, 126590 }, + { 126592, 126601 }, + { 126603, 126619 }, + { 126625, 126627 }, + { 126629, 126633 }, + { 126635, 126651 }, + { 126704, 126705 }, +}; +static const URange16 Armenian_range16[] = { + { 1329, 1366 }, + { 1369, 1418 }, + { 1421, 1423 }, + { 64275, 64279 }, +}; +static const URange32 Avestan_range32[] = { + { 68352, 68405 }, + { 68409, 68415 }, +}; +static const URange16 Balinese_range16[] = { + { 6912, 6988 }, + { 6992, 7038 }, +}; +static const URange16 Bamum_range16[] = { + { 42656, 42743 }, +}; +static const URange32 Bamum_range32[] = { + { 92160, 92728 }, +}; +static const URange32 Bassa_Vah_range32[] = { + { 92880, 92909 }, + { 92912, 92917 }, +}; +static const URange16 Batak_range16[] = { + { 7104, 7155 }, + { 7164, 7167 }, +}; +static const URange16 Bengali_range16[] = { + { 2432, 2435 }, + { 2437, 2444 }, + { 2447, 2448 }, + { 2451, 2472 }, + { 2474, 2480 }, + { 2482, 2482 }, + { 2486, 2489 }, + { 2492, 2500 }, + { 2503, 2504 }, + { 2507, 2510 }, + { 2519, 2519 }, + { 2524, 2525 }, + { 2527, 2531 }, + { 2534, 2558 }, +}; +static const URange32 Bhaiksuki_range32[] = { + { 72704, 72712 }, + { 72714, 72758 }, + { 72760, 72773 }, + { 72784, 72812 }, +}; +static const URange16 Bopomofo_range16[] = { + { 746, 747 }, + { 12549, 12591 }, + { 12704, 12735 }, +}; +static const URange32 Brahmi_range32[] = { + { 69632, 69709 }, + { 69714, 69749 }, + { 69759, 69759 }, +}; +static const URange16 Braille_range16[] = { + { 10240, 10495 }, +}; +static const URange16 Buginese_range16[] = { + { 6656, 6683 }, + { 6686, 6687 }, +}; +static const URange16 Buhid_range16[] = { + { 5952, 5971 }, +}; +static const URange16 Canadian_Aboriginal_range16[] = { + { 5120, 5759 }, + { 6320, 6389 }, +}; +static const URange32 Canadian_Aboriginal_range32[] = { + { 72368, 72383 }, +}; +static const URange32 Carian_range32[] = { + { 66208, 66256 }, +}; +static const URange32 Caucasian_Albanian_range32[] = { + { 66864, 66915 }, + { 66927, 66927 }, +}; +static const URange32 Chakma_range32[] = { + { 69888, 69940 }, + { 69942, 69959 }, +}; +static const URange16 Cham_range16[] = { + { 43520, 43574 }, + { 43584, 43597 }, + { 43600, 43609 }, + { 43612, 43615 }, +}; +static const URange16 Cherokee_range16[] = { + { 5024, 5109 }, + { 5112, 5117 }, + { 43888, 43967 }, +}; +static const URange32 Chorasmian_range32[] = { + { 69552, 69579 }, +}; +static const URange16 Common_range16[] = { + { 0, 64 }, + { 91, 96 }, + { 123, 169 }, + { 171, 185 }, + { 187, 191 }, + { 215, 215 }, + { 247, 247 }, + { 697, 735 }, + { 741, 745 }, + { 748, 767 }, + { 884, 884 }, + { 894, 894 }, + { 901, 901 }, + { 903, 903 }, + { 1541, 1541 }, + { 1548, 1548 }, + { 1563, 1563 }, + { 1567, 1567 }, + { 1600, 1600 }, + { 1757, 1757 }, + { 2274, 2274 }, + { 2404, 2405 }, + { 3647, 3647 }, + { 4053, 4056 }, + { 4347, 4347 }, + { 5867, 5869 }, + { 5941, 5942 }, + { 6146, 6147 }, + { 6149, 6149 }, + { 7379, 7379 }, + { 7393, 7393 }, + { 7401, 7404 }, + { 7406, 7411 }, + { 7413, 7415 }, + { 7418, 7418 }, + { 8192, 8203 }, + { 8206, 8292 }, + { 8294, 8304 }, + { 8308, 8318 }, + { 8320, 8334 }, + { 8352, 8384 }, + { 8448, 8485 }, + { 8487, 8489 }, + { 8492, 8497 }, + { 8499, 8525 }, + { 8527, 8543 }, + { 8585, 8587 }, + { 8592, 9254 }, + { 9280, 9290 }, + { 9312, 10239 }, + { 10496, 11123 }, + { 11126, 11157 }, + { 11159, 11263 }, + { 11776, 11869 }, + { 12272, 12283 }, + { 12288, 12292 }, + { 12294, 12294 }, + { 12296, 12320 }, + { 12336, 12343 }, + { 12348, 12351 }, + { 12443, 12444 }, + { 12448, 12448 }, + { 12539, 12540 }, + { 12688, 12703 }, + { 12736, 12771 }, + { 12832, 12895 }, + { 12927, 13007 }, + { 13055, 13055 }, + { 13144, 13311 }, + { 19904, 19967 }, + { 42752, 42785 }, + { 42888, 42890 }, + { 43056, 43065 }, + { 43310, 43310 }, + { 43471, 43471 }, + { 43867, 43867 }, + { 43882, 43883 }, + { 64830, 64831 }, + { 65040, 65049 }, + { 65072, 65106 }, + { 65108, 65126 }, + { 65128, 65131 }, + { 65279, 65279 }, + { 65281, 65312 }, + { 65339, 65344 }, + { 65371, 65381 }, + { 65392, 65392 }, + { 65438, 65439 }, + { 65504, 65510 }, + { 65512, 65518 }, + { 65529, 65533 }, +}; +static const URange32 Common_range32[] = { + { 65792, 65794 }, + { 65799, 65843 }, + { 65847, 65855 }, + { 65936, 65948 }, + { 66000, 66044 }, + { 66273, 66299 }, + { 113824, 113827 }, + { 118608, 118723 }, + { 118784, 119029 }, + { 119040, 119078 }, + { 119081, 119142 }, + { 119146, 119162 }, + { 119171, 119172 }, + { 119180, 119209 }, + { 119214, 119274 }, + { 119488, 119507 }, + { 119520, 119539 }, + { 119552, 119638 }, + { 119648, 119672 }, + { 119808, 119892 }, + { 119894, 119964 }, + { 119966, 119967 }, + { 119970, 119970 }, + { 119973, 119974 }, + { 119977, 119980 }, + { 119982, 119993 }, + { 119995, 119995 }, + { 119997, 120003 }, + { 120005, 120069 }, + { 120071, 120074 }, + { 120077, 120084 }, + { 120086, 120092 }, + { 120094, 120121 }, + { 120123, 120126 }, + { 120128, 120132 }, + { 120134, 120134 }, + { 120138, 120144 }, + { 120146, 120485 }, + { 120488, 120779 }, + { 120782, 120831 }, + { 126065, 126132 }, + { 126209, 126269 }, + { 126976, 127019 }, + { 127024, 127123 }, + { 127136, 127150 }, + { 127153, 127167 }, + { 127169, 127183 }, + { 127185, 127221 }, + { 127232, 127405 }, + { 127462, 127487 }, + { 127489, 127490 }, + { 127504, 127547 }, + { 127552, 127560 }, + { 127568, 127569 }, + { 127584, 127589 }, + { 127744, 128727 }, + { 128732, 128748 }, + { 128752, 128764 }, + { 128768, 128886 }, + { 128891, 128985 }, + { 128992, 129003 }, + { 129008, 129008 }, + { 129024, 129035 }, + { 129040, 129095 }, + { 129104, 129113 }, + { 129120, 129159 }, + { 129168, 129197 }, + { 129200, 129201 }, + { 129280, 129619 }, + { 129632, 129645 }, + { 129648, 129660 }, + { 129664, 129672 }, + { 129680, 129725 }, + { 129727, 129733 }, + { 129742, 129755 }, + { 129760, 129768 }, + { 129776, 129784 }, + { 129792, 129938 }, + { 129940, 129994 }, + { 130032, 130041 }, + { 917505, 917505 }, + { 917536, 917631 }, +}; +static const URange16 Coptic_range16[] = { + { 994, 1007 }, + { 11392, 11507 }, + { 11513, 11519 }, +}; +static const URange32 Cuneiform_range32[] = { + { 73728, 74649 }, + { 74752, 74862 }, + { 74864, 74868 }, + { 74880, 75075 }, +}; +static const URange32 Cypriot_range32[] = { + { 67584, 67589 }, + { 67592, 67592 }, + { 67594, 67637 }, + { 67639, 67640 }, + { 67644, 67644 }, + { 67647, 67647 }, +}; +static const URange32 Cypro_Minoan_range32[] = { + { 77712, 77810 }, +}; +static const URange16 Cyrillic_range16[] = { + { 1024, 1156 }, + { 1159, 1327 }, + { 7296, 7304 }, + { 7467, 7467 }, + { 7544, 7544 }, + { 11744, 11775 }, + { 42560, 42655 }, + { 65070, 65071 }, +}; +static const URange32 Cyrillic_range32[] = { + { 122928, 122989 }, + { 123023, 123023 }, +}; +static const URange32 Deseret_range32[] = { + { 66560, 66639 }, +}; +static const URange16 Devanagari_range16[] = { + { 2304, 2384 }, + { 2389, 2403 }, + { 2406, 2431 }, + { 43232, 43263 }, +}; +static const URange32 Devanagari_range32[] = { + { 72448, 72457 }, +}; +static const URange32 Dives_Akuru_range32[] = { + { 71936, 71942 }, + { 71945, 71945 }, + { 71948, 71955 }, + { 71957, 71958 }, + { 71960, 71989 }, + { 71991, 71992 }, + { 71995, 72006 }, + { 72016, 72025 }, +}; +static const URange32 Dogra_range32[] = { + { 71680, 71739 }, +}; +static const URange32 Duployan_range32[] = { + { 113664, 113770 }, + { 113776, 113788 }, + { 113792, 113800 }, + { 113808, 113817 }, + { 113820, 113823 }, +}; +static const URange32 Egyptian_Hieroglyphs_range32[] = { + { 77824, 78933 }, +}; +static const URange32 Elbasan_range32[] = { + { 66816, 66855 }, +}; +static const URange32 Elymaic_range32[] = { + { 69600, 69622 }, +}; +static const URange16 Ethiopic_range16[] = { + { 4608, 4680 }, + { 4682, 4685 }, + { 4688, 4694 }, + { 4696, 4696 }, + { 4698, 4701 }, + { 4704, 4744 }, + { 4746, 4749 }, + { 4752, 4784 }, + { 4786, 4789 }, + { 4792, 4798 }, + { 4800, 4800 }, + { 4802, 4805 }, + { 4808, 4822 }, + { 4824, 4880 }, + { 4882, 4885 }, + { 4888, 4954 }, + { 4957, 4988 }, + { 4992, 5017 }, + { 11648, 11670 }, + { 11680, 11686 }, + { 11688, 11694 }, + { 11696, 11702 }, + { 11704, 11710 }, + { 11712, 11718 }, + { 11720, 11726 }, + { 11728, 11734 }, + { 11736, 11742 }, + { 43777, 43782 }, + { 43785, 43790 }, + { 43793, 43798 }, + { 43808, 43814 }, + { 43816, 43822 }, +}; +static const URange32 Ethiopic_range32[] = { + { 124896, 124902 }, + { 124904, 124907 }, + { 124909, 124910 }, + { 124912, 124926 }, +}; +static const URange16 Georgian_range16[] = { + { 4256, 4293 }, + { 4295, 4295 }, + { 4301, 4301 }, + { 4304, 4346 }, + { 4348, 4351 }, + { 7312, 7354 }, + { 7357, 7359 }, + { 11520, 11557 }, + { 11559, 11559 }, + { 11565, 11565 }, +}; +static const URange16 Glagolitic_range16[] = { + { 11264, 11359 }, +}; +static const URange32 Glagolitic_range32[] = { + { 122880, 122886 }, + { 122888, 122904 }, + { 122907, 122913 }, + { 122915, 122916 }, + { 122918, 122922 }, +}; +static const URange32 Gothic_range32[] = { + { 66352, 66378 }, +}; +static const URange32 Grantha_range32[] = { + { 70400, 70403 }, + { 70405, 70412 }, + { 70415, 70416 }, + { 70419, 70440 }, + { 70442, 70448 }, + { 70450, 70451 }, + { 70453, 70457 }, + { 70460, 70468 }, + { 70471, 70472 }, + { 70475, 70477 }, + { 70480, 70480 }, + { 70487, 70487 }, + { 70493, 70499 }, + { 70502, 70508 }, + { 70512, 70516 }, +}; +static const URange16 Greek_range16[] = { + { 880, 883 }, + { 885, 887 }, + { 890, 893 }, + { 895, 895 }, + { 900, 900 }, + { 902, 902 }, + { 904, 906 }, + { 908, 908 }, + { 910, 929 }, + { 931, 993 }, + { 1008, 1023 }, + { 7462, 7466 }, + { 7517, 7521 }, + { 7526, 7530 }, + { 7615, 7615 }, + { 7936, 7957 }, + { 7960, 7965 }, + { 7968, 8005 }, + { 8008, 8013 }, + { 8016, 8023 }, + { 8025, 8025 }, + { 8027, 8027 }, + { 8029, 8029 }, + { 8031, 8061 }, + { 8064, 8116 }, + { 8118, 8132 }, + { 8134, 8147 }, + { 8150, 8155 }, + { 8157, 8175 }, + { 8178, 8180 }, + { 8182, 8190 }, + { 8486, 8486 }, + { 43877, 43877 }, +}; +static const URange32 Greek_range32[] = { + { 65856, 65934 }, + { 65952, 65952 }, + { 119296, 119365 }, +}; +static const URange16 Gujarati_range16[] = { + { 2689, 2691 }, + { 2693, 2701 }, + { 2703, 2705 }, + { 2707, 2728 }, + { 2730, 2736 }, + { 2738, 2739 }, + { 2741, 2745 }, + { 2748, 2757 }, + { 2759, 2761 }, + { 2763, 2765 }, + { 2768, 2768 }, + { 2784, 2787 }, + { 2790, 2801 }, + { 2809, 2815 }, +}; +static const URange32 Gunjala_Gondi_range32[] = { + { 73056, 73061 }, + { 73063, 73064 }, + { 73066, 73102 }, + { 73104, 73105 }, + { 73107, 73112 }, + { 73120, 73129 }, +}; +static const URange16 Gurmukhi_range16[] = { + { 2561, 2563 }, + { 2565, 2570 }, + { 2575, 2576 }, + { 2579, 2600 }, + { 2602, 2608 }, + { 2610, 2611 }, + { 2613, 2614 }, + { 2616, 2617 }, + { 2620, 2620 }, + { 2622, 2626 }, + { 2631, 2632 }, + { 2635, 2637 }, + { 2641, 2641 }, + { 2649, 2652 }, + { 2654, 2654 }, + { 2662, 2678 }, +}; +static const URange16 Han_range16[] = { + { 11904, 11929 }, + { 11931, 12019 }, + { 12032, 12245 }, + { 12293, 12293 }, + { 12295, 12295 }, + { 12321, 12329 }, + { 12344, 12347 }, + { 13312, 19903 }, + { 19968, 40959 }, + { 63744, 64109 }, + { 64112, 64217 }, +}; +static const URange32 Han_range32[] = { + { 94178, 94179 }, + { 94192, 94193 }, + { 131072, 173791 }, + { 173824, 177977 }, + { 177984, 178205 }, + { 178208, 183969 }, + { 183984, 191456 }, + { 194560, 195101 }, + { 196608, 201546 }, + { 201552, 205743 }, +}; +static const URange16 Hangul_range16[] = { + { 4352, 4607 }, + { 12334, 12335 }, + { 12593, 12686 }, + { 12800, 12830 }, + { 12896, 12926 }, + { 43360, 43388 }, + { 44032, 55203 }, + { 55216, 55238 }, + { 55243, 55291 }, + { 65440, 65470 }, + { 65474, 65479 }, + { 65482, 65487 }, + { 65490, 65495 }, + { 65498, 65500 }, +}; +static const URange32 Hanifi_Rohingya_range32[] = { + { 68864, 68903 }, + { 68912, 68921 }, +}; +static const URange16 Hanunoo_range16[] = { + { 5920, 5940 }, +}; +static const URange32 Hatran_range32[] = { + { 67808, 67826 }, + { 67828, 67829 }, + { 67835, 67839 }, +}; +static const URange16 Hebrew_range16[] = { + { 1425, 1479 }, + { 1488, 1514 }, + { 1519, 1524 }, + { 64285, 64310 }, + { 64312, 64316 }, + { 64318, 64318 }, + { 64320, 64321 }, + { 64323, 64324 }, + { 64326, 64335 }, +}; +static const URange16 Hiragana_range16[] = { + { 12353, 12438 }, + { 12445, 12447 }, +}; +static const URange32 Hiragana_range32[] = { + { 110593, 110879 }, + { 110898, 110898 }, + { 110928, 110930 }, + { 127488, 127488 }, +}; +static const URange32 Imperial_Aramaic_range32[] = { + { 67648, 67669 }, + { 67671, 67679 }, +}; +static const URange16 Inherited_range16[] = { + { 768, 879 }, + { 1157, 1158 }, + { 1611, 1621 }, + { 1648, 1648 }, + { 2385, 2388 }, + { 6832, 6862 }, + { 7376, 7378 }, + { 7380, 7392 }, + { 7394, 7400 }, + { 7405, 7405 }, + { 7412, 7412 }, + { 7416, 7417 }, + { 7616, 7679 }, + { 8204, 8205 }, + { 8400, 8432 }, + { 12330, 12333 }, + { 12441, 12442 }, + { 65024, 65039 }, + { 65056, 65069 }, +}; +static const URange32 Inherited_range32[] = { + { 66045, 66045 }, + { 66272, 66272 }, + { 70459, 70459 }, + { 118528, 118573 }, + { 118576, 118598 }, + { 119143, 119145 }, + { 119163, 119170 }, + { 119173, 119179 }, + { 119210, 119213 }, + { 917760, 917999 }, +}; +static const URange32 Inscriptional_Pahlavi_range32[] = { + { 68448, 68466 }, + { 68472, 68479 }, +}; +static const URange32 Inscriptional_Parthian_range32[] = { + { 68416, 68437 }, + { 68440, 68447 }, +}; +static const URange16 Javanese_range16[] = { + { 43392, 43469 }, + { 43472, 43481 }, + { 43486, 43487 }, +}; +static const URange32 Kaithi_range32[] = { + { 69760, 69826 }, + { 69837, 69837 }, +}; +static const URange16 Kannada_range16[] = { + { 3200, 3212 }, + { 3214, 3216 }, + { 3218, 3240 }, + { 3242, 3251 }, + { 3253, 3257 }, + { 3260, 3268 }, + { 3270, 3272 }, + { 3274, 3277 }, + { 3285, 3286 }, + { 3293, 3294 }, + { 3296, 3299 }, + { 3302, 3311 }, + { 3313, 3315 }, +}; +static const URange16 Katakana_range16[] = { + { 12449, 12538 }, + { 12541, 12543 }, + { 12784, 12799 }, + { 13008, 13054 }, + { 13056, 13143 }, + { 65382, 65391 }, + { 65393, 65437 }, +}; +static const URange32 Katakana_range32[] = { + { 110576, 110579 }, + { 110581, 110587 }, + { 110589, 110590 }, + { 110592, 110592 }, + { 110880, 110882 }, + { 110933, 110933 }, + { 110948, 110951 }, +}; +static const URange32 Kawi_range32[] = { + { 73472, 73488 }, + { 73490, 73530 }, + { 73534, 73561 }, +}; +static const URange16 Kayah_Li_range16[] = { + { 43264, 43309 }, + { 43311, 43311 }, +}; +static const URange32 Kharoshthi_range32[] = { + { 68096, 68099 }, + { 68101, 68102 }, + { 68108, 68115 }, + { 68117, 68119 }, + { 68121, 68149 }, + { 68152, 68154 }, + { 68159, 68168 }, + { 68176, 68184 }, +}; +static const URange32 Khitan_Small_Script_range32[] = { + { 94180, 94180 }, + { 101120, 101589 }, +}; +static const URange16 Khmer_range16[] = { + { 6016, 6109 }, + { 6112, 6121 }, + { 6128, 6137 }, + { 6624, 6655 }, +}; +static const URange32 Khojki_range32[] = { + { 70144, 70161 }, + { 70163, 70209 }, +}; +static const URange32 Khudawadi_range32[] = { + { 70320, 70378 }, + { 70384, 70393 }, +}; +static const URange16 Lao_range16[] = { + { 3713, 3714 }, + { 3716, 3716 }, + { 3718, 3722 }, + { 3724, 3747 }, + { 3749, 3749 }, + { 3751, 3773 }, + { 3776, 3780 }, + { 3782, 3782 }, + { 3784, 3790 }, + { 3792, 3801 }, + { 3804, 3807 }, +}; +static const URange16 Latin_range16[] = { + { 65, 90 }, + { 97, 122 }, + { 170, 170 }, + { 186, 186 }, + { 192, 214 }, + { 216, 246 }, + { 248, 696 }, + { 736, 740 }, + { 7424, 7461 }, + { 7468, 7516 }, + { 7522, 7525 }, + { 7531, 7543 }, + { 7545, 7614 }, + { 7680, 7935 }, + { 8305, 8305 }, + { 8319, 8319 }, + { 8336, 8348 }, + { 8490, 8491 }, + { 8498, 8498 }, + { 8526, 8526 }, + { 8544, 8584 }, + { 11360, 11391 }, + { 42786, 42887 }, + { 42891, 42954 }, + { 42960, 42961 }, + { 42963, 42963 }, + { 42965, 42969 }, + { 42994, 43007 }, + { 43824, 43866 }, + { 43868, 43876 }, + { 43878, 43881 }, + { 64256, 64262 }, + { 65313, 65338 }, + { 65345, 65370 }, +}; +static const URange32 Latin_range32[] = { + { 67456, 67461 }, + { 67463, 67504 }, + { 67506, 67514 }, + { 122624, 122654 }, + { 122661, 122666 }, +}; +static const URange16 Lepcha_range16[] = { + { 7168, 7223 }, + { 7227, 7241 }, + { 7245, 7247 }, +}; +static const URange16 Limbu_range16[] = { + { 6400, 6430 }, + { 6432, 6443 }, + { 6448, 6459 }, + { 6464, 6464 }, + { 6468, 6479 }, +}; +static const URange32 Linear_A_range32[] = { + { 67072, 67382 }, + { 67392, 67413 }, + { 67424, 67431 }, +}; +static const URange32 Linear_B_range32[] = { + { 65536, 65547 }, + { 65549, 65574 }, + { 65576, 65594 }, + { 65596, 65597 }, + { 65599, 65613 }, + { 65616, 65629 }, + { 65664, 65786 }, +}; +static const URange16 Lisu_range16[] = { + { 42192, 42239 }, +}; +static const URange32 Lisu_range32[] = { + { 73648, 73648 }, +}; +static const URange32 Lycian_range32[] = { + { 66176, 66204 }, +}; +static const URange32 Lydian_range32[] = { + { 67872, 67897 }, + { 67903, 67903 }, +}; +static const URange32 Mahajani_range32[] = { + { 69968, 70006 }, +}; +static const URange32 Makasar_range32[] = { + { 73440, 73464 }, +}; +static const URange16 Malayalam_range16[] = { + { 3328, 3340 }, + { 3342, 3344 }, + { 3346, 3396 }, + { 3398, 3400 }, + { 3402, 3407 }, + { 3412, 3427 }, + { 3430, 3455 }, +}; +static const URange16 Mandaic_range16[] = { + { 2112, 2139 }, + { 2142, 2142 }, +}; +static const URange32 Manichaean_range32[] = { + { 68288, 68326 }, + { 68331, 68342 }, +}; +static const URange32 Marchen_range32[] = { + { 72816, 72847 }, + { 72850, 72871 }, + { 72873, 72886 }, +}; +static const URange32 Masaram_Gondi_range32[] = { + { 72960, 72966 }, + { 72968, 72969 }, + { 72971, 73014 }, + { 73018, 73018 }, + { 73020, 73021 }, + { 73023, 73031 }, + { 73040, 73049 }, +}; +static const URange32 Medefaidrin_range32[] = { + { 93760, 93850 }, +}; +static const URange16 Meetei_Mayek_range16[] = { + { 43744, 43766 }, + { 43968, 44013 }, + { 44016, 44025 }, +}; +static const URange32 Mende_Kikakui_range32[] = { + { 124928, 125124 }, + { 125127, 125142 }, +}; +static const URange32 Meroitic_Cursive_range32[] = { + { 68000, 68023 }, + { 68028, 68047 }, + { 68050, 68095 }, +}; +static const URange32 Meroitic_Hieroglyphs_range32[] = { + { 67968, 67999 }, +}; +static const URange32 Miao_range32[] = { + { 93952, 94026 }, + { 94031, 94087 }, + { 94095, 94111 }, +}; +static const URange32 Modi_range32[] = { + { 71168, 71236 }, + { 71248, 71257 }, +}; +static const URange16 Mongolian_range16[] = { + { 6144, 6145 }, + { 6148, 6148 }, + { 6150, 6169 }, + { 6176, 6264 }, + { 6272, 6314 }, +}; +static const URange32 Mongolian_range32[] = { + { 71264, 71276 }, +}; +static const URange32 Mro_range32[] = { + { 92736, 92766 }, + { 92768, 92777 }, + { 92782, 92783 }, +}; +static const URange32 Multani_range32[] = { + { 70272, 70278 }, + { 70280, 70280 }, + { 70282, 70285 }, + { 70287, 70301 }, + { 70303, 70313 }, +}; +static const URange16 Myanmar_range16[] = { + { 4096, 4255 }, + { 43488, 43518 }, + { 43616, 43647 }, +}; +static const URange32 Nabataean_range32[] = { + { 67712, 67742 }, + { 67751, 67759 }, +}; +static const URange32 Nag_Mundari_range32[] = { + { 124112, 124153 }, +}; +static const URange32 Nandinagari_range32[] = { + { 72096, 72103 }, + { 72106, 72151 }, + { 72154, 72164 }, +}; +static const URange16 New_Tai_Lue_range16[] = { + { 6528, 6571 }, + { 6576, 6601 }, + { 6608, 6618 }, + { 6622, 6623 }, +}; +static const URange32 Newa_range32[] = { + { 70656, 70747 }, + { 70749, 70753 }, +}; +static const URange16 Nko_range16[] = { + { 1984, 2042 }, + { 2045, 2047 }, +}; +static const URange32 Nushu_range32[] = { + { 94177, 94177 }, + { 110960, 111355 }, +}; +static const URange32 Nyiakeng_Puachue_Hmong_range32[] = { + { 123136, 123180 }, + { 123184, 123197 }, + { 123200, 123209 }, + { 123214, 123215 }, +}; +static const URange16 Ogham_range16[] = { + { 5760, 5788 }, +}; +static const URange16 Ol_Chiki_range16[] = { + { 7248, 7295 }, +}; +static const URange32 Old_Hungarian_range32[] = { + { 68736, 68786 }, + { 68800, 68850 }, + { 68858, 68863 }, +}; +static const URange32 Old_Italic_range32[] = { + { 66304, 66339 }, + { 66349, 66351 }, +}; +static const URange32 Old_North_Arabian_range32[] = { + { 68224, 68255 }, +}; +static const URange32 Old_Permic_range32[] = { + { 66384, 66426 }, +}; +static const URange32 Old_Persian_range32[] = { + { 66464, 66499 }, + { 66504, 66517 }, +}; +static const URange32 Old_Sogdian_range32[] = { + { 69376, 69415 }, +}; +static const URange32 Old_South_Arabian_range32[] = { + { 68192, 68223 }, +}; +static const URange32 Old_Turkic_range32[] = { + { 68608, 68680 }, +}; +static const URange32 Old_Uyghur_range32[] = { + { 69488, 69513 }, +}; +static const URange16 Oriya_range16[] = { + { 2817, 2819 }, + { 2821, 2828 }, + { 2831, 2832 }, + { 2835, 2856 }, + { 2858, 2864 }, + { 2866, 2867 }, + { 2869, 2873 }, + { 2876, 2884 }, + { 2887, 2888 }, + { 2891, 2893 }, + { 2901, 2903 }, + { 2908, 2909 }, + { 2911, 2915 }, + { 2918, 2935 }, +}; +static const URange32 Osage_range32[] = { + { 66736, 66771 }, + { 66776, 66811 }, +}; +static const URange32 Osmanya_range32[] = { + { 66688, 66717 }, + { 66720, 66729 }, +}; +static const URange32 Pahawh_Hmong_range32[] = { + { 92928, 92997 }, + { 93008, 93017 }, + { 93019, 93025 }, + { 93027, 93047 }, + { 93053, 93071 }, +}; +static const URange32 Palmyrene_range32[] = { + { 67680, 67711 }, +}; +static const URange32 Pau_Cin_Hau_range32[] = { + { 72384, 72440 }, +}; +static const URange16 Phags_Pa_range16[] = { + { 43072, 43127 }, +}; +static const URange32 Phoenician_range32[] = { + { 67840, 67867 }, + { 67871, 67871 }, +}; +static const URange32 Psalter_Pahlavi_range32[] = { + { 68480, 68497 }, + { 68505, 68508 }, + { 68521, 68527 }, +}; +static const URange16 Rejang_range16[] = { + { 43312, 43347 }, + { 43359, 43359 }, +}; +static const URange16 Runic_range16[] = { + { 5792, 5866 }, + { 5870, 5880 }, +}; +static const URange16 Samaritan_range16[] = { + { 2048, 2093 }, + { 2096, 2110 }, +}; +static const URange16 Saurashtra_range16[] = { + { 43136, 43205 }, + { 43214, 43225 }, +}; +static const URange32 Sharada_range32[] = { + { 70016, 70111 }, +}; +static const URange32 Shavian_range32[] = { + { 66640, 66687 }, +}; +static const URange32 Siddham_range32[] = { + { 71040, 71093 }, + { 71096, 71133 }, +}; +static const URange32 SignWriting_range32[] = { + { 120832, 121483 }, + { 121499, 121503 }, + { 121505, 121519 }, +}; +static const URange16 Sinhala_range16[] = { + { 3457, 3459 }, + { 3461, 3478 }, + { 3482, 3505 }, + { 3507, 3515 }, + { 3517, 3517 }, + { 3520, 3526 }, + { 3530, 3530 }, + { 3535, 3540 }, + { 3542, 3542 }, + { 3544, 3551 }, + { 3558, 3567 }, + { 3570, 3572 }, +}; +static const URange32 Sinhala_range32[] = { + { 70113, 70132 }, +}; +static const URange32 Sogdian_range32[] = { + { 69424, 69465 }, +}; +static const URange32 Sora_Sompeng_range32[] = { + { 69840, 69864 }, + { 69872, 69881 }, +}; +static const URange32 Soyombo_range32[] = { + { 72272, 72354 }, +}; +static const URange16 Sundanese_range16[] = { + { 7040, 7103 }, + { 7360, 7367 }, +}; +static const URange16 Syloti_Nagri_range16[] = { + { 43008, 43052 }, +}; +static const URange16 Syriac_range16[] = { + { 1792, 1805 }, + { 1807, 1866 }, + { 1869, 1871 }, + { 2144, 2154 }, +}; +static const URange16 Tagalog_range16[] = { + { 5888, 5909 }, + { 5919, 5919 }, +}; +static const URange16 Tagbanwa_range16[] = { + { 5984, 5996 }, + { 5998, 6000 }, + { 6002, 6003 }, +}; +static const URange16 Tai_Le_range16[] = { + { 6480, 6509 }, + { 6512, 6516 }, +}; +static const URange16 Tai_Tham_range16[] = { + { 6688, 6750 }, + { 6752, 6780 }, + { 6783, 6793 }, + { 6800, 6809 }, + { 6816, 6829 }, +}; +static const URange16 Tai_Viet_range16[] = { + { 43648, 43714 }, + { 43739, 43743 }, +}; +static const URange32 Takri_range32[] = { + { 71296, 71353 }, + { 71360, 71369 }, +}; +static const URange16 Tamil_range16[] = { + { 2946, 2947 }, + { 2949, 2954 }, + { 2958, 2960 }, + { 2962, 2965 }, + { 2969, 2970 }, + { 2972, 2972 }, + { 2974, 2975 }, + { 2979, 2980 }, + { 2984, 2986 }, + { 2990, 3001 }, + { 3006, 3010 }, + { 3014, 3016 }, + { 3018, 3021 }, + { 3024, 3024 }, + { 3031, 3031 }, + { 3046, 3066 }, +}; +static const URange32 Tamil_range32[] = { + { 73664, 73713 }, + { 73727, 73727 }, +}; +static const URange32 Tangsa_range32[] = { + { 92784, 92862 }, + { 92864, 92873 }, +}; +static const URange32 Tangut_range32[] = { + { 94176, 94176 }, + { 94208, 100343 }, + { 100352, 101119 }, + { 101632, 101640 }, +}; +static const URange16 Telugu_range16[] = { + { 3072, 3084 }, + { 3086, 3088 }, + { 3090, 3112 }, + { 3114, 3129 }, + { 3132, 3140 }, + { 3142, 3144 }, + { 3146, 3149 }, + { 3157, 3158 }, + { 3160, 3162 }, + { 3165, 3165 }, + { 3168, 3171 }, + { 3174, 3183 }, + { 3191, 3199 }, +}; +static const URange16 Thaana_range16[] = { + { 1920, 1969 }, +}; +static const URange16 Thai_range16[] = { + { 3585, 3642 }, + { 3648, 3675 }, +}; +static const URange16 Tibetan_range16[] = { + { 3840, 3911 }, + { 3913, 3948 }, + { 3953, 3991 }, + { 3993, 4028 }, + { 4030, 4044 }, + { 4046, 4052 }, + { 4057, 4058 }, +}; +static const URange16 Tifinagh_range16[] = { + { 11568, 11623 }, + { 11631, 11632 }, + { 11647, 11647 }, +}; +static const URange32 Tirhuta_range32[] = { + { 70784, 70855 }, + { 70864, 70873 }, +}; +static const URange32 Toto_range32[] = { + { 123536, 123566 }, +}; +static const URange32 Ugaritic_range32[] = { + { 66432, 66461 }, + { 66463, 66463 }, +}; +static const URange16 Vai_range16[] = { + { 42240, 42539 }, +}; +static const URange32 Vithkuqi_range32[] = { + { 66928, 66938 }, + { 66940, 66954 }, + { 66956, 66962 }, + { 66964, 66965 }, + { 66967, 66977 }, + { 66979, 66993 }, + { 66995, 67001 }, + { 67003, 67004 }, +}; +static const URange32 Wancho_range32[] = { + { 123584, 123641 }, + { 123647, 123647 }, +}; +static const URange32 Warang_Citi_range32[] = { + { 71840, 71922 }, + { 71935, 71935 }, +}; +static const URange32 Yezidi_range32[] = { + { 69248, 69289 }, + { 69291, 69293 }, + { 69296, 69297 }, +}; +static const URange16 Yi_range16[] = { + { 40960, 42124 }, + { 42128, 42182 }, +}; +static const URange32 Zanabazar_Square_range32[] = { + { 72192, 72263 }, +}; +// 4040 16-bit ranges, 1775 32-bit ranges +const UGroup unicode_groups[] = { + { "Adlam", +1, 0, 0, Adlam_range32, 3 }, + { "Ahom", +1, 0, 0, Ahom_range32, 3 }, + { "Anatolian_Hieroglyphs", +1, 0, 0, Anatolian_Hieroglyphs_range32, 1 }, + { "Arabic", +1, Arabic_range16, 22, Arabic_range32, 36 }, + { "Armenian", +1, Armenian_range16, 4, 0, 0 }, + { "Avestan", +1, 0, 0, Avestan_range32, 2 }, + { "Balinese", +1, Balinese_range16, 2, 0, 0 }, + { "Bamum", +1, Bamum_range16, 1, Bamum_range32, 1 }, + { "Bassa_Vah", +1, 0, 0, Bassa_Vah_range32, 2 }, + { "Batak", +1, Batak_range16, 2, 0, 0 }, + { "Bengali", +1, Bengali_range16, 14, 0, 0 }, + { "Bhaiksuki", +1, 0, 0, Bhaiksuki_range32, 4 }, + { "Bopomofo", +1, Bopomofo_range16, 3, 0, 0 }, + { "Brahmi", +1, 0, 0, Brahmi_range32, 3 }, + { "Braille", +1, Braille_range16, 1, 0, 0 }, + { "Buginese", +1, Buginese_range16, 2, 0, 0 }, + { "Buhid", +1, Buhid_range16, 1, 0, 0 }, + { "C", +1, C_range16, 17, C_range32, 9 }, + { "Canadian_Aboriginal", +1, Canadian_Aboriginal_range16, 2, Canadian_Aboriginal_range32, 1 }, + { "Carian", +1, 0, 0, Carian_range32, 1 }, + { "Caucasian_Albanian", +1, 0, 0, Caucasian_Albanian_range32, 2 }, + { "Cc", +1, Cc_range16, 2, 0, 0 }, + { "Cf", +1, Cf_range16, 14, Cf_range32, 7 }, + { "Chakma", +1, 0, 0, Chakma_range32, 2 }, + { "Cham", +1, Cham_range16, 4, 0, 0 }, + { "Cherokee", +1, Cherokee_range16, 3, 0, 0 }, + { "Chorasmian", +1, 0, 0, Chorasmian_range32, 1 }, + { "Co", +1, Co_range16, 1, Co_range32, 2 }, + { "Common", +1, Common_range16, 91, Common_range32, 82 }, + { "Coptic", +1, Coptic_range16, 3, 0, 0 }, + { "Cs", +1, Cs_range16, 1, 0, 0 }, + { "Cuneiform", +1, 0, 0, Cuneiform_range32, 4 }, + { "Cypriot", +1, 0, 0, Cypriot_range32, 6 }, + { "Cypro_Minoan", +1, 0, 0, Cypro_Minoan_range32, 1 }, + { "Cyrillic", +1, Cyrillic_range16, 8, Cyrillic_range32, 2 }, + { "Deseret", +1, 0, 0, Deseret_range32, 1 }, + { "Devanagari", +1, Devanagari_range16, 4, Devanagari_range32, 1 }, + { "Dives_Akuru", +1, 0, 0, Dives_Akuru_range32, 8 }, + { "Dogra", +1, 0, 0, Dogra_range32, 1 }, + { "Duployan", +1, 0, 0, Duployan_range32, 5 }, + { "Egyptian_Hieroglyphs", +1, 0, 0, Egyptian_Hieroglyphs_range32, 1 }, + { "Elbasan", +1, 0, 0, Elbasan_range32, 1 }, + { "Elymaic", +1, 0, 0, Elymaic_range32, 1 }, + { "Ethiopic", +1, Ethiopic_range16, 32, Ethiopic_range32, 4 }, + { "Georgian", +1, Georgian_range16, 10, 0, 0 }, + { "Glagolitic", +1, Glagolitic_range16, 1, Glagolitic_range32, 5 }, + { "Gothic", +1, 0, 0, Gothic_range32, 1 }, + { "Grantha", +1, 0, 0, Grantha_range32, 15 }, + { "Greek", +1, Greek_range16, 33, Greek_range32, 3 }, + { "Gujarati", +1, Gujarati_range16, 14, 0, 0 }, + { "Gunjala_Gondi", +1, 0, 0, Gunjala_Gondi_range32, 6 }, + { "Gurmukhi", +1, Gurmukhi_range16, 16, 0, 0 }, + { "Han", +1, Han_range16, 11, Han_range32, 10 }, + { "Hangul", +1, Hangul_range16, 14, 0, 0 }, + { "Hanifi_Rohingya", +1, 0, 0, Hanifi_Rohingya_range32, 2 }, + { "Hanunoo", +1, Hanunoo_range16, 1, 0, 0 }, + { "Hatran", +1, 0, 0, Hatran_range32, 3 }, + { "Hebrew", +1, Hebrew_range16, 9, 0, 0 }, + { "Hiragana", +1, Hiragana_range16, 2, Hiragana_range32, 4 }, + { "Imperial_Aramaic", +1, 0, 0, Imperial_Aramaic_range32, 2 }, + { "Inherited", +1, Inherited_range16, 19, Inherited_range32, 10 }, + { "Inscriptional_Pahlavi", +1, 0, 0, Inscriptional_Pahlavi_range32, 2 }, + { "Inscriptional_Parthian", +1, 0, 0, Inscriptional_Parthian_range32, 2 }, + { "Javanese", +1, Javanese_range16, 3, 0, 0 }, + { "Kaithi", +1, 0, 0, Kaithi_range32, 2 }, + { "Kannada", +1, Kannada_range16, 13, 0, 0 }, + { "Katakana", +1, Katakana_range16, 7, Katakana_range32, 7 }, + { "Kawi", +1, 0, 0, Kawi_range32, 3 }, + { "Kayah_Li", +1, Kayah_Li_range16, 2, 0, 0 }, + { "Kharoshthi", +1, 0, 0, Kharoshthi_range32, 8 }, + { "Khitan_Small_Script", +1, 0, 0, Khitan_Small_Script_range32, 2 }, + { "Khmer", +1, Khmer_range16, 4, 0, 0 }, + { "Khojki", +1, 0, 0, Khojki_range32, 2 }, + { "Khudawadi", +1, 0, 0, Khudawadi_range32, 2 }, + { "L", +1, L_range16, 380, L_range32, 279 }, + { "Lao", +1, Lao_range16, 11, 0, 0 }, + { "Latin", +1, Latin_range16, 34, Latin_range32, 5 }, + { "Lepcha", +1, Lepcha_range16, 3, 0, 0 }, + { "Limbu", +1, Limbu_range16, 5, 0, 0 }, + { "Linear_A", +1, 0, 0, Linear_A_range32, 3 }, + { "Linear_B", +1, 0, 0, Linear_B_range32, 7 }, + { "Lisu", +1, Lisu_range16, 1, Lisu_range32, 1 }, + { "Ll", +1, Ll_range16, 617, Ll_range32, 41 }, + { "Lm", +1, Lm_range16, 57, Lm_range32, 14 }, + { "Lo", +1, Lo_range16, 290, Lo_range32, 220 }, + { "Lt", +1, Lt_range16, 10, 0, 0 }, + { "Lu", +1, Lu_range16, 605, Lu_range32, 41 }, + { "Lycian", +1, 0, 0, Lycian_range32, 1 }, + { "Lydian", +1, 0, 0, Lydian_range32, 2 }, + { "M", +1, M_range16, 190, M_range32, 120 }, + { "Mahajani", +1, 0, 0, Mahajani_range32, 1 }, + { "Makasar", +1, 0, 0, Makasar_range32, 1 }, + { "Malayalam", +1, Malayalam_range16, 7, 0, 0 }, + { "Mandaic", +1, Mandaic_range16, 2, 0, 0 }, + { "Manichaean", +1, 0, 0, Manichaean_range32, 2 }, + { "Marchen", +1, 0, 0, Marchen_range32, 3 }, + { "Masaram_Gondi", +1, 0, 0, Masaram_Gondi_range32, 7 }, + { "Mc", +1, Mc_range16, 112, Mc_range32, 70 }, + { "Me", +1, Me_range16, 5, 0, 0 }, + { "Medefaidrin", +1, 0, 0, Medefaidrin_range32, 1 }, + { "Meetei_Mayek", +1, Meetei_Mayek_range16, 3, 0, 0 }, + { "Mende_Kikakui", +1, 0, 0, Mende_Kikakui_range32, 2 }, + { "Meroitic_Cursive", +1, 0, 0, Meroitic_Cursive_range32, 3 }, + { "Meroitic_Hieroglyphs", +1, 0, 0, Meroitic_Hieroglyphs_range32, 1 }, + { "Miao", +1, 0, 0, Miao_range32, 3 }, + { "Mn", +1, Mn_range16, 212, Mn_range32, 134 }, + { "Modi", +1, 0, 0, Modi_range32, 2 }, + { "Mongolian", +1, Mongolian_range16, 5, Mongolian_range32, 1 }, + { "Mro", +1, 0, 0, Mro_range32, 3 }, + { "Multani", +1, 0, 0, Multani_range32, 5 }, + { "Myanmar", +1, Myanmar_range16, 3, 0, 0 }, + { "N", +1, N_range16, 67, N_range32, 70 }, + { "Nabataean", +1, 0, 0, Nabataean_range32, 2 }, + { "Nag_Mundari", +1, 0, 0, Nag_Mundari_range32, 1 }, + { "Nandinagari", +1, 0, 0, Nandinagari_range32, 3 }, + { "Nd", +1, Nd_range16, 37, Nd_range32, 27 }, + { "New_Tai_Lue", +1, New_Tai_Lue_range16, 4, 0, 0 }, + { "Newa", +1, 0, 0, Newa_range32, 2 }, + { "Nko", +1, Nko_range16, 2, 0, 0 }, + { "Nl", +1, Nl_range16, 7, Nl_range32, 5 }, + { "No", +1, No_range16, 29, No_range32, 43 }, + { "Nushu", +1, 0, 0, Nushu_range32, 2 }, + { "Nyiakeng_Puachue_Hmong", +1, 0, 0, Nyiakeng_Puachue_Hmong_range32, 4 }, + { "Ogham", +1, Ogham_range16, 1, 0, 0 }, + { "Ol_Chiki", +1, Ol_Chiki_range16, 1, 0, 0 }, + { "Old_Hungarian", +1, 0, 0, Old_Hungarian_range32, 3 }, + { "Old_Italic", +1, 0, 0, Old_Italic_range32, 2 }, + { "Old_North_Arabian", +1, 0, 0, Old_North_Arabian_range32, 1 }, + { "Old_Permic", +1, 0, 0, Old_Permic_range32, 1 }, + { "Old_Persian", +1, 0, 0, Old_Persian_range32, 2 }, + { "Old_Sogdian", +1, 0, 0, Old_Sogdian_range32, 1 }, + { "Old_South_Arabian", +1, 0, 0, Old_South_Arabian_range32, 1 }, + { "Old_Turkic", +1, 0, 0, Old_Turkic_range32, 1 }, + { "Old_Uyghur", +1, 0, 0, Old_Uyghur_range32, 1 }, + { "Oriya", +1, Oriya_range16, 14, 0, 0 }, + { "Osage", +1, 0, 0, Osage_range32, 2 }, + { "Osmanya", +1, 0, 0, Osmanya_range32, 2 }, + { "P", +1, P_range16, 133, P_range32, 58 }, + { "Pahawh_Hmong", +1, 0, 0, Pahawh_Hmong_range32, 5 }, + { "Palmyrene", +1, 0, 0, Palmyrene_range32, 1 }, + { "Pau_Cin_Hau", +1, 0, 0, Pau_Cin_Hau_range32, 1 }, + { "Pc", +1, Pc_range16, 6, 0, 0 }, + { "Pd", +1, Pd_range16, 18, Pd_range32, 1 }, + { "Pe", +1, Pe_range16, 76, 0, 0 }, + { "Pf", +1, Pf_range16, 10, 0, 0 }, + { "Phags_Pa", +1, Phags_Pa_range16, 1, 0, 0 }, + { "Phoenician", +1, 0, 0, Phoenician_range32, 2 }, + { "Pi", +1, Pi_range16, 11, 0, 0 }, + { "Po", +1, Po_range16, 130, Po_range32, 57 }, + { "Ps", +1, Ps_range16, 79, 0, 0 }, + { "Psalter_Pahlavi", +1, 0, 0, Psalter_Pahlavi_range32, 3 }, + { "Rejang", +1, Rejang_range16, 2, 0, 0 }, + { "Runic", +1, Runic_range16, 2, 0, 0 }, + { "S", +1, S_range16, 151, S_range32, 81 }, + { "Samaritan", +1, Samaritan_range16, 2, 0, 0 }, + { "Saurashtra", +1, Saurashtra_range16, 2, 0, 0 }, + { "Sc", +1, Sc_range16, 18, Sc_range32, 3 }, + { "Sharada", +1, 0, 0, Sharada_range32, 1 }, + { "Shavian", +1, 0, 0, Shavian_range32, 1 }, + { "Siddham", +1, 0, 0, Siddham_range32, 2 }, + { "SignWriting", +1, 0, 0, SignWriting_range32, 3 }, + { "Sinhala", +1, Sinhala_range16, 12, Sinhala_range32, 1 }, + { "Sk", +1, Sk_range16, 30, Sk_range32, 1 }, + { "Sm", +1, Sm_range16, 53, Sm_range32, 11 }, + { "So", +1, So_range16, 114, So_range32, 70 }, + { "Sogdian", +1, 0, 0, Sogdian_range32, 1 }, + { "Sora_Sompeng", +1, 0, 0, Sora_Sompeng_range32, 2 }, + { "Soyombo", +1, 0, 0, Soyombo_range32, 1 }, + { "Sundanese", +1, Sundanese_range16, 2, 0, 0 }, + { "Syloti_Nagri", +1, Syloti_Nagri_range16, 1, 0, 0 }, + { "Syriac", +1, Syriac_range16, 4, 0, 0 }, + { "Tagalog", +1, Tagalog_range16, 2, 0, 0 }, + { "Tagbanwa", +1, Tagbanwa_range16, 3, 0, 0 }, + { "Tai_Le", +1, Tai_Le_range16, 2, 0, 0 }, + { "Tai_Tham", +1, Tai_Tham_range16, 5, 0, 0 }, + { "Tai_Viet", +1, Tai_Viet_range16, 2, 0, 0 }, + { "Takri", +1, 0, 0, Takri_range32, 2 }, + { "Tamil", +1, Tamil_range16, 16, Tamil_range32, 2 }, + { "Tangsa", +1, 0, 0, Tangsa_range32, 2 }, + { "Tangut", +1, 0, 0, Tangut_range32, 4 }, + { "Telugu", +1, Telugu_range16, 13, 0, 0 }, + { "Thaana", +1, Thaana_range16, 1, 0, 0 }, + { "Thai", +1, Thai_range16, 2, 0, 0 }, + { "Tibetan", +1, Tibetan_range16, 7, 0, 0 }, + { "Tifinagh", +1, Tifinagh_range16, 3, 0, 0 }, + { "Tirhuta", +1, 0, 0, Tirhuta_range32, 2 }, + { "Toto", +1, 0, 0, Toto_range32, 1 }, + { "Ugaritic", +1, 0, 0, Ugaritic_range32, 2 }, + { "Vai", +1, Vai_range16, 1, 0, 0 }, + { "Vithkuqi", +1, 0, 0, Vithkuqi_range32, 8 }, + { "Wancho", +1, 0, 0, Wancho_range32, 2 }, + { "Warang_Citi", +1, 0, 0, Warang_Citi_range32, 2 }, + { "Yezidi", +1, 0, 0, Yezidi_range32, 3 }, + { "Yi", +1, Yi_range16, 2, 0, 0 }, + { "Z", +1, Z_range16, 8, 0, 0 }, + { "Zanabazar_Square", +1, 0, 0, Zanabazar_Square_range32, 1 }, + { "Zl", +1, Zl_range16, 1, 0, 0 }, + { "Zp", +1, Zp_range16, 1, 0, 0 }, + { "Zs", +1, Zs_range16, 7, 0, 0 }, +}; +const int num_unicode_groups = 199; + + +} // namespace re2 + + diff --git a/internal/cpp/re2/unicode_groups.h b/internal/cpp/re2/unicode_groups.h new file mode 100644 index 00000000000..a2bff0670e6 --- /dev/null +++ b/internal/cpp/re2/unicode_groups.h @@ -0,0 +1,64 @@ +// Copyright 2008 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_UNICODE_GROUPS_H_ +#define RE2_UNICODE_GROUPS_H_ + +// Unicode character groups. + +// The codes get split into ranges of 16-bit codes +// and ranges of 32-bit codes. It would be simpler +// to use only 32-bit ranges, but these tables are large +// enough to warrant extra care. +// +// Using just 32-bit ranges gives 27 kB of data. +// Adding 16-bit ranges gives 18 kB of data. +// Adding an extra table of 16-bit singletons would reduce +// to 16.5 kB of data but make the data harder to use; +// we don't bother. + +#include + +#include "util/utf.h" +#include "util/util.h" + +namespace re2 { + +struct URange16 { + uint16_t lo; + uint16_t hi; +}; + +struct URange32 { + Rune lo; + Rune hi; +}; + +struct UGroup { + const char *name; + int sign; // +1 for [abc], -1 for [^abc] + const URange16 *r16; + int nr16; + const URange32 *r32; + int nr32; +}; + +// Named by property or script name (e.g., "Nd", "N", "Han"). +// Negated groups are not included. +extern const UGroup unicode_groups[]; +extern const int num_unicode_groups; + +// Named by POSIX name (e.g., "[:alpha:]", "[:^lower:]"). +// Negated groups are included. +extern const UGroup posix_groups[]; +extern const int num_posix_groups; + +// Named by Perl name (e.g., "\\d", "\\D"). +// Negated groups are included. +extern const UGroup perl_groups[]; +extern const int num_perl_groups; + +} // namespace re2 + +#endif // RE2_UNICODE_GROUPS_H_ diff --git a/internal/cpp/re2/walker-inl.h b/internal/cpp/re2/walker-inl.h new file mode 100644 index 00000000000..f0313cae83d --- /dev/null +++ b/internal/cpp/re2/walker-inl.h @@ -0,0 +1,246 @@ +// Copyright 2006 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef RE2_WALKER_INL_H_ +#define RE2_WALKER_INL_H_ + +// Helper class for traversing Regexps without recursion. +// Clients should declare their own subclasses that override +// the PreVisit and PostVisit methods, which are called before +// and after visiting the subexpressions. + +// Not quite the Visitor pattern, because (among other things) +// the Visitor pattern is recursive. + +#include + +#include "re2/regexp.h" +#include "util/logging.h" + +namespace re2 { + +template +struct WalkState; + +template +class Regexp::Walker { +public: + Walker(); + virtual ~Walker(); + + // Virtual method called before visiting re's children. + // PreVisit passes ownership of its return value to its caller. + // The Arg* that PreVisit returns will be passed to PostVisit as pre_arg + // and passed to the child PreVisits and PostVisits as parent_arg. + // At the top-most Regexp, parent_arg is arg passed to walk. + // If PreVisit sets *stop to true, the walk does not recurse + // into the children. Instead it behaves as though the return + // value from PreVisit is the return value from PostVisit. + // The default PreVisit returns parent_arg. + virtual T PreVisit(Regexp *re, T parent_arg, bool *stop); + + // Virtual method called after visiting re's children. + // The pre_arg is the T that PreVisit returned. + // The child_args is a vector of the T that the child PostVisits returned. + // PostVisit takes ownership of pre_arg. + // PostVisit takes ownership of the Ts + // in *child_args, but not the vector itself. + // PostVisit passes ownership of its return value + // to its caller. + // The default PostVisit simply returns pre_arg. + virtual T PostVisit(Regexp *re, T parent_arg, T pre_arg, T *child_args, int nchild_args); + + // Virtual method called to copy a T, + // when Walk notices that more than one child is the same re. + virtual T Copy(T arg); + + // Virtual method called to do a "quick visit" of the re, + // but not its children. Only called once the visit budget + // has been used up and we're trying to abort the walk + // as quickly as possible. Should return a value that + // makes sense for the parent PostVisits still to be run. + // This function is (hopefully) only called by + // WalkExponential, but must be implemented by all clients, + // just in case. + virtual T ShortVisit(Regexp *re, T parent_arg) = 0; + + // Walks over a regular expression. + // Top_arg is passed as parent_arg to PreVisit and PostVisit of re. + // Returns the T returned by PostVisit on re. + T Walk(Regexp *re, T top_arg); + + // Like Walk, but doesn't use Copy. This can lead to + // exponential runtimes on cross-linked Regexps like the + // ones generated by Simplify. To help limit this, + // at most max_visits nodes will be visited and then + // the walk will be cut off early. + // If the walk *is* cut off early, ShortVisit(re) + // will be called on regexps that cannot be fully + // visited rather than calling PreVisit/PostVisit. + T WalkExponential(Regexp *re, T top_arg, int max_visits); + + // Clears the stack. Should never be necessary, since + // Walk always enters and exits with an empty stack. + // Logs DFATAL if stack is not already clear. + void Reset(); + + // Returns whether walk was cut off. + bool stopped_early() { return stopped_early_; } + +private: + // Walk state for the entire traversal. + std::stack> stack_; + bool stopped_early_; + int max_visits_; + + T WalkInternal(Regexp *re, T top_arg, bool use_copy); + + Walker(const Walker &) = delete; + Walker &operator=(const Walker &) = delete; +}; + +template +T Regexp::Walker::PreVisit(Regexp *re, T parent_arg, bool *stop) { + return parent_arg; +} + +template +T Regexp::Walker::PostVisit(Regexp *re, T parent_arg, T pre_arg, T *child_args, int nchild_args) { + return pre_arg; +} + +template +T Regexp::Walker::Copy(T arg) { + return arg; +} + +// State about a single level in the traversal. +template +struct WalkState { + WalkState(Regexp *re, T parent) : re(re), n(-1), parent_arg(parent), child_args(NULL) {} + + Regexp *re; // The regexp + int n; // The index of the next child to process; -1 means need to PreVisit + T parent_arg; // Accumulated arguments. + T pre_arg; + T child_arg; // One-element buffer for child_args. + T *child_args; +}; + +template +Regexp::Walker::Walker() { + stopped_early_ = false; +} + +template +Regexp::Walker::~Walker() { + Reset(); +} + +// Clears the stack. Should never be necessary, since +// Walk always enters and exits with an empty stack. +// Logs DFATAL if stack is not already clear. +template +void Regexp::Walker::Reset() { + if (!stack_.empty()) { + LOG(DFATAL) << "Stack not empty."; + while (!stack_.empty()) { + if (stack_.top().re->nsub_ > 1) + delete[] stack_.top().child_args; + stack_.pop(); + } + } +} + +template +T Regexp::Walker::WalkInternal(Regexp *re, T top_arg, bool use_copy) { + Reset(); + + if (re == NULL) { + LOG(DFATAL) << "Walk NULL"; + return top_arg; + } + + stack_.push(WalkState(re, top_arg)); + + WalkState *s; + for (;;) { + T t; + s = &stack_.top(); + re = s->re; + switch (s->n) { + case -1: { + if (--max_visits_ < 0) { + stopped_early_ = true; + t = ShortVisit(re, s->parent_arg); + break; + } + bool stop = false; + s->pre_arg = PreVisit(re, s->parent_arg, &stop); + if (stop) { + t = s->pre_arg; + break; + } + s->n = 0; + s->child_args = NULL; + if (re->nsub_ == 1) + s->child_args = &s->child_arg; + else if (re->nsub_ > 1) + s->child_args = new T[re->nsub_]; + FALLTHROUGH_INTENDED; + } + default: { + if (re->nsub_ > 0) { + Regexp **sub = re->sub(); + if (s->n < re->nsub_) { + if (use_copy && s->n > 0 && sub[s->n - 1] == sub[s->n]) { + s->child_args[s->n] = Copy(s->child_args[s->n - 1]); + s->n++; + } else { + stack_.push(WalkState(sub[s->n], s->pre_arg)); + } + continue; + } + } + + t = PostVisit(re, s->parent_arg, s->pre_arg, s->child_args, s->n); + if (re->nsub_ > 1) + delete[] s->child_args; + break; + } + } + + // We've finished stack_.top(). + // Update next guy down. + stack_.pop(); + if (stack_.empty()) + return t; + s = &stack_.top(); + if (s->child_args != NULL) + s->child_args[s->n] = t; + else + s->child_arg = t; + s->n++; + } +} + +template +T Regexp::Walker::Walk(Regexp *re, T top_arg) { + // Without the exponential walking behavior, + // this budget should be more than enough for any + // regexp, and yet not enough to get us in trouble + // as far as CPU time. + max_visits_ = 1000000; + return WalkInternal(re, top_arg, true); +} + +template +T Regexp::Walker::WalkExponential(Regexp *re, T top_arg, int max_visits) { + max_visits_ = max_visits; + return WalkInternal(re, top_arg, false); +} + +} // namespace re2 + +#endif // RE2_WALKER_INL_H_ diff --git a/internal/cpp/stemmer/api.cpp b/internal/cpp/stemmer/api.cpp new file mode 100644 index 00000000000..9107370465d --- /dev/null +++ b/internal/cpp/stemmer/api.cpp @@ -0,0 +1,78 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "header.h" + +#include /* for calloc, free */ + +extern struct SN_env *SN_create_env(int S_size, int I_size, int B_size) { + struct SN_env *z = (struct SN_env *)calloc(1, sizeof(struct SN_env)); + if (z == NULL) + return NULL; + z->p = create_s(); + if (z->p == NULL) + goto error; + if (S_size) { + int i; + z->S = (symbol **)calloc(S_size, sizeof(symbol *)); + if (z->S == NULL) + goto error; + + for (i = 0; i < S_size; i++) { + z->S[i] = create_s(); + if (z->S[i] == NULL) + goto error; + } + } + + if (I_size) { + z->I = (int *)calloc(I_size, sizeof(int)); + if (z->I == NULL) + goto error; + } + + if (B_size) { + z->B = (unsigned char *)calloc(B_size, sizeof(unsigned char)); + if (z->B == NULL) + goto error; + } + + return z; +error: + SN_close_env(z, S_size); + return NULL; +} + +extern void SN_close_env(struct SN_env *z, int S_size) { + if (z == NULL) + return; + if (S_size) { + int i; + for (i = 0; i < S_size; i++) { + lose_s(z->S[i]); + } + free(z->S); + } + free(z->I); + free(z->B); + if (z->p) + lose_s(z->p); + free(z); +} + +extern int SN_set_current(struct SN_env *z, int size, const symbol *s) { + int err = replace_s(z, 0, z->l, size, s, NULL); + z->c = 0; + return err; +} diff --git a/internal/cpp/stemmer/api.h b/internal/cpp/stemmer/api.h new file mode 100644 index 00000000000..341ea6cf386 --- /dev/null +++ b/internal/cpp/stemmer/api.h @@ -0,0 +1,31 @@ + +#pragma once + +typedef unsigned char symbol; + +/* Or replace 'char' above with 'short' for 16 bit characters. + + More precisely, replace 'char' with whatever type guarantees the + character width you need. Note however that sizeof(symbol) should divide + HEAD, defined in header.h as 2*sizeof(int), without remainder, otherwise + there is an alignment problem. In the unlikely event of a problem here, + consult Martin Porter. + +*/ + +struct SN_env { + symbol *p; + int c; + int l; + int lb; + int bra; + int ket; + symbol **S; + int *I; + unsigned char *B; +}; + +extern struct SN_env *SN_create_env(int S_size, int I_size, int B_size); +extern void SN_close_env(struct SN_env *z, int S_size); + +extern int SN_set_current(struct SN_env *z, int size, const symbol *s); diff --git a/internal/cpp/stemmer/header.h b/internal/cpp/stemmer/header.h new file mode 100644 index 00000000000..82604bae93e --- /dev/null +++ b/internal/cpp/stemmer/header.h @@ -0,0 +1,59 @@ + +#pragma once + +#include + +#include "api.h" + +#define MAXINT INT_MAX +#define MININT INT_MIN + +#define HEAD 2 * sizeof(int) + +#define SIZE(p) ((int *)(p))[-1] +#define SET_SIZE(p, n) ((int *)(p))[-1] = n +#define CAPACITY(p) ((int *)(p))[-2] + +struct among { + int s_size; /* number of chars in string */ + const symbol *s; /* search string */ + int substring_i; /* index to longest matching substring */ + int result; /* result of the lookup */ + int (*function)(struct SN_env *); +}; + +extern symbol *create_s(void); +extern void lose_s(symbol *p); + +extern int skip_utf8(const symbol *p, int c, int lb, int l, int n); + +extern int in_grouping_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int in_grouping_b_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int out_grouping_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int out_grouping_b_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); + +extern int in_grouping(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int in_grouping_b(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int out_grouping(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); +extern int out_grouping_b(struct SN_env *z, const unsigned char *s, int min, int max, int repeat); + +extern int eq_s(struct SN_env *z, int s_size, const symbol *s); +extern int eq_s_b(struct SN_env *z, int s_size, const symbol *s); +extern int eq_v(struct SN_env *z, const symbol *p); +extern int eq_v_b(struct SN_env *z, const symbol *p); + +extern int find_among(struct SN_env *z, const struct among *v, int v_size); +extern int find_among_b(struct SN_env *z, const struct among *v, int v_size); + +extern int replace_s(struct SN_env *z, int c_bra, int c_ket, int s_size, const symbol *s, int *adjustment); +extern int slice_from_s(struct SN_env *z, int s_size, const symbol *s); +extern int slice_from_v(struct SN_env *z, const symbol *p); +extern int slice_del(struct SN_env *z); + +extern int insert_s(struct SN_env *z, int bra, int ket, int s_size, const symbol *s); +extern int insert_v(struct SN_env *z, int bra, int ket, const symbol *p); + +extern symbol *slice_to(struct SN_env *z, symbol *p); +extern symbol *assign_to(struct SN_env *z, symbol *p); + +extern void debug(struct SN_env *z, int number, int line_count); diff --git a/internal/cpp/stemmer/stem_UTF_8_danish.cpp b/internal/cpp/stemmer/stem_UTF_8_danish.cpp new file mode 100644 index 00000000000..b804fd70820 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_danish.cpp @@ -0,0 +1,424 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int danish_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_undouble(struct SN_env *z); +static int r_other_suffix(struct SN_env *z); +static int r_consonant_pair(struct SN_env *z); +static int r_main_suffix(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *danish_UTF_8_create_env(void); +extern void danish_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[3] = {'h', 'e', 'd'}; +static const symbol s_0_1[5] = {'e', 't', 'h', 'e', 'd'}; +static const symbol s_0_2[4] = {'e', 'r', 'e', 'd'}; +static const symbol s_0_3[1] = {'e'}; +static const symbol s_0_4[5] = {'e', 'r', 'e', 'd', 'e'}; +static const symbol s_0_5[4] = {'e', 'n', 'd', 'e'}; +static const symbol s_0_6[6] = {'e', 'r', 'e', 'n', 'd', 'e'}; +static const symbol s_0_7[3] = {'e', 'n', 'e'}; +static const symbol s_0_8[4] = {'e', 'r', 'n', 'e'}; +static const symbol s_0_9[3] = {'e', 'r', 'e'}; +static const symbol s_0_10[2] = {'e', 'n'}; +static const symbol s_0_11[5] = {'h', 'e', 'd', 'e', 'n'}; +static const symbol s_0_12[4] = {'e', 'r', 'e', 'n'}; +static const symbol s_0_13[2] = {'e', 'r'}; +static const symbol s_0_14[5] = {'h', 'e', 'd', 'e', 'r'}; +static const symbol s_0_15[4] = {'e', 'r', 'e', 'r'}; +static const symbol s_0_16[1] = {'s'}; +static const symbol s_0_17[4] = {'h', 'e', 'd', 's'}; +static const symbol s_0_18[2] = {'e', 's'}; +static const symbol s_0_19[5] = {'e', 'n', 'd', 'e', 's'}; +static const symbol s_0_20[7] = {'e', 'r', 'e', 'n', 'd', 'e', 's'}; +static const symbol s_0_21[4] = {'e', 'n', 'e', 's'}; +static const symbol s_0_22[5] = {'e', 'r', 'n', 'e', 's'}; +static const symbol s_0_23[4] = {'e', 'r', 'e', 's'}; +static const symbol s_0_24[3] = {'e', 'n', 's'}; +static const symbol s_0_25[6] = {'h', 'e', 'd', 'e', 'n', 's'}; +static const symbol s_0_26[5] = {'e', 'r', 'e', 'n', 's'}; +static const symbol s_0_27[3] = {'e', 'r', 's'}; +static const symbol s_0_28[3] = {'e', 't', 's'}; +static const symbol s_0_29[5] = {'e', 'r', 'e', 't', 's'}; +static const symbol s_0_30[2] = {'e', 't'}; +static const symbol s_0_31[4] = {'e', 'r', 'e', 't'}; + +static const struct among a_0[32] = { + /* 0 */ {3, s_0_0, -1, 1, 0}, + /* 1 */ {5, s_0_1, 0, 1, 0}, + /* 2 */ {4, s_0_2, -1, 1, 0}, + /* 3 */ {1, s_0_3, -1, 1, 0}, + /* 4 */ {5, s_0_4, 3, 1, 0}, + /* 5 */ {4, s_0_5, 3, 1, 0}, + /* 6 */ {6, s_0_6, 5, 1, 0}, + /* 7 */ {3, s_0_7, 3, 1, 0}, + /* 8 */ {4, s_0_8, 3, 1, 0}, + /* 9 */ {3, s_0_9, 3, 1, 0}, + /* 10 */ {2, s_0_10, -1, 1, 0}, + /* 11 */ {5, s_0_11, 10, 1, 0}, + /* 12 */ {4, s_0_12, 10, 1, 0}, + /* 13 */ {2, s_0_13, -1, 1, 0}, + /* 14 */ {5, s_0_14, 13, 1, 0}, + /* 15 */ {4, s_0_15, 13, 1, 0}, + /* 16 */ {1, s_0_16, -1, 2, 0}, + /* 17 */ {4, s_0_17, 16, 1, 0}, + /* 18 */ {2, s_0_18, 16, 1, 0}, + /* 19 */ {5, s_0_19, 18, 1, 0}, + /* 20 */ {7, s_0_20, 19, 1, 0}, + /* 21 */ {4, s_0_21, 18, 1, 0}, + /* 22 */ {5, s_0_22, 18, 1, 0}, + /* 23 */ {4, s_0_23, 18, 1, 0}, + /* 24 */ {3, s_0_24, 16, 1, 0}, + /* 25 */ {6, s_0_25, 24, 1, 0}, + /* 26 */ {5, s_0_26, 24, 1, 0}, + /* 27 */ {3, s_0_27, 16, 1, 0}, + /* 28 */ {3, s_0_28, 16, 1, 0}, + /* 29 */ {5, s_0_29, 28, 1, 0}, + /* 30 */ {2, s_0_30, -1, 1, 0}, + /* 31 */ {4, s_0_31, 30, 1, 0}}; + +static const symbol s_1_0[2] = {'g', 'd'}; +static const symbol s_1_1[2] = {'d', 't'}; +static const symbol s_1_2[2] = {'g', 't'}; +static const symbol s_1_3[2] = {'k', 't'}; + +static const struct among a_1[4] = { + /* 0 */ {2, s_1_0, -1, -1, 0}, + /* 1 */ {2, s_1_1, -1, -1, 0}, + /* 2 */ {2, s_1_2, -1, -1, 0}, + /* 3 */ {2, s_1_3, -1, -1, 0}}; + +static const symbol s_2_0[2] = {'i', 'g'}; +static const symbol s_2_1[3] = {'l', 'i', 'g'}; +static const symbol s_2_2[4] = {'e', 'l', 'i', 'g'}; +static const symbol s_2_3[3] = {'e', 'l', 's'}; +static const symbol s_2_4[5] = {'l', 0xC3, 0xB8, 's', 't'}; + +static const struct among a_2[5] = { + /* 0 */ {2, s_2_0, -1, 1, 0}, + /* 1 */ {3, s_2_1, 0, 1, 0}, + /* 2 */ {4, s_2_2, 1, 1, 0}, + /* 3 */ {3, s_2_3, -1, 1, 0}, + /* 4 */ {5, s_2_4, -1, 2, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 0, 128}; + +static const unsigned char g_s_ending[] = {239, 254, 42, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}; + +static const symbol s_0[] = {'s', 't'}; +static const symbol s_1[] = {'i', 'g'}; +static const symbol s_2[] = {'l', 0xC3, 0xB8, 's'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + { + int c_test = z->c; /* test, line 33 */ + { + int ret = skip_utf8(z->p, z->c, 0, z->l, +3); + if (ret < 0) + return 0; + z->c = ret; /* hop, line 33 */ + } + z->I[1] = z->c; /* setmark x, line 33 */ + z->c = c_test; + } + if (out_grouping_U(z, g_v, 97, 248, 1) < 0) + return 0; /* goto */ /* grouping v, line 34 */ + { /* gopast */ /* non v, line 34 */ + int ret = in_grouping_U(z, g_v, 97, 248, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 34 */ + /* try, line 35 */ + if (!(z->I[0] < z->I[1])) + goto lab0; + z->I[0] = z->I[1]; +lab0: + return 1; +} + +static int r_main_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 41 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 41 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 41 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1851440 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_0, 32); /* substring, line 41 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 41 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 48 */ + if (ret < 0) + return ret; + } break; + case 2: + if (in_grouping_b_U(z, g_s_ending, 97, 229, 0)) + return 0; + { + int ret = slice_del(z); /* delete, line 50 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_consonant_pair(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 55 */ + { + int mlimit; /* setlimit, line 56 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 56 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 56 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 100 && z->p[z->c - 1] != 116)) { + z->lb = mlimit; + return 0; + } + if (!(find_among_b(z, a_1, 4))) { + z->lb = mlimit; + return 0; + } /* substring, line 56 */ + z->bra = z->c; /* ], line 56 */ + z->lb = mlimit; + } + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 62 */ + } + z->bra = z->c; /* ], line 62 */ + { + int ret = slice_del(z); /* delete, line 62 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_other_suffix(struct SN_env *z) { + int among_var; + { + int m1 = z->l - z->c; + (void)m1; /* do, line 66 */ + z->ket = z->c; /* [, line 66 */ + if (!(eq_s_b(z, 2, s_0))) + goto lab0; + z->bra = z->c; /* ], line 66 */ + if (!(eq_s_b(z, 2, s_1))) + goto lab0; + { + int ret = slice_del(z); /* delete, line 66 */ + if (ret < 0) + return ret; + } + lab0: + z->c = z->l - m1; + } + { + int mlimit; /* setlimit, line 67 */ + int m2 = z->l - z->c; + (void)m2; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 67 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m2; + z->ket = z->c; /* [, line 67 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1572992 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_2, 5); /* substring, line 67 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 67 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 70 */ + if (ret < 0) + return ret; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 70 */ + { + int ret = r_consonant_pair(z); + if (ret == 0) + goto lab1; /* call consonant_pair, line 70 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m3; + } + break; + case 2: { + int ret = slice_from_s(z, 4, s_2); /* <-, line 72 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_undouble(struct SN_env *z) { + { + int mlimit; /* setlimit, line 76 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 76 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 76 */ + if (out_grouping_b_U(z, g_v, 97, 248, 0)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 76 */ + z->S[0] = slice_to(z, z->S[0]); /* -> ch, line 76 */ + if (z->S[0] == 0) + return -1; /* -> ch, line 76 */ + z->lb = mlimit; + } + if (!(eq_v_b(z, z->S[0]))) + return 0; /* name ch, line 77 */ + { + int ret = slice_del(z); /* delete, line 78 */ + if (ret < 0) + return ret; + } + return 1; +} + +extern int danish_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 84 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 84 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 85 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 86 */ + { + int ret = r_main_suffix(z); + if (ret == 0) + goto lab1; /* call main_suffix, line 86 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 87 */ + { + int ret = r_consonant_pair(z); + if (ret == 0) + goto lab2; /* call consonant_pair, line 87 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 88 */ + { + int ret = r_other_suffix(z); + if (ret == 0) + goto lab3; /* call other_suffix, line 88 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 89 */ + { + int ret = r_undouble(z); + if (ret == 0) + goto lab4; /* call undouble, line 89 */ + if (ret < 0) + return ret; + } + lab4: + z->c = z->l - m5; + } + z->c = z->lb; + return 1; +} + +extern struct SN_env *danish_UTF_8_create_env(void) { return SN_create_env(1, 2, 0); } + +extern void danish_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 1); } diff --git a/internal/cpp/stemmer/stem_UTF_8_danish.h b/internal/cpp/stemmer/stem_UTF_8_danish.h new file mode 100644 index 00000000000..5d86b1c59c1 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_danish.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *danish_UTF_8_create_env(void); +extern void danish_UTF_8_close_env(struct SN_env *z); + +extern int danish_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_dutch.cpp b/internal/cpp/stemmer/stem_UTF_8_dutch.cpp new file mode 100644 index 00000000000..18d8cc663d3 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_dutch.cpp @@ -0,0 +1,792 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int dutch_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_standard_suffix(struct SN_env *z); +static int r_undouble(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_en_ending(struct SN_env *z); +static int r_e_ending(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *dutch_UTF_8_create_env(void); +extern void dutch_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[2] = {0xC3, 0xA1}; +static const symbol s_0_2[2] = {0xC3, 0xA4}; +static const symbol s_0_3[2] = {0xC3, 0xA9}; +static const symbol s_0_4[2] = {0xC3, 0xAB}; +static const symbol s_0_5[2] = {0xC3, 0xAD}; +static const symbol s_0_6[2] = {0xC3, 0xAF}; +static const symbol s_0_7[2] = {0xC3, 0xB3}; +static const symbol s_0_8[2] = {0xC3, 0xB6}; +static const symbol s_0_9[2] = {0xC3, 0xBA}; +static const symbol s_0_10[2] = {0xC3, 0xBC}; + +static const struct among a_0[11] = { + /* 0 */ {0, 0, -1, 6, 0}, + /* 1 */ {2, s_0_1, 0, 1, 0}, + /* 2 */ {2, s_0_2, 0, 1, 0}, + /* 3 */ {2, s_0_3, 0, 2, 0}, + /* 4 */ {2, s_0_4, 0, 2, 0}, + /* 5 */ {2, s_0_5, 0, 3, 0}, + /* 6 */ {2, s_0_6, 0, 3, 0}, + /* 7 */ {2, s_0_7, 0, 4, 0}, + /* 8 */ {2, s_0_8, 0, 4, 0}, + /* 9 */ {2, s_0_9, 0, 5, 0}, + /* 10 */ {2, s_0_10, 0, 5, 0}}; + +static const symbol s_1_1[1] = {'I'}; +static const symbol s_1_2[1] = {'Y'}; + +static const struct among a_1[3] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {1, s_1_1, 0, 2, 0}, + /* 2 */ {1, s_1_2, 0, 1, 0}}; + +static const symbol s_2_0[2] = {'d', 'd'}; +static const symbol s_2_1[2] = {'k', 'k'}; +static const symbol s_2_2[2] = {'t', 't'}; + +static const struct among a_2[3] = { + /* 0 */ {2, s_2_0, -1, -1, 0}, + /* 1 */ {2, s_2_1, -1, -1, 0}, + /* 2 */ {2, s_2_2, -1, -1, 0}}; + +static const symbol s_3_0[3] = {'e', 'n', 'e'}; +static const symbol s_3_1[2] = {'s', 'e'}; +static const symbol s_3_2[2] = {'e', 'n'}; +static const symbol s_3_3[5] = {'h', 'e', 'd', 'e', 'n'}; +static const symbol s_3_4[1] = {'s'}; + +static const struct among a_3[5] = { + /* 0 */ {3, s_3_0, -1, 2, 0}, + /* 1 */ {2, s_3_1, -1, 3, 0}, + /* 2 */ {2, s_3_2, -1, 2, 0}, + /* 3 */ {5, s_3_3, 2, 1, 0}, + /* 4 */ {1, s_3_4, -1, 3, 0}}; + +static const symbol s_4_0[3] = {'e', 'n', 'd'}; +static const symbol s_4_1[2] = {'i', 'g'}; +static const symbol s_4_2[3] = {'i', 'n', 'g'}; +static const symbol s_4_3[4] = {'l', 'i', 'j', 'k'}; +static const symbol s_4_4[4] = {'b', 'a', 'a', 'r'}; +static const symbol s_4_5[3] = {'b', 'a', 'r'}; + +static const struct among a_4[6] = { + /* 0 */ {3, s_4_0, -1, 1, 0}, + /* 1 */ {2, s_4_1, -1, 2, 0}, + /* 2 */ {3, s_4_2, -1, 1, 0}, + /* 3 */ {4, s_4_3, -1, 3, 0}, + /* 4 */ {4, s_4_4, -1, 4, 0}, + /* 5 */ {3, s_4_5, -1, 5, 0}}; + +static const symbol s_5_0[2] = {'a', 'a'}; +static const symbol s_5_1[2] = {'e', 'e'}; +static const symbol s_5_2[2] = {'o', 'o'}; +static const symbol s_5_3[2] = {'u', 'u'}; + +static const struct among a_5[4] = { + /* 0 */ {2, s_5_0, -1, -1, 0}, + /* 1 */ {2, s_5_1, -1, -1, 0}, + /* 2 */ {2, s_5_2, -1, -1, 0}, + /* 3 */ {2, s_5_3, -1, -1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128}; + +static const unsigned char g_v_I[] = {1, 0, 0, 17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128}; + +static const unsigned char g_v_j[] = {17, 67, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128}; + +static const symbol s_0[] = {'a'}; +static const symbol s_1[] = {'e'}; +static const symbol s_2[] = {'i'}; +static const symbol s_3[] = {'o'}; +static const symbol s_4[] = {'u'}; +static const symbol s_5[] = {'y'}; +static const symbol s_6[] = {'Y'}; +static const symbol s_7[] = {'i'}; +static const symbol s_8[] = {'I'}; +static const symbol s_9[] = {'y'}; +static const symbol s_10[] = {'Y'}; +static const symbol s_11[] = {'y'}; +static const symbol s_12[] = {'i'}; +static const symbol s_13[] = {'e'}; +static const symbol s_14[] = {'g', 'e', 'm'}; +static const symbol s_15[] = {'h', 'e', 'i', 'd'}; +static const symbol s_16[] = {'h', 'e', 'i', 'd'}; +static const symbol s_17[] = {'c'}; +static const symbol s_18[] = {'e', 'n'}; +static const symbol s_19[] = {'i', 'g'}; +static const symbol s_20[] = {'e'}; +static const symbol s_21[] = {'e'}; + +static int r_prelude(struct SN_env *z) { + int among_var; + { + int c_test = z->c; /* test, line 42 */ + while (1) { /* repeat, line 42 */ + int c1 = z->c; + z->bra = z->c; /* [, line 43 */ + if (z->c + 1 >= z->l || z->p[z->c + 1] >> 5 != 5 || !((340306450 >> (z->p[z->c + 1] & 0x1f)) & 1)) + among_var = 6; + else + among_var = find_among(z, a_0, 11); /* substring, line 43 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 43 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_0); /* <-, line 45 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_1); /* <-, line 47 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_2); /* <-, line 49 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_3); /* <-, line 51 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 1, s_4); /* <-, line 53 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 54 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + z->c = c_test; + } + { + int c_keep = z->c; /* try, line 57 */ + z->bra = z->c; /* [, line 57 */ + if (!(eq_s(z, 1, s_5))) { + z->c = c_keep; + goto lab1; + } + z->ket = z->c; /* ], line 57 */ + { + int ret = slice_from_s(z, 1, s_6); /* <-, line 57 */ + if (ret < 0) + return ret; + } + lab1:; + } + while (1) { /* repeat, line 58 */ + int c2 = z->c; + while (1) { /* goto, line 58 */ + int c3 = z->c; + if (in_grouping_U(z, g_v, 97, 232, 0)) + goto lab3; + z->bra = z->c; /* [, line 59 */ + { + int c4 = z->c; /* or, line 59 */ + if (!(eq_s(z, 1, s_7))) + goto lab5; + z->ket = z->c; /* ], line 59 */ + if (in_grouping_U(z, g_v, 97, 232, 0)) + goto lab5; + { + int ret = slice_from_s(z, 1, s_8); /* <-, line 59 */ + if (ret < 0) + return ret; + } + goto lab4; + lab5: + z->c = c4; + if (!(eq_s(z, 1, s_9))) + goto lab3; + z->ket = z->c; /* ], line 60 */ + { + int ret = slice_from_s(z, 1, s_10); /* <-, line 60 */ + if (ret < 0) + return ret; + } + } + lab4: + z->c = c3; + break; + lab3: + z->c = c3; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab2; + z->c = ret; /* goto, line 58 */ + } + } + continue; + lab2: + z->c = c2; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + { /* gopast */ /* grouping v, line 69 */ + int ret = out_grouping_U(z, g_v, 97, 232, 1); + if (ret < 0) + return 0; + z->c += ret; + } + { /* gopast */ /* non v, line 69 */ + int ret = in_grouping_U(z, g_v, 97, 232, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 69 */ + /* try, line 70 */ + if (!(z->I[0] < 3)) + goto lab0; + z->I[0] = 3; +lab0: { /* gopast */ /* grouping v, line 71 */ + int ret = out_grouping_U(z, g_v, 97, 232, 1); + if (ret < 0) + return 0; + z->c += ret; +} + { /* gopast */ /* non v, line 71 */ + int ret = in_grouping_U(z, g_v, 97, 232, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 71 */ + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 75 */ + int c1 = z->c; + z->bra = z->c; /* [, line 77 */ + if (z->c >= z->l || (z->p[z->c + 0] != 73 && z->p[z->c + 0] != 89)) + among_var = 3; + else + among_var = find_among(z, a_1, 3); /* substring, line 77 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 77 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_11); /* <-, line 78 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_12); /* <-, line 79 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 80 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_undouble(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 91 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1050640 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_2, 3))) + return 0; /* among, line 91 */ + z->c = z->l - m_test; + } + z->ket = z->c; /* [, line 91 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 91 */ + } + z->bra = z->c; /* ], line 91 */ + { + int ret = slice_del(z); /* delete, line 91 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_e_ending(struct SN_env *z) { + z->B[0] = 0; /* unset e_found, line 95 */ + z->ket = z->c; /* [, line 96 */ + if (!(eq_s_b(z, 1, s_13))) + return 0; + z->bra = z->c; /* ], line 96 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 96 */ + if (ret < 0) + return ret; + } + { + int m_test = z->l - z->c; /* test, line 96 */ + if (out_grouping_b_U(z, g_v, 97, 232, 0)) + return 0; + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 96 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set e_found, line 97 */ + { + int ret = r_undouble(z); + if (ret == 0) + return 0; /* call undouble, line 98 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_en_ending(struct SN_env *z) { + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 102 */ + if (ret < 0) + return ret; + } + { + int m1 = z->l - z->c; + (void)m1; /* and, line 102 */ + if (out_grouping_b_U(z, g_v, 97, 232, 0)) + return 0; + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 102 */ + if (!(eq_s_b(z, 3, s_14))) + goto lab0; + return 0; + lab0: + z->c = z->l - m2; + } + } + { + int ret = slice_del(z); /* delete, line 102 */ + if (ret < 0) + return ret; + } + { + int ret = r_undouble(z); + if (ret == 0) + return 0; /* call undouble, line 103 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + { + int m1 = z->l - z->c; + (void)m1; /* do, line 107 */ + z->ket = z->c; /* [, line 108 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((540704 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab0; + among_var = find_among_b(z, a_3, 5); /* substring, line 108 */ + if (!(among_var)) + goto lab0; + z->bra = z->c; /* ], line 108 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = r_R1(z); + if (ret == 0) + goto lab0; /* call R1, line 110 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 4, s_15); /* <-, line 110 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_en_ending(z); + if (ret == 0) + goto lab0; /* call en_ending, line 113 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = r_R1(z); + if (ret == 0) + goto lab0; /* call R1, line 116 */ + if (ret < 0) + return ret; + } + if (out_grouping_b_U(z, g_v_j, 97, 232, 0)) + goto lab0; + { + int ret = slice_del(z); /* delete, line 116 */ + if (ret < 0) + return ret; + } + break; + } + lab0: + z->c = z->l - m1; + } + { + int m2 = z->l - z->c; + (void)m2; /* do, line 120 */ + { + int ret = r_e_ending(z); + if (ret == 0) + goto lab1; /* call e_ending, line 120 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 122 */ + z->ket = z->c; /* [, line 122 */ + if (!(eq_s_b(z, 4, s_16))) + goto lab2; + z->bra = z->c; /* ], line 122 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab2; /* call R2, line 122 */ + if (ret < 0) + return ret; + } + { + int m4 = z->l - z->c; + (void)m4; /* not, line 122 */ + if (!(eq_s_b(z, 1, s_17))) + goto lab3; + goto lab2; + lab3: + z->c = z->l - m4; + } + { + int ret = slice_del(z); /* delete, line 122 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 123 */ + if (!(eq_s_b(z, 2, s_18))) + goto lab2; + z->bra = z->c; /* ], line 123 */ + { + int ret = r_en_ending(z); + if (ret == 0) + goto lab2; /* call en_ending, line 123 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 126 */ + z->ket = z->c; /* [, line 127 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((264336 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab4; + among_var = find_among_b(z, a_4, 6); /* substring, line 127 */ + if (!(among_var)) + goto lab4; + z->bra = z->c; /* ], line 127 */ + switch (among_var) { + case 0: + goto lab4; + case 1: { + int ret = r_R2(z); + if (ret == 0) + goto lab4; /* call R2, line 129 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 129 */ + if (ret < 0) + return ret; + } + { + int m6 = z->l - z->c; + (void)m6; /* or, line 130 */ + z->ket = z->c; /* [, line 130 */ + if (!(eq_s_b(z, 2, s_19))) + goto lab6; + z->bra = z->c; /* ], line 130 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab6; /* call R2, line 130 */ + if (ret < 0) + return ret; + } + { + int m7 = z->l - z->c; + (void)m7; /* not, line 130 */ + if (!(eq_s_b(z, 1, s_20))) + goto lab7; + goto lab6; + lab7: + z->c = z->l - m7; + } + { + int ret = slice_del(z); /* delete, line 130 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = z->l - m6; + { + int ret = r_undouble(z); + if (ret == 0) + goto lab4; /* call undouble, line 130 */ + if (ret < 0) + return ret; + } + } + lab5: + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + goto lab4; /* call R2, line 133 */ + if (ret < 0) + return ret; + } + { + int m8 = z->l - z->c; + (void)m8; /* not, line 133 */ + if (!(eq_s_b(z, 1, s_21))) + goto lab8; + goto lab4; + lab8: + z->c = z->l - m8; + } + { + int ret = slice_del(z); /* delete, line 133 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) + goto lab4; /* call R2, line 136 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 136 */ + if (ret < 0) + return ret; + } + { + int ret = r_e_ending(z); + if (ret == 0) + goto lab4; /* call e_ending, line 136 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_R2(z); + if (ret == 0) + goto lab4; /* call R2, line 139 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 139 */ + if (ret < 0) + return ret; + } + break; + case 5: { + int ret = r_R2(z); + if (ret == 0) + goto lab4; /* call R2, line 142 */ + if (ret < 0) + return ret; + } + if (!(z->B[0])) + goto lab4; /* Boolean test e_found, line 142 */ + { + int ret = slice_del(z); /* delete, line 142 */ + if (ret < 0) + return ret; + } + break; + } + lab4: + z->c = z->l - m5; + } + { + int m9 = z->l - z->c; + (void)m9; /* do, line 146 */ + if (out_grouping_b_U(z, g_v_I, 73, 232, 0)) + goto lab9; + { + int m_test = z->l - z->c; /* test, line 148 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((2129954 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab9; + if (!(find_among_b(z, a_5, 4))) + goto lab9; /* among, line 149 */ + if (out_grouping_b_U(z, g_v, 97, 232, 0)) + goto lab9; + z->c = z->l - m_test; + } + z->ket = z->c; /* [, line 152 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab9; + z->c = ret; /* next, line 152 */ + } + z->bra = z->c; /* ], line 152 */ + { + int ret = slice_del(z); /* delete, line 152 */ + if (ret < 0) + return ret; + } + lab9: + z->c = z->l - m9; + } + return 1; +} + +extern int dutch_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 159 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 159 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 160 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 160 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 161 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 162 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab2; /* call standard_suffix, line 162 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + z->c = z->lb; + { + int c4 = z->c; /* do, line 163 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab3; /* call postlude, line 163 */ + if (ret < 0) + return ret; + } + lab3: + z->c = c4; + } + return 1; +} + +extern struct SN_env *dutch_UTF_8_create_env(void) { return SN_create_env(0, 2, 1); } + +extern void dutch_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_dutch.h b/internal/cpp/stemmer/stem_UTF_8_dutch.h new file mode 100644 index 00000000000..468ac17572c --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_dutch.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *dutch_UTF_8_create_env(void); +extern void dutch_UTF_8_close_env(struct SN_env *z); + +extern int dutch_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_english.cpp b/internal/cpp/stemmer/stem_UTF_8_english.cpp new file mode 100644 index 00000000000..3eb186dd78d --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_english.cpp @@ -0,0 +1,1316 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int english_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_exception2(struct SN_env *z); +static int r_exception1(struct SN_env *z); +static int r_Step_5(struct SN_env *z); +static int r_Step_4(struct SN_env *z); +static int r_Step_3(struct SN_env *z); +static int r_Step_2(struct SN_env *z); +static int r_Step_1c(struct SN_env *z); +static int r_Step_1b(struct SN_env *z); +static int r_Step_1a(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_shortv(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *english_UTF_8_create_env(void); +extern void english_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[5] = {'a', 'r', 's', 'e', 'n'}; +static const symbol s_0_1[6] = {'c', 'o', 'm', 'm', 'u', 'n'}; +static const symbol s_0_2[5] = {'g', 'e', 'n', 'e', 'r'}; + +static const struct among a_0[3] = { + /* 0 */ {5, s_0_0, -1, -1, 0}, + /* 1 */ {6, s_0_1, -1, -1, 0}, + /* 2 */ {5, s_0_2, -1, -1, 0}}; + +static const symbol s_1_0[1] = {'\''}; +static const symbol s_1_1[3] = {'\'', 's', '\''}; +static const symbol s_1_2[2] = {'\'', 's'}; + +static const struct among a_1[3] = { + /* 0 */ {1, s_1_0, -1, 1, 0}, + /* 1 */ {3, s_1_1, 0, 1, 0}, + /* 2 */ {2, s_1_2, -1, 1, 0}}; + +static const symbol s_2_0[3] = {'i', 'e', 'd'}; +static const symbol s_2_1[1] = {'s'}; +static const symbol s_2_2[3] = {'i', 'e', 's'}; +static const symbol s_2_3[4] = {'s', 's', 'e', 's'}; +static const symbol s_2_4[2] = {'s', 's'}; +static const symbol s_2_5[2] = {'u', 's'}; + +static const struct among a_2[6] = { + /* 0 */ {3, s_2_0, -1, 2, 0}, + /* 1 */ {1, s_2_1, -1, 3, 0}, + /* 2 */ {3, s_2_2, 1, 2, 0}, + /* 3 */ {4, s_2_3, 1, 1, 0}, + /* 4 */ {2, s_2_4, 1, -1, 0}, + /* 5 */ {2, s_2_5, 1, -1, 0}}; + +static const symbol s_3_1[2] = {'b', 'b'}; +static const symbol s_3_2[2] = {'d', 'd'}; +static const symbol s_3_3[2] = {'f', 'f'}; +static const symbol s_3_4[2] = {'g', 'g'}; +static const symbol s_3_5[2] = {'b', 'l'}; +static const symbol s_3_6[2] = {'m', 'm'}; +static const symbol s_3_7[2] = {'n', 'n'}; +static const symbol s_3_8[2] = {'p', 'p'}; +static const symbol s_3_9[2] = {'r', 'r'}; +static const symbol s_3_10[2] = {'a', 't'}; +static const symbol s_3_11[2] = {'t', 't'}; +static const symbol s_3_12[2] = {'i', 'z'}; + +static const struct among a_3[13] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {2, s_3_1, 0, 2, 0}, + /* 2 */ {2, s_3_2, 0, 2, 0}, + /* 3 */ {2, s_3_3, 0, 2, 0}, + /* 4 */ {2, s_3_4, 0, 2, 0}, + /* 5 */ {2, s_3_5, 0, 1, 0}, + /* 6 */ {2, s_3_6, 0, 2, 0}, + /* 7 */ {2, s_3_7, 0, 2, 0}, + /* 8 */ {2, s_3_8, 0, 2, 0}, + /* 9 */ {2, s_3_9, 0, 2, 0}, + /* 10 */ {2, s_3_10, 0, 1, 0}, + /* 11 */ {2, s_3_11, 0, 2, 0}, + /* 12 */ {2, s_3_12, 0, 1, 0}}; + +static const symbol s_4_0[2] = {'e', 'd'}; +static const symbol s_4_1[3] = {'e', 'e', 'd'}; +static const symbol s_4_2[3] = {'i', 'n', 'g'}; +static const symbol s_4_3[4] = {'e', 'd', 'l', 'y'}; +static const symbol s_4_4[5] = {'e', 'e', 'd', 'l', 'y'}; +static const symbol s_4_5[5] = {'i', 'n', 'g', 'l', 'y'}; + +static const struct among a_4[6] = { + /* 0 */ {2, s_4_0, -1, 2, 0}, + /* 1 */ {3, s_4_1, 0, 1, 0}, + /* 2 */ {3, s_4_2, -1, 2, 0}, + /* 3 */ {4, s_4_3, -1, 2, 0}, + /* 4 */ {5, s_4_4, 3, 1, 0}, + /* 5 */ {5, s_4_5, -1, 2, 0}}; + +static const symbol s_5_0[4] = {'a', 'n', 'c', 'i'}; +static const symbol s_5_1[4] = {'e', 'n', 'c', 'i'}; +static const symbol s_5_2[3] = {'o', 'g', 'i'}; +static const symbol s_5_3[2] = {'l', 'i'}; +static const symbol s_5_4[3] = {'b', 'l', 'i'}; +static const symbol s_5_5[4] = {'a', 'b', 'l', 'i'}; +static const symbol s_5_6[4] = {'a', 'l', 'l', 'i'}; +static const symbol s_5_7[5] = {'f', 'u', 'l', 'l', 'i'}; +static const symbol s_5_8[6] = {'l', 'e', 's', 's', 'l', 'i'}; +static const symbol s_5_9[5] = {'o', 'u', 's', 'l', 'i'}; +static const symbol s_5_10[5] = {'e', 'n', 't', 'l', 'i'}; +static const symbol s_5_11[5] = {'a', 'l', 'i', 't', 'i'}; +static const symbol s_5_12[6] = {'b', 'i', 'l', 'i', 't', 'i'}; +static const symbol s_5_13[5] = {'i', 'v', 'i', 't', 'i'}; +static const symbol s_5_14[6] = {'t', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_5_15[7] = {'a', 't', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_5_16[5] = {'a', 'l', 'i', 's', 'm'}; +static const symbol s_5_17[5] = {'a', 't', 'i', 'o', 'n'}; +static const symbol s_5_18[7] = {'i', 'z', 'a', 't', 'i', 'o', 'n'}; +static const symbol s_5_19[4] = {'i', 'z', 'e', 'r'}; +static const symbol s_5_20[4] = {'a', 't', 'o', 'r'}; +static const symbol s_5_21[7] = {'i', 'v', 'e', 'n', 'e', 's', 's'}; +static const symbol s_5_22[7] = {'f', 'u', 'l', 'n', 'e', 's', 's'}; +static const symbol s_5_23[7] = {'o', 'u', 's', 'n', 'e', 's', 's'}; + +static const struct among a_5[24] = { + /* 0 */ {4, s_5_0, -1, 3, 0}, + /* 1 */ {4, s_5_1, -1, 2, 0}, + /* 2 */ {3, s_5_2, -1, 13, 0}, + /* 3 */ {2, s_5_3, -1, 16, 0}, + /* 4 */ {3, s_5_4, 3, 12, 0}, + /* 5 */ {4, s_5_5, 4, 4, 0}, + /* 6 */ {4, s_5_6, 3, 8, 0}, + /* 7 */ {5, s_5_7, 3, 14, 0}, + /* 8 */ {6, s_5_8, 3, 15, 0}, + /* 9 */ {5, s_5_9, 3, 10, 0}, + /* 10 */ {5, s_5_10, 3, 5, 0}, + /* 11 */ {5, s_5_11, -1, 8, 0}, + /* 12 */ {6, s_5_12, -1, 12, 0}, + /* 13 */ {5, s_5_13, -1, 11, 0}, + /* 14 */ {6, s_5_14, -1, 1, 0}, + /* 15 */ {7, s_5_15, 14, 7, 0}, + /* 16 */ {5, s_5_16, -1, 8, 0}, + /* 17 */ {5, s_5_17, -1, 7, 0}, + /* 18 */ {7, s_5_18, 17, 6, 0}, + /* 19 */ {4, s_5_19, -1, 6, 0}, + /* 20 */ {4, s_5_20, -1, 7, 0}, + /* 21 */ {7, s_5_21, -1, 11, 0}, + /* 22 */ {7, s_5_22, -1, 9, 0}, + /* 23 */ {7, s_5_23, -1, 10, 0}}; + +static const symbol s_6_0[5] = {'i', 'c', 'a', 't', 'e'}; +static const symbol s_6_1[5] = {'a', 't', 'i', 'v', 'e'}; +static const symbol s_6_2[5] = {'a', 'l', 'i', 'z', 'e'}; +static const symbol s_6_3[5] = {'i', 'c', 'i', 't', 'i'}; +static const symbol s_6_4[4] = {'i', 'c', 'a', 'l'}; +static const symbol s_6_5[6] = {'t', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_6_6[7] = {'a', 't', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_6_7[3] = {'f', 'u', 'l'}; +static const symbol s_6_8[4] = {'n', 'e', 's', 's'}; + +static const struct among a_6[9] = { + /* 0 */ {5, s_6_0, -1, 4, 0}, + /* 1 */ {5, s_6_1, -1, 6, 0}, + /* 2 */ {5, s_6_2, -1, 3, 0}, + /* 3 */ {5, s_6_3, -1, 4, 0}, + /* 4 */ {4, s_6_4, -1, 4, 0}, + /* 5 */ {6, s_6_5, -1, 1, 0}, + /* 6 */ {7, s_6_6, 5, 2, 0}, + /* 7 */ {3, s_6_7, -1, 5, 0}, + /* 8 */ {4, s_6_8, -1, 5, 0}}; + +static const symbol s_7_0[2] = {'i', 'c'}; +static const symbol s_7_1[4] = {'a', 'n', 'c', 'e'}; +static const symbol s_7_2[4] = {'e', 'n', 'c', 'e'}; +static const symbol s_7_3[4] = {'a', 'b', 'l', 'e'}; +static const symbol s_7_4[4] = {'i', 'b', 'l', 'e'}; +static const symbol s_7_5[3] = {'a', 't', 'e'}; +static const symbol s_7_6[3] = {'i', 'v', 'e'}; +static const symbol s_7_7[3] = {'i', 'z', 'e'}; +static const symbol s_7_8[3] = {'i', 't', 'i'}; +static const symbol s_7_9[2] = {'a', 'l'}; +static const symbol s_7_10[3] = {'i', 's', 'm'}; +static const symbol s_7_11[3] = {'i', 'o', 'n'}; +static const symbol s_7_12[2] = {'e', 'r'}; +static const symbol s_7_13[3] = {'o', 'u', 's'}; +static const symbol s_7_14[3] = {'a', 'n', 't'}; +static const symbol s_7_15[3] = {'e', 'n', 't'}; +static const symbol s_7_16[4] = {'m', 'e', 'n', 't'}; +static const symbol s_7_17[5] = {'e', 'm', 'e', 'n', 't'}; + +static const struct among a_7[18] = { + /* 0 */ {2, s_7_0, -1, 1, 0}, + /* 1 */ {4, s_7_1, -1, 1, 0}, + /* 2 */ {4, s_7_2, -1, 1, 0}, + /* 3 */ {4, s_7_3, -1, 1, 0}, + /* 4 */ {4, s_7_4, -1, 1, 0}, + /* 5 */ {3, s_7_5, -1, 1, 0}, + /* 6 */ {3, s_7_6, -1, 1, 0}, + /* 7 */ {3, s_7_7, -1, 1, 0}, + /* 8 */ {3, s_7_8, -1, 1, 0}, + /* 9 */ {2, s_7_9, -1, 1, 0}, + /* 10 */ {3, s_7_10, -1, 1, 0}, + /* 11 */ {3, s_7_11, -1, 2, 0}, + /* 12 */ {2, s_7_12, -1, 1, 0}, + /* 13 */ {3, s_7_13, -1, 1, 0}, + /* 14 */ {3, s_7_14, -1, 1, 0}, + /* 15 */ {3, s_7_15, -1, 1, 0}, + /* 16 */ {4, s_7_16, 15, 1, 0}, + /* 17 */ {5, s_7_17, 16, 1, 0}}; + +static const symbol s_8_0[1] = {'e'}; +static const symbol s_8_1[1] = {'l'}; + +static const struct among a_8[2] = { + /* 0 */ {1, s_8_0, -1, 1, 0}, + /* 1 */ {1, s_8_1, -1, 2, 0}}; + +static const symbol s_9_0[7] = {'s', 'u', 'c', 'c', 'e', 'e', 'd'}; +static const symbol s_9_1[7] = {'p', 'r', 'o', 'c', 'e', 'e', 'd'}; +static const symbol s_9_2[6] = {'e', 'x', 'c', 'e', 'e', 'd'}; +static const symbol s_9_3[7] = {'c', 'a', 'n', 'n', 'i', 'n', 'g'}; +static const symbol s_9_4[6] = {'i', 'n', 'n', 'i', 'n', 'g'}; +static const symbol s_9_5[7] = {'e', 'a', 'r', 'r', 'i', 'n', 'g'}; +static const symbol s_9_6[7] = {'h', 'e', 'r', 'r', 'i', 'n', 'g'}; +static const symbol s_9_7[6] = {'o', 'u', 't', 'i', 'n', 'g'}; + +static const struct among a_9[8] = { + /* 0 */ {7, s_9_0, -1, -1, 0}, + /* 1 */ {7, s_9_1, -1, -1, 0}, + /* 2 */ {6, s_9_2, -1, -1, 0}, + /* 3 */ {7, s_9_3, -1, -1, 0}, + /* 4 */ {6, s_9_4, -1, -1, 0}, + /* 5 */ {7, s_9_5, -1, -1, 0}, + /* 6 */ {7, s_9_6, -1, -1, 0}, + /* 7 */ {6, s_9_7, -1, -1, 0}}; + +static const symbol s_10_0[5] = {'a', 'n', 'd', 'e', 's'}; +static const symbol s_10_1[5] = {'a', 't', 'l', 'a', 's'}; +static const symbol s_10_2[4] = {'b', 'i', 'a', 's'}; +static const symbol s_10_3[6] = {'c', 'o', 's', 'm', 'o', 's'}; +static const symbol s_10_4[5] = {'d', 'y', 'i', 'n', 'g'}; +static const symbol s_10_5[5] = {'e', 'a', 'r', 'l', 'y'}; +static const symbol s_10_6[6] = {'g', 'e', 'n', 't', 'l', 'y'}; +static const symbol s_10_7[4] = {'h', 'o', 'w', 'e'}; +static const symbol s_10_8[4] = {'i', 'd', 'l', 'y'}; +static const symbol s_10_9[5] = {'l', 'y', 'i', 'n', 'g'}; +static const symbol s_10_10[4] = {'n', 'e', 'w', 's'}; +static const symbol s_10_11[4] = {'o', 'n', 'l', 'y'}; +static const symbol s_10_12[6] = {'s', 'i', 'n', 'g', 'l', 'y'}; +static const symbol s_10_13[5] = {'s', 'k', 'i', 'e', 's'}; +static const symbol s_10_14[4] = {'s', 'k', 'i', 's'}; +static const symbol s_10_15[3] = {'s', 'k', 'y'}; +static const symbol s_10_16[5] = {'t', 'y', 'i', 'n', 'g'}; +static const symbol s_10_17[4] = {'u', 'g', 'l', 'y'}; + +static const struct among a_10[18] = { + /* 0 */ {5, s_10_0, -1, -1, 0}, + /* 1 */ {5, s_10_1, -1, -1, 0}, + /* 2 */ {4, s_10_2, -1, -1, 0}, + /* 3 */ {6, s_10_3, -1, -1, 0}, + /* 4 */ {5, s_10_4, -1, 3, 0}, + /* 5 */ {5, s_10_5, -1, 9, 0}, + /* 6 */ {6, s_10_6, -1, 7, 0}, + /* 7 */ {4, s_10_7, -1, -1, 0}, + /* 8 */ {4, s_10_8, -1, 6, 0}, + /* 9 */ {5, s_10_9, -1, 4, 0}, + /* 10 */ {4, s_10_10, -1, -1, 0}, + /* 11 */ {4, s_10_11, -1, 10, 0}, + /* 12 */ {6, s_10_12, -1, 11, 0}, + /* 13 */ {5, s_10_13, -1, 2, 0}, + /* 14 */ {4, s_10_14, -1, 1, 0}, + /* 15 */ {3, s_10_15, -1, -1, 0}, + /* 16 */ {5, s_10_16, -1, 5, 0}, + /* 17 */ {4, s_10_17, -1, 8, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1}; + +static const unsigned char g_v_WXY[] = {1, 17, 65, 208, 1}; + +static const unsigned char g_valid_LI[] = {55, 141, 2}; + +static const symbol s_0[] = {'\''}; +static const symbol s_1[] = {'y'}; +static const symbol s_2[] = {'Y'}; +static const symbol s_3[] = {'y'}; +static const symbol s_4[] = {'Y'}; +static const symbol s_5[] = {'s', 's'}; +static const symbol s_6[] = {'i'}; +static const symbol s_7[] = {'i', 'e'}; +static const symbol s_8[] = {'e', 'e'}; +static const symbol s_9[] = {'e'}; +static const symbol s_10[] = {'e'}; +static const symbol s_11[] = {'y'}; +static const symbol s_12[] = {'Y'}; +static const symbol s_13[] = {'i'}; +static const symbol s_14[] = {'t', 'i', 'o', 'n'}; +static const symbol s_15[] = {'e', 'n', 'c', 'e'}; +static const symbol s_16[] = {'a', 'n', 'c', 'e'}; +static const symbol s_17[] = {'a', 'b', 'l', 'e'}; +static const symbol s_18[] = {'e', 'n', 't'}; +static const symbol s_19[] = {'i', 'z', 'e'}; +static const symbol s_20[] = {'a', 't', 'e'}; +static const symbol s_21[] = {'a', 'l'}; +static const symbol s_22[] = {'f', 'u', 'l'}; +static const symbol s_23[] = {'o', 'u', 's'}; +static const symbol s_24[] = {'i', 'v', 'e'}; +static const symbol s_25[] = {'b', 'l', 'e'}; +static const symbol s_26[] = {'l'}; +static const symbol s_27[] = {'o', 'g'}; +static const symbol s_28[] = {'f', 'u', 'l'}; +static const symbol s_29[] = {'l', 'e', 's', 's'}; +static const symbol s_30[] = {'t', 'i', 'o', 'n'}; +static const symbol s_31[] = {'a', 't', 'e'}; +static const symbol s_32[] = {'a', 'l'}; +static const symbol s_33[] = {'i', 'c'}; +static const symbol s_34[] = {'s'}; +static const symbol s_35[] = {'t'}; +static const symbol s_36[] = {'l'}; +static const symbol s_37[] = {'s', 'k', 'i'}; +static const symbol s_38[] = {'s', 'k', 'y'}; +static const symbol s_39[] = {'d', 'i', 'e'}; +static const symbol s_40[] = {'l', 'i', 'e'}; +static const symbol s_41[] = {'t', 'i', 'e'}; +static const symbol s_42[] = {'i', 'd', 'l'}; +static const symbol s_43[] = {'g', 'e', 'n', 't', 'l'}; +static const symbol s_44[] = {'u', 'g', 'l', 'i'}; +static const symbol s_45[] = {'e', 'a', 'r', 'l', 'i'}; +static const symbol s_46[] = {'o', 'n', 'l', 'i'}; +static const symbol s_47[] = {'s', 'i', 'n', 'g', 'l'}; +static const symbol s_48[] = {'Y'}; +static const symbol s_49[] = {'y'}; + +static int r_prelude(struct SN_env *z) { + z->B[0] = 0; /* unset Y_found, line 26 */ + { + int c1 = z->c; /* do, line 27 */ + z->bra = z->c; /* [, line 27 */ + if (!(eq_s(z, 1, s_0))) + goto lab0; + z->ket = z->c; /* ], line 27 */ + { + int ret = slice_del(z); /* delete, line 27 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 28 */ + z->bra = z->c; /* [, line 28 */ + if (!(eq_s(z, 1, s_1))) + goto lab1; + z->ket = z->c; /* ], line 28 */ + { + int ret = slice_from_s(z, 1, s_2); /* <-, line 28 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set Y_found, line 28 */ + lab1: + z->c = c2; + } + { + int c3 = z->c; /* do, line 29 */ + while (1) { /* repeat, line 29 */ + int c4 = z->c; + while (1) { /* goto, line 29 */ + int c5 = z->c; + if (in_grouping_U(z, g_v, 97, 121, 0)) + goto lab4; + z->bra = z->c; /* [, line 29 */ + if (!(eq_s(z, 1, s_3))) + goto lab4; + z->ket = z->c; /* ], line 29 */ + z->c = c5; + break; + lab4: + z->c = c5; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab3; + z->c = ret; /* goto, line 29 */ + } + } + { + int ret = slice_from_s(z, 1, s_4); /* <-, line 29 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set Y_found, line 29 */ + continue; + lab3: + z->c = c4; + break; + } + z->c = c3; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + { + int c1 = z->c; /* do, line 35 */ + { + int c2 = z->c; /* or, line 41 */ + if (z->c + 4 >= z->l || z->p[z->c + 4] >> 5 != 3 || !((2375680 >> (z->p[z->c + 4] & 0x1f)) & 1)) + goto lab2; + if (!(find_among(z, a_0, 3))) + goto lab2; /* among, line 36 */ + goto lab1; + lab2: + z->c = c2; + { /* gopast */ /* grouping v, line 41 */ + int ret = out_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + { /* gopast */ /* non v, line 41 */ + int ret = in_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + } + lab1: + z->I[0] = z->c; /* setmark p1, line 42 */ + { /* gopast */ /* grouping v, line 43 */ + int ret = out_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + { /* gopast */ /* non v, line 43 */ + int ret = in_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 43 */ + lab0: + z->c = c1; + } + return 1; +} + +static int r_shortv(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 51 */ + if (out_grouping_b_U(z, g_v_WXY, 89, 121, 0)) + goto lab1; + if (in_grouping_b_U(z, g_v, 97, 121, 0)) + goto lab1; + if (out_grouping_b_U(z, g_v, 97, 121, 0)) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (out_grouping_b_U(z, g_v, 97, 121, 0)) + return 0; + if (in_grouping_b_U(z, g_v, 97, 121, 0)) + return 0; + if (z->c > z->lb) + return 0; /* atlimit, line 52 */ + } +lab0: + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_Step_1a(struct SN_env *z) { + int among_var; + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 59 */ + z->ket = z->c; /* [, line 60 */ + if (z->c <= z->lb || (z->p[z->c - 1] != 39 && z->p[z->c - 1] != 115)) { + z->c = z->l - m_keep; + goto lab0; + } + among_var = find_among_b(z, a_1, 3); /* substring, line 60 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 60 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab0; + } + case 1: { + int ret = slice_del(z); /* delete, line 62 */ + if (ret < 0) + return ret; + } break; + } + lab0:; + } + z->ket = z->c; /* [, line 65 */ + if (z->c <= z->lb || (z->p[z->c - 1] != 100 && z->p[z->c - 1] != 115)) + return 0; + among_var = find_among_b(z, a_2, 6); /* substring, line 65 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 65 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 2, s_5); /* <-, line 66 */ + if (ret < 0) + return ret; + } break; + case 2: { + int m1 = z->l - z->c; + (void)m1; /* or, line 68 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, z->l, -2); + if (ret < 0) + goto lab2; + z->c = ret; /* hop, line 68 */ + } + { + int ret = slice_from_s(z, 1, s_6); /* <-, line 68 */ + if (ret < 0) + return ret; + } + goto lab1; + lab2: + z->c = z->l - m1; + { + int ret = slice_from_s(z, 2, s_7); /* <-, line 68 */ + if (ret < 0) + return ret; + } + } + lab1: + break; + case 3: { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 69 */ + } + { /* gopast */ /* grouping v, line 69 */ + int ret = out_grouping_b_U(z, g_v, 97, 121, 1); + if (ret < 0) + return 0; + z->c -= ret; + } + { + int ret = slice_del(z); /* delete, line 69 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_Step_1b(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 75 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((33554576 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_4, 6); /* substring, line 75 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 75 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 77 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 2, s_8); /* <-, line 77 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int m_test = z->l - z->c; /* test, line 80 */ + { /* gopast */ /* grouping v, line 80 */ + int ret = out_grouping_b_U(z, g_v, 97, 121, 1); + if (ret < 0) + return 0; + z->c -= ret; + } + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 80 */ + if (ret < 0) + return ret; + } + { + int m_test = z->l - z->c; /* test, line 81 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((68514004 >> (z->p[z->c - 1] & 0x1f)) & 1)) + among_var = 3; + else + among_var = find_among_b(z, a_3, 13); /* substring, line 81 */ + if (!(among_var)) + return 0; + z->c = z->l - m_test; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_9); /* <+, line 83 */ + z->c = c_keep; + if (ret < 0) + return ret; + } break; + case 2: + z->ket = z->c; /* [, line 86 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 86 */ + } + z->bra = z->c; /* ], line 86 */ + { + int ret = slice_del(z); /* delete, line 86 */ + if (ret < 0) + return ret; + } + break; + case 3: + if (z->c != z->I[0]) + return 0; /* atmark, line 87 */ + { + int m_test = z->l - z->c; /* test, line 87 */ + { + int ret = r_shortv(z); + if (ret == 0) + return 0; /* call shortv, line 87 */ + if (ret < 0) + return ret; + } + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_10); /* <+, line 87 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + break; + } + break; + } + return 1; +} + +static int r_Step_1c(struct SN_env *z) { + z->ket = z->c; /* [, line 94 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 94 */ + if (!(eq_s_b(z, 1, s_11))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_12))) + return 0; + } +lab0: + z->bra = z->c; /* ], line 94 */ + if (out_grouping_b_U(z, g_v, 97, 121, 0)) + return 0; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 95 */ + if (z->c > z->lb) + goto lab2; /* atlimit, line 95 */ + return 0; + lab2: + z->c = z->l - m2; + } + { + int ret = slice_from_s(z, 1, s_13); /* <-, line 96 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_Step_2(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 100 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((815616 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_5, 24); /* substring, line 100 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 100 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 100 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 4, s_14); /* <-, line 101 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 4, s_15); /* <-, line 102 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 4, s_16); /* <-, line 103 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 4, s_17); /* <-, line 104 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 3, s_18); /* <-, line 105 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 3, s_19); /* <-, line 107 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_from_s(z, 3, s_20); /* <-, line 109 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_from_s(z, 2, s_21); /* <-, line 111 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_from_s(z, 3, s_22); /* <-, line 112 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = slice_from_s(z, 3, s_23); /* <-, line 114 */ + if (ret < 0) + return ret; + } break; + case 11: { + int ret = slice_from_s(z, 3, s_24); /* <-, line 116 */ + if (ret < 0) + return ret; + } break; + case 12: { + int ret = slice_from_s(z, 3, s_25); /* <-, line 118 */ + if (ret < 0) + return ret; + } break; + case 13: + if (!(eq_s_b(z, 1, s_26))) + return 0; + { + int ret = slice_from_s(z, 2, s_27); /* <-, line 119 */ + if (ret < 0) + return ret; + } + break; + case 14: { + int ret = slice_from_s(z, 3, s_28); /* <-, line 120 */ + if (ret < 0) + return ret; + } break; + case 15: { + int ret = slice_from_s(z, 4, s_29); /* <-, line 121 */ + if (ret < 0) + return ret; + } break; + case 16: + if (in_grouping_b_U(z, g_valid_LI, 99, 116, 0)) + return 0; + { + int ret = slice_del(z); /* delete, line 122 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_Step_3(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 127 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((528928 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_6, 9); /* substring, line 127 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 127 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 127 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 4, s_30); /* <-, line 128 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 3, s_31); /* <-, line 129 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 2, s_32); /* <-, line 130 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 2, s_33); /* <-, line 132 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_del(z); /* delete, line 134 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 136 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 136 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_Step_4(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 141 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1864232 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_7, 18); /* substring, line 141 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 141 */ + { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 141 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 144 */ + if (ret < 0) + return ret; + } break; + case 2: { + int m1 = z->l - z->c; + (void)m1; /* or, line 145 */ + if (!(eq_s_b(z, 1, s_34))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_35))) + return 0; + } + lab0: { + int ret = slice_del(z); /* delete, line 145 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_Step_5(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 150 */ + if (z->c <= z->lb || (z->p[z->c - 1] != 101 && z->p[z->c - 1] != 108)) + return 0; + among_var = find_among_b(z, a_8, 2); /* substring, line 150 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 150 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int m1 = z->l - z->c; + (void)m1; /* or, line 151 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab1; /* call R2, line 151 */ + if (ret < 0) + return ret; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 151 */ + if (ret < 0) + return ret; + } + { + int m2 = z->l - z->c; + (void)m2; /* not, line 151 */ + { + int ret = r_shortv(z); + if (ret == 0) + goto lab2; /* call shortv, line 151 */ + if (ret < 0) + return ret; + } + return 0; + lab2: + z->c = z->l - m2; + } + } + lab0: { + int ret = slice_del(z); /* delete, line 151 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 152 */ + if (ret < 0) + return ret; + } + if (!(eq_s_b(z, 1, s_36))) + return 0; + { + int ret = slice_del(z); /* delete, line 152 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_exception2(struct SN_env *z) { + z->ket = z->c; /* [, line 158 */ + if (z->c - 5 <= z->lb || (z->p[z->c - 1] != 100 && z->p[z->c - 1] != 103)) + return 0; + if (!(find_among_b(z, a_9, 8))) + return 0; /* substring, line 158 */ + z->bra = z->c; /* ], line 158 */ + if (z->c > z->lb) + return 0; /* atlimit, line 158 */ + return 1; +} + +static int r_exception1(struct SN_env *z) { + int among_var; + z->bra = z->c; /* [, line 170 */ + if (z->c + 2 >= z->l || z->p[z->c + 2] >> 5 != 3 || !((42750482 >> (z->p[z->c + 2] & 0x1f)) & 1)) + return 0; + among_var = find_among(z, a_10, 18); /* substring, line 170 */ + if (!(among_var)) + return 0; + z->ket = z->c; /* ], line 170 */ + if (z->c < z->l) + return 0; /* atlimit, line 170 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 3, s_37); /* <-, line 174 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 3, s_38); /* <-, line 175 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 3, s_39); /* <-, line 176 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 3, s_40); /* <-, line 177 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 3, s_41); /* <-, line 178 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 3, s_42); /* <-, line 182 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_from_s(z, 5, s_43); /* <-, line 183 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_from_s(z, 4, s_44); /* <-, line 184 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_from_s(z, 5, s_45); /* <-, line 185 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = slice_from_s(z, 4, s_46); /* <-, line 186 */ + if (ret < 0) + return ret; + } break; + case 11: { + int ret = slice_from_s(z, 5, s_47); /* <-, line 187 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + if (!(z->B[0])) + return 0; /* Boolean test Y_found, line 203 */ + while (1) { /* repeat, line 203 */ + int c1 = z->c; + while (1) { /* goto, line 203 */ + int c2 = z->c; + z->bra = z->c; /* [, line 203 */ + if (!(eq_s(z, 1, s_48))) + goto lab1; + z->ket = z->c; /* ], line 203 */ + z->c = c2; + break; + lab1: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* goto, line 203 */ + } + } + { + int ret = slice_from_s(z, 1, s_49); /* <-, line 203 */ + if (ret < 0) + return ret; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +extern int english_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* or, line 207 */ + { + int ret = r_exception1(z); + if (ret == 0) + goto lab1; /* call exception1, line 207 */ + if (ret < 0) + return ret; + } + goto lab0; + lab1: + z->c = c1; + { + int c2 = z->c; /* not, line 208 */ + { + int ret = skip_utf8(z->p, z->c, 0, z->l, +3); + if (ret < 0) + goto lab3; + z->c = ret; /* hop, line 208 */ + } + goto lab2; + lab3: + z->c = c2; + } + goto lab0; + lab2: + z->c = c1; + { + int c3 = z->c; /* do, line 209 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab4; /* call prelude, line 209 */ + if (ret < 0) + return ret; + } + lab4: + z->c = c3; + } + { + int c4 = z->c; /* do, line 210 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab5; /* call mark_regions, line 210 */ + if (ret < 0) + return ret; + } + lab5: + z->c = c4; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 211 */ + + { + int m5 = z->l - z->c; + (void)m5; /* do, line 213 */ + { + int ret = r_Step_1a(z); + if (ret == 0) + goto lab6; /* call Step_1a, line 213 */ + if (ret < 0) + return ret; + } + lab6: + z->c = z->l - m5; + } + { + int m6 = z->l - z->c; + (void)m6; /* or, line 215 */ + { + int ret = r_exception2(z); + if (ret == 0) + goto lab8; /* call exception2, line 215 */ + if (ret < 0) + return ret; + } + goto lab7; + lab8: + z->c = z->l - m6; + { + int m7 = z->l - z->c; + (void)m7; /* do, line 217 */ + { + int ret = r_Step_1b(z); + if (ret == 0) + goto lab9; /* call Step_1b, line 217 */ + if (ret < 0) + return ret; + } + lab9: + z->c = z->l - m7; + } + { + int m8 = z->l - z->c; + (void)m8; /* do, line 218 */ + { + int ret = r_Step_1c(z); + if (ret == 0) + goto lab10; /* call Step_1c, line 218 */ + if (ret < 0) + return ret; + } + lab10: + z->c = z->l - m8; + } + { + int m9 = z->l - z->c; + (void)m9; /* do, line 220 */ + { + int ret = r_Step_2(z); + if (ret == 0) + goto lab11; /* call Step_2, line 220 */ + if (ret < 0) + return ret; + } + lab11: + z->c = z->l - m9; + } + { + int m10 = z->l - z->c; + (void)m10; /* do, line 221 */ + { + int ret = r_Step_3(z); + if (ret == 0) + goto lab12; /* call Step_3, line 221 */ + if (ret < 0) + return ret; + } + lab12: + z->c = z->l - m10; + } + { + int m11 = z->l - z->c; + (void)m11; /* do, line 222 */ + { + int ret = r_Step_4(z); + if (ret == 0) + goto lab13; /* call Step_4, line 222 */ + if (ret < 0) + return ret; + } + lab13: + z->c = z->l - m11; + } + { + int m12 = z->l - z->c; + (void)m12; /* do, line 224 */ + { + int ret = r_Step_5(z); + if (ret == 0) + goto lab14; /* call Step_5, line 224 */ + if (ret < 0) + return ret; + } + lab14: + z->c = z->l - m12; + } + } + lab7: + z->c = z->lb; + { + int c13 = z->c; /* do, line 227 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab15; /* call postlude, line 227 */ + if (ret < 0) + return ret; + } + lab15: + z->c = c13; + } + } +lab0: + return 1; +} + +extern struct SN_env *english_UTF_8_create_env(void) { return SN_create_env(0, 2, 1); } + +extern void english_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_english.h b/internal/cpp/stemmer/stem_UTF_8_english.h new file mode 100644 index 00000000000..22a38a5b17f --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_english.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *english_UTF_8_create_env(void); +extern void english_UTF_8_close_env(struct SN_env *z); + +extern int english_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_finnish.cpp b/internal/cpp/stemmer/stem_UTF_8_finnish.cpp new file mode 100644 index 00000000000..1a858ec4ac4 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_finnish.cpp @@ -0,0 +1,958 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int finnish_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_tidy(struct SN_env *z); +static int r_other_endings(struct SN_env *z); +static int r_t_plural(struct SN_env *z); +static int r_i_plural(struct SN_env *z); +static int r_case_ending(struct SN_env *z); +static int r_VI(struct SN_env *z); +static int r_LONG(struct SN_env *z); +static int r_possessive(struct SN_env *z); +static int r_particle_etc(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *finnish_UTF_8_create_env(void); +extern void finnish_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[2] = {'p', 'a'}; +static const symbol s_0_1[3] = {'s', 't', 'i'}; +static const symbol s_0_2[4] = {'k', 'a', 'a', 'n'}; +static const symbol s_0_3[3] = {'h', 'a', 'n'}; +static const symbol s_0_4[3] = {'k', 'i', 'n'}; +static const symbol s_0_5[4] = {'h', 0xC3, 0xA4, 'n'}; +static const symbol s_0_6[6] = {'k', 0xC3, 0xA4, 0xC3, 0xA4, 'n'}; +static const symbol s_0_7[2] = {'k', 'o'}; +static const symbol s_0_8[3] = {'p', 0xC3, 0xA4}; +static const symbol s_0_9[3] = {'k', 0xC3, 0xB6}; + +static const struct among a_0[10] = { + /* 0 */ {2, s_0_0, -1, 1, 0}, + /* 1 */ {3, s_0_1, -1, 2, 0}, + /* 2 */ {4, s_0_2, -1, 1, 0}, + /* 3 */ {3, s_0_3, -1, 1, 0}, + /* 4 */ {3, s_0_4, -1, 1, 0}, + /* 5 */ {4, s_0_5, -1, 1, 0}, + /* 6 */ {6, s_0_6, -1, 1, 0}, + /* 7 */ {2, s_0_7, -1, 1, 0}, + /* 8 */ {3, s_0_8, -1, 1, 0}, + /* 9 */ {3, s_0_9, -1, 1, 0}}; + +static const symbol s_1_0[3] = {'l', 'l', 'a'}; +static const symbol s_1_1[2] = {'n', 'a'}; +static const symbol s_1_2[3] = {'s', 's', 'a'}; +static const symbol s_1_3[2] = {'t', 'a'}; +static const symbol s_1_4[3] = {'l', 't', 'a'}; +static const symbol s_1_5[3] = {'s', 't', 'a'}; + +static const struct among a_1[6] = { + /* 0 */ {3, s_1_0, -1, -1, 0}, + /* 1 */ {2, s_1_1, -1, -1, 0}, + /* 2 */ {3, s_1_2, -1, -1, 0}, + /* 3 */ {2, s_1_3, -1, -1, 0}, + /* 4 */ {3, s_1_4, 3, -1, 0}, + /* 5 */ {3, s_1_5, 3, -1, 0}}; + +static const symbol s_2_0[4] = {'l', 'l', 0xC3, 0xA4}; +static const symbol s_2_1[3] = {'n', 0xC3, 0xA4}; +static const symbol s_2_2[4] = {'s', 's', 0xC3, 0xA4}; +static const symbol s_2_3[3] = {'t', 0xC3, 0xA4}; +static const symbol s_2_4[4] = {'l', 't', 0xC3, 0xA4}; +static const symbol s_2_5[4] = {'s', 't', 0xC3, 0xA4}; + +static const struct among a_2[6] = { + /* 0 */ {4, s_2_0, -1, -1, 0}, + /* 1 */ {3, s_2_1, -1, -1, 0}, + /* 2 */ {4, s_2_2, -1, -1, 0}, + /* 3 */ {3, s_2_3, -1, -1, 0}, + /* 4 */ {4, s_2_4, 3, -1, 0}, + /* 5 */ {4, s_2_5, 3, -1, 0}}; + +static const symbol s_3_0[3] = {'l', 'l', 'e'}; +static const symbol s_3_1[3] = {'i', 'n', 'e'}; + +static const struct among a_3[2] = { + /* 0 */ {3, s_3_0, -1, -1, 0}, + /* 1 */ {3, s_3_1, -1, -1, 0}}; + +static const symbol s_4_0[3] = {'n', 's', 'a'}; +static const symbol s_4_1[3] = {'m', 'm', 'e'}; +static const symbol s_4_2[3] = {'n', 'n', 'e'}; +static const symbol s_4_3[2] = {'n', 'i'}; +static const symbol s_4_4[2] = {'s', 'i'}; +static const symbol s_4_5[2] = {'a', 'n'}; +static const symbol s_4_6[2] = {'e', 'n'}; +static const symbol s_4_7[3] = {0xC3, 0xA4, 'n'}; +static const symbol s_4_8[4] = {'n', 's', 0xC3, 0xA4}; + +static const struct among a_4[9] = { + /* 0 */ {3, s_4_0, -1, 3, 0}, + /* 1 */ {3, s_4_1, -1, 3, 0}, + /* 2 */ {3, s_4_2, -1, 3, 0}, + /* 3 */ {2, s_4_3, -1, 2, 0}, + /* 4 */ {2, s_4_4, -1, 1, 0}, + /* 5 */ {2, s_4_5, -1, 4, 0}, + /* 6 */ {2, s_4_6, -1, 6, 0}, + /* 7 */ {3, s_4_7, -1, 5, 0}, + /* 8 */ {4, s_4_8, -1, 3, 0}}; + +static const symbol s_5_0[2] = {'a', 'a'}; +static const symbol s_5_1[2] = {'e', 'e'}; +static const symbol s_5_2[2] = {'i', 'i'}; +static const symbol s_5_3[2] = {'o', 'o'}; +static const symbol s_5_4[2] = {'u', 'u'}; +static const symbol s_5_5[4] = {0xC3, 0xA4, 0xC3, 0xA4}; +static const symbol s_5_6[4] = {0xC3, 0xB6, 0xC3, 0xB6}; + +static const struct among a_5[7] = { + /* 0 */ {2, s_5_0, -1, -1, 0}, + /* 1 */ {2, s_5_1, -1, -1, 0}, + /* 2 */ {2, s_5_2, -1, -1, 0}, + /* 3 */ {2, s_5_3, -1, -1, 0}, + /* 4 */ {2, s_5_4, -1, -1, 0}, + /* 5 */ {4, s_5_5, -1, -1, 0}, + /* 6 */ {4, s_5_6, -1, -1, 0}}; + +static const symbol s_6_0[1] = {'a'}; +static const symbol s_6_1[3] = {'l', 'l', 'a'}; +static const symbol s_6_2[2] = {'n', 'a'}; +static const symbol s_6_3[3] = {'s', 's', 'a'}; +static const symbol s_6_4[2] = {'t', 'a'}; +static const symbol s_6_5[3] = {'l', 't', 'a'}; +static const symbol s_6_6[3] = {'s', 't', 'a'}; +static const symbol s_6_7[3] = {'t', 't', 'a'}; +static const symbol s_6_8[3] = {'l', 'l', 'e'}; +static const symbol s_6_9[3] = {'i', 'n', 'e'}; +static const symbol s_6_10[3] = {'k', 's', 'i'}; +static const symbol s_6_11[1] = {'n'}; +static const symbol s_6_12[3] = {'h', 'a', 'n'}; +static const symbol s_6_13[3] = {'d', 'e', 'n'}; +static const symbol s_6_14[4] = {'s', 'e', 'e', 'n'}; +static const symbol s_6_15[3] = {'h', 'e', 'n'}; +static const symbol s_6_16[4] = {'t', 't', 'e', 'n'}; +static const symbol s_6_17[3] = {'h', 'i', 'n'}; +static const symbol s_6_18[4] = {'s', 'i', 'i', 'n'}; +static const symbol s_6_19[3] = {'h', 'o', 'n'}; +static const symbol s_6_20[4] = {'h', 0xC3, 0xA4, 'n'}; +static const symbol s_6_21[4] = {'h', 0xC3, 0xB6, 'n'}; +static const symbol s_6_22[2] = {0xC3, 0xA4}; +static const symbol s_6_23[4] = {'l', 'l', 0xC3, 0xA4}; +static const symbol s_6_24[3] = {'n', 0xC3, 0xA4}; +static const symbol s_6_25[4] = {'s', 's', 0xC3, 0xA4}; +static const symbol s_6_26[3] = {'t', 0xC3, 0xA4}; +static const symbol s_6_27[4] = {'l', 't', 0xC3, 0xA4}; +static const symbol s_6_28[4] = {'s', 't', 0xC3, 0xA4}; +static const symbol s_6_29[4] = {'t', 't', 0xC3, 0xA4}; + +static const struct among a_6[30] = { + /* 0 */ {1, s_6_0, -1, 8, 0}, + /* 1 */ {3, s_6_1, 0, -1, 0}, + /* 2 */ {2, s_6_2, 0, -1, 0}, + /* 3 */ {3, s_6_3, 0, -1, 0}, + /* 4 */ {2, s_6_4, 0, -1, 0}, + /* 5 */ {3, s_6_5, 4, -1, 0}, + /* 6 */ {3, s_6_6, 4, -1, 0}, + /* 7 */ {3, s_6_7, 4, 9, 0}, + /* 8 */ {3, s_6_8, -1, -1, 0}, + /* 9 */ {3, s_6_9, -1, -1, 0}, + /* 10 */ {3, s_6_10, -1, -1, 0}, + /* 11 */ {1, s_6_11, -1, 7, 0}, + /* 12 */ {3, s_6_12, 11, 1, 0}, + /* 13 */ {3, s_6_13, 11, -1, r_VI}, + /* 14 */ {4, s_6_14, 11, -1, r_LONG}, + /* 15 */ {3, s_6_15, 11, 2, 0}, + /* 16 */ {4, s_6_16, 11, -1, r_VI}, + /* 17 */ {3, s_6_17, 11, 3, 0}, + /* 18 */ {4, s_6_18, 11, -1, r_VI}, + /* 19 */ {3, s_6_19, 11, 4, 0}, + /* 20 */ {4, s_6_20, 11, 5, 0}, + /* 21 */ {4, s_6_21, 11, 6, 0}, + /* 22 */ {2, s_6_22, -1, 8, 0}, + /* 23 */ {4, s_6_23, 22, -1, 0}, + /* 24 */ {3, s_6_24, 22, -1, 0}, + /* 25 */ {4, s_6_25, 22, -1, 0}, + /* 26 */ {3, s_6_26, 22, -1, 0}, + /* 27 */ {4, s_6_27, 26, -1, 0}, + /* 28 */ {4, s_6_28, 26, -1, 0}, + /* 29 */ {4, s_6_29, 26, 9, 0}}; + +static const symbol s_7_0[3] = {'e', 'j', 'a'}; +static const symbol s_7_1[3] = {'m', 'm', 'a'}; +static const symbol s_7_2[4] = {'i', 'm', 'm', 'a'}; +static const symbol s_7_3[3] = {'m', 'p', 'a'}; +static const symbol s_7_4[4] = {'i', 'm', 'p', 'a'}; +static const symbol s_7_5[3] = {'m', 'm', 'i'}; +static const symbol s_7_6[4] = {'i', 'm', 'm', 'i'}; +static const symbol s_7_7[3] = {'m', 'p', 'i'}; +static const symbol s_7_8[4] = {'i', 'm', 'p', 'i'}; +static const symbol s_7_9[4] = {'e', 'j', 0xC3, 0xA4}; +static const symbol s_7_10[4] = {'m', 'm', 0xC3, 0xA4}; +static const symbol s_7_11[5] = {'i', 'm', 'm', 0xC3, 0xA4}; +static const symbol s_7_12[4] = {'m', 'p', 0xC3, 0xA4}; +static const symbol s_7_13[5] = {'i', 'm', 'p', 0xC3, 0xA4}; + +static const struct among a_7[14] = { + /* 0 */ {3, s_7_0, -1, -1, 0}, + /* 1 */ {3, s_7_1, -1, 1, 0}, + /* 2 */ {4, s_7_2, 1, -1, 0}, + /* 3 */ {3, s_7_3, -1, 1, 0}, + /* 4 */ {4, s_7_4, 3, -1, 0}, + /* 5 */ {3, s_7_5, -1, 1, 0}, + /* 6 */ {4, s_7_6, 5, -1, 0}, + /* 7 */ {3, s_7_7, -1, 1, 0}, + /* 8 */ {4, s_7_8, 7, -1, 0}, + /* 9 */ {4, s_7_9, -1, -1, 0}, + /* 10 */ {4, s_7_10, -1, 1, 0}, + /* 11 */ {5, s_7_11, 10, -1, 0}, + /* 12 */ {4, s_7_12, -1, 1, 0}, + /* 13 */ {5, s_7_13, 12, -1, 0}}; + +static const symbol s_8_0[1] = {'i'}; +static const symbol s_8_1[1] = {'j'}; + +static const struct among a_8[2] = { + /* 0 */ {1, s_8_0, -1, -1, 0}, + /* 1 */ {1, s_8_1, -1, -1, 0}}; + +static const symbol s_9_0[3] = {'m', 'm', 'a'}; +static const symbol s_9_1[4] = {'i', 'm', 'm', 'a'}; + +static const struct among a_9[2] = { + /* 0 */ {3, s_9_0, -1, 1, 0}, + /* 1 */ {4, s_9_1, 0, -1, 0}}; + +static const unsigned char g_AEI[] = {17, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8}; + +static const unsigned char g_V1[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 32}; + +static const unsigned char g_V2[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 32}; + +static const unsigned char g_particle_end[] = {17, 97, 24, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 32}; + +static const symbol s_0[] = {'k'}; +static const symbol s_1[] = {'k', 's', 'e'}; +static const symbol s_2[] = {'k', 's', 'i'}; +static const symbol s_3[] = {'i'}; +static const symbol s_4[] = {'a'}; +static const symbol s_5[] = {'e'}; +static const symbol s_6[] = {'i'}; +static const symbol s_7[] = {'o'}; +static const symbol s_8[] = {0xC3, 0xA4}; +static const symbol s_9[] = {0xC3, 0xB6}; +static const symbol s_10[] = {'i', 'e'}; +static const symbol s_11[] = {'e'}; +static const symbol s_12[] = {'p', 'o'}; +static const symbol s_13[] = {'t'}; +static const symbol s_14[] = {'p', 'o'}; +static const symbol s_15[] = {'j'}; +static const symbol s_16[] = {'o'}; +static const symbol s_17[] = {'u'}; +static const symbol s_18[] = {'o'}; +static const symbol s_19[] = {'j'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + if (out_grouping_U(z, g_V1, 97, 246, 1) < 0) + return 0; /* goto */ /* grouping V1, line 46 */ + { /* gopast */ /* non V1, line 46 */ + int ret = in_grouping_U(z, g_V1, 97, 246, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 46 */ + if (out_grouping_U(z, g_V1, 97, 246, 1) < 0) + return 0; /* goto */ /* grouping V1, line 47 */ + { /* gopast */ /* non V1, line 47 */ + int ret = in_grouping_U(z, g_V1, 97, 246, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 47 */ + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_particle_etc(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 55 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 55 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 55 */ + among_var = find_among_b(z, a_0, 10); /* substring, line 55 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 55 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: + if (in_grouping_b_U(z, g_particle_end, 97, 246, 0)) + return 0; + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 64 */ + if (ret < 0) + return ret; + } break; + } + { + int ret = slice_del(z); /* delete, line 66 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_possessive(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 69 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 69 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 69 */ + among_var = find_among_b(z, a_4, 9); /* substring, line 69 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 69 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int m2 = z->l - z->c; + (void)m2; /* not, line 72 */ + if (!(eq_s_b(z, 1, s_0))) + goto lab0; + return 0; + lab0: + z->c = z->l - m2; + } + { + int ret = slice_del(z); /* delete, line 72 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = slice_del(z); /* delete, line 74 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 74 */ + if (!(eq_s_b(z, 3, s_1))) + return 0; + z->bra = z->c; /* ], line 74 */ + { + int ret = slice_from_s(z, 3, s_2); /* <-, line 74 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = slice_del(z); /* delete, line 78 */ + if (ret < 0) + return ret; + } break; + case 4: + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 97) + return 0; + if (!(find_among_b(z, a_1, 6))) + return 0; /* among, line 81 */ + { + int ret = slice_del(z); /* delete, line 81 */ + if (ret < 0) + return ret; + } + break; + case 5: + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 164) + return 0; + if (!(find_among_b(z, a_2, 6))) + return 0; /* among, line 83 */ + { + int ret = slice_del(z); /* delete, line 84 */ + if (ret < 0) + return ret; + } + break; + case 6: + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 101) + return 0; + if (!(find_among_b(z, a_3, 2))) + return 0; /* among, line 86 */ + { + int ret = slice_del(z); /* delete, line 86 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_LONG(struct SN_env *z) { + if (!(find_among_b(z, a_5, 7))) + return 0; /* among, line 91 */ + return 1; +} + +static int r_VI(struct SN_env *z) { + if (!(eq_s_b(z, 1, s_3))) + return 0; + if (in_grouping_b_U(z, g_V2, 97, 246, 0)) + return 0; + return 1; +} + +static int r_case_ending(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 96 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 96 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 96 */ + among_var = find_among_b(z, a_6, 30); /* substring, line 96 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 96 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: + if (!(eq_s_b(z, 1, s_4))) + return 0; + break; + case 2: + if (!(eq_s_b(z, 1, s_5))) + return 0; + break; + case 3: + if (!(eq_s_b(z, 1, s_6))) + return 0; + break; + case 4: + if (!(eq_s_b(z, 1, s_7))) + return 0; + break; + case 5: + if (!(eq_s_b(z, 2, s_8))) + return 0; + break; + case 6: + if (!(eq_s_b(z, 2, s_9))) + return 0; + break; + case 7: { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 111 */ + { + int m2 = z->l - z->c; + (void)m2; /* and, line 113 */ + { + int m3 = z->l - z->c; + (void)m3; /* or, line 112 */ + { + int ret = r_LONG(z); + if (ret == 0) + goto lab2; /* call LONG, line 111 */ + if (ret < 0) + return ret; + } + goto lab1; + lab2: + z->c = z->l - m3; + if (!(eq_s_b(z, 2, s_10))) { + z->c = z->l - m_keep; + goto lab0; + } + } + lab1: + z->c = z->l - m2; + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) { + z->c = z->l - m_keep; + goto lab0; + } + z->c = ret; /* next, line 113 */ + } + } + z->bra = z->c; /* ], line 113 */ + lab0:; + } break; + case 8: + if (in_grouping_b_U(z, g_V1, 97, 246, 0)) + return 0; + if (out_grouping_b_U(z, g_V1, 97, 246, 0)) + return 0; + break; + case 9: + if (!(eq_s_b(z, 1, s_11))) + return 0; + break; + } + { + int ret = slice_del(z); /* delete, line 138 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set ending_removed, line 139 */ + return 1; +} + +static int r_other_endings(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 142 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[1]) + return 0; + z->c = z->I[1]; /* tomark, line 142 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 142 */ + among_var = find_among_b(z, a_7, 14); /* substring, line 142 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 142 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int m2 = z->l - z->c; + (void)m2; /* not, line 146 */ + if (!(eq_s_b(z, 2, s_12))) + goto lab0; + return 0; + lab0: + z->c = z->l - m2; + } break; + } + { + int ret = slice_del(z); /* delete, line 151 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_i_plural(struct SN_env *z) { + { + int mlimit; /* setlimit, line 154 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 154 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 154 */ + if (z->c <= z->lb || (z->p[z->c - 1] != 105 && z->p[z->c - 1] != 106)) { + z->lb = mlimit; + return 0; + } + if (!(find_among_b(z, a_8, 2))) { + z->lb = mlimit; + return 0; + } /* substring, line 154 */ + z->bra = z->c; /* ], line 154 */ + z->lb = mlimit; + } + { + int ret = slice_del(z); /* delete, line 158 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_t_plural(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 161 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 161 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 162 */ + if (!(eq_s_b(z, 1, s_13))) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 162 */ + { + int m_test = z->l - z->c; /* test, line 162 */ + if (in_grouping_b_U(z, g_V1, 97, 246, 0)) { + z->lb = mlimit; + return 0; + } + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 163 */ + if (ret < 0) + return ret; + } + z->lb = mlimit; + } + { + int mlimit; /* setlimit, line 165 */ + int m2 = z->l - z->c; + (void)m2; + if (z->c < z->I[1]) + return 0; + z->c = z->I[1]; /* tomark, line 165 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m2; + z->ket = z->c; /* [, line 165 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 97) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_9, 2); /* substring, line 165 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 165 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int m3 = z->l - z->c; + (void)m3; /* not, line 167 */ + if (!(eq_s_b(z, 2, s_14))) + goto lab0; + return 0; + lab0: + z->c = z->l - m3; + } break; + } + { + int ret = slice_del(z); /* delete, line 170 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_tidy(struct SN_env *z) { + { + int mlimit; /* setlimit, line 173 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 173 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* do, line 174 */ + { + int m3 = z->l - z->c; + (void)m3; /* and, line 174 */ + { + int ret = r_LONG(z); + if (ret == 0) + goto lab0; /* call LONG, line 174 */ + if (ret < 0) + return ret; + } + z->c = z->l - m3; + z->ket = z->c; /* [, line 174 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 174 */ + } + z->bra = z->c; /* ], line 174 */ + { + int ret = slice_del(z); /* delete, line 174 */ + if (ret < 0) + return ret; + } + } + lab0: + z->c = z->l - m2; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 175 */ + z->ket = z->c; /* [, line 175 */ + if (in_grouping_b_U(z, g_AEI, 97, 228, 0)) + goto lab1; + z->bra = z->c; /* ], line 175 */ + if (out_grouping_b_U(z, g_V1, 97, 246, 0)) + goto lab1; + { + int ret = slice_del(z); /* delete, line 175 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m4; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 176 */ + z->ket = z->c; /* [, line 176 */ + if (!(eq_s_b(z, 1, s_15))) + goto lab2; + z->bra = z->c; /* ], line 176 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 176 */ + if (!(eq_s_b(z, 1, s_16))) + goto lab4; + goto lab3; + lab4: + z->c = z->l - m6; + if (!(eq_s_b(z, 1, s_17))) + goto lab2; + } + lab3: { + int ret = slice_del(z); /* delete, line 176 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m5; + } + { + int m7 = z->l - z->c; + (void)m7; /* do, line 177 */ + z->ket = z->c; /* [, line 177 */ + if (!(eq_s_b(z, 1, s_18))) + goto lab5; + z->bra = z->c; /* ], line 177 */ + if (!(eq_s_b(z, 1, s_19))) + goto lab5; + { + int ret = slice_del(z); /* delete, line 177 */ + if (ret < 0) + return ret; + } + lab5: + z->c = z->l - m7; + } + z->lb = mlimit; + } + if (in_grouping_b_U(z, g_V1, 97, 246, 1) < 0) + return 0; /* goto */ /* non V1, line 179 */ + z->ket = z->c; /* [, line 179 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 179 */ + } + z->bra = z->c; /* ], line 179 */ + z->S[0] = slice_to(z, z->S[0]); /* -> x, line 179 */ + if (z->S[0] == 0) + return -1; /* -> x, line 179 */ + if (!(eq_v_b(z, z->S[0]))) + return 0; /* name x, line 179 */ + { + int ret = slice_del(z); /* delete, line 179 */ + if (ret < 0) + return ret; + } + return 1; +} + +extern int finnish_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 185 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 185 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->B[0] = 0; /* unset ending_removed, line 186 */ + z->lb = z->c; + z->c = z->l; /* backwards, line 187 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 188 */ + { + int ret = r_particle_etc(z); + if (ret == 0) + goto lab1; /* call particle_etc, line 188 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 189 */ + { + int ret = r_possessive(z); + if (ret == 0) + goto lab2; /* call possessive, line 189 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 190 */ + { + int ret = r_case_ending(z); + if (ret == 0) + goto lab3; /* call case_ending, line 190 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 191 */ + { + int ret = r_other_endings(z); + if (ret == 0) + goto lab4; /* call other_endings, line 191 */ + if (ret < 0) + return ret; + } + lab4: + z->c = z->l - m5; + } + { + int m6 = z->l - z->c; + (void)m6; /* or, line 192 */ + if (!(z->B[0])) + goto lab6; /* Boolean test ending_removed, line 192 */ + { + int m7 = z->l - z->c; + (void)m7; /* do, line 192 */ + { + int ret = r_i_plural(z); + if (ret == 0) + goto lab7; /* call i_plural, line 192 */ + if (ret < 0) + return ret; + } + lab7: + z->c = z->l - m7; + } + goto lab5; + lab6: + z->c = z->l - m6; + { + int m8 = z->l - z->c; + (void)m8; /* do, line 192 */ + { + int ret = r_t_plural(z); + if (ret == 0) + goto lab8; /* call t_plural, line 192 */ + if (ret < 0) + return ret; + } + lab8: + z->c = z->l - m8; + } + } +lab5: { + int m9 = z->l - z->c; + (void)m9; /* do, line 193 */ + { + int ret = r_tidy(z); + if (ret == 0) + goto lab9; /* call tidy, line 193 */ + if (ret < 0) + return ret; + } +lab9: + z->c = z->l - m9; +} + z->c = z->lb; + return 1; +} + +extern struct SN_env *finnish_UTF_8_create_env(void) { return SN_create_env(1, 2, 1); } + +extern void finnish_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 1); } diff --git a/internal/cpp/stemmer/stem_UTF_8_finnish.h b/internal/cpp/stemmer/stem_UTF_8_finnish.h new file mode 100644 index 00000000000..6205ebd09f1 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_finnish.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *finnish_UTF_8_create_env(void); +extern void finnish_UTF_8_close_env(struct SN_env *z); + +extern int finnish_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_french.cpp b/internal/cpp/stemmer/stem_UTF_8_french.cpp new file mode 100644 index 00000000000..849c40c4952 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_french.cpp @@ -0,0 +1,1605 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int french_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_un_accent(struct SN_env *z); +static int r_un_double(struct SN_env *z); +static int r_residual_suffix(struct SN_env *z); +static int r_verb_suffix(struct SN_env *z); +static int r_i_verb_suffix(struct SN_env *z); +static int r_standard_suffix(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_RV(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *french_UTF_8_create_env(void); +extern void french_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[3] = {'c', 'o', 'l'}; +static const symbol s_0_1[3] = {'p', 'a', 'r'}; +static const symbol s_0_2[3] = {'t', 'a', 'p'}; + +static const struct among a_0[3] = { + /* 0 */ {3, s_0_0, -1, -1, 0}, + /* 1 */ {3, s_0_1, -1, -1, 0}, + /* 2 */ {3, s_0_2, -1, -1, 0}}; + +static const symbol s_1_1[1] = {'I'}; +static const symbol s_1_2[1] = {'U'}; +static const symbol s_1_3[1] = {'Y'}; + +static const struct among a_1[4] = { + /* 0 */ {0, 0, -1, 4, 0}, + /* 1 */ {1, s_1_1, 0, 1, 0}, + /* 2 */ {1, s_1_2, 0, 2, 0}, + /* 3 */ {1, s_1_3, 0, 3, 0}}; + +static const symbol s_2_0[3] = {'i', 'q', 'U'}; +static const symbol s_2_1[3] = {'a', 'b', 'l'}; +static const symbol s_2_2[4] = {'I', 0xC3, 0xA8, 'r'}; +static const symbol s_2_3[4] = {'i', 0xC3, 0xA8, 'r'}; +static const symbol s_2_4[3] = {'e', 'u', 's'}; +static const symbol s_2_5[2] = {'i', 'v'}; + +static const struct among a_2[6] = { + /* 0 */ {3, s_2_0, -1, 3, 0}, + /* 1 */ {3, s_2_1, -1, 3, 0}, + /* 2 */ {4, s_2_2, -1, 4, 0}, + /* 3 */ {4, s_2_3, -1, 4, 0}, + /* 4 */ {3, s_2_4, -1, 2, 0}, + /* 5 */ {2, s_2_5, -1, 1, 0}}; + +static const symbol s_3_0[2] = {'i', 'c'}; +static const symbol s_3_1[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_3_2[2] = {'i', 'v'}; + +static const struct among a_3[3] = { + /* 0 */ {2, s_3_0, -1, 2, 0}, + /* 1 */ {4, s_3_1, -1, 1, 0}, + /* 2 */ {2, s_3_2, -1, 3, 0}}; + +static const symbol s_4_0[4] = {'i', 'q', 'U', 'e'}; +static const symbol s_4_1[6] = {'a', 't', 'r', 'i', 'c', 'e'}; +static const symbol s_4_2[4] = {'a', 'n', 'c', 'e'}; +static const symbol s_4_3[4] = {'e', 'n', 'c', 'e'}; +static const symbol s_4_4[5] = {'l', 'o', 'g', 'i', 'e'}; +static const symbol s_4_5[4] = {'a', 'b', 'l', 'e'}; +static const symbol s_4_6[4] = {'i', 's', 'm', 'e'}; +static const symbol s_4_7[4] = {'e', 'u', 's', 'e'}; +static const symbol s_4_8[4] = {'i', 's', 't', 'e'}; +static const symbol s_4_9[3] = {'i', 'v', 'e'}; +static const symbol s_4_10[2] = {'i', 'f'}; +static const symbol s_4_11[5] = {'u', 's', 'i', 'o', 'n'}; +static const symbol s_4_12[5] = {'a', 't', 'i', 'o', 'n'}; +static const symbol s_4_13[5] = {'u', 't', 'i', 'o', 'n'}; +static const symbol s_4_14[5] = {'a', 't', 'e', 'u', 'r'}; +static const symbol s_4_15[5] = {'i', 'q', 'U', 'e', 's'}; +static const symbol s_4_16[7] = {'a', 't', 'r', 'i', 'c', 'e', 's'}; +static const symbol s_4_17[5] = {'a', 'n', 'c', 'e', 's'}; +static const symbol s_4_18[5] = {'e', 'n', 'c', 'e', 's'}; +static const symbol s_4_19[6] = {'l', 'o', 'g', 'i', 'e', 's'}; +static const symbol s_4_20[5] = {'a', 'b', 'l', 'e', 's'}; +static const symbol s_4_21[5] = {'i', 's', 'm', 'e', 's'}; +static const symbol s_4_22[5] = {'e', 'u', 's', 'e', 's'}; +static const symbol s_4_23[5] = {'i', 's', 't', 'e', 's'}; +static const symbol s_4_24[4] = {'i', 'v', 'e', 's'}; +static const symbol s_4_25[3] = {'i', 'f', 's'}; +static const symbol s_4_26[6] = {'u', 's', 'i', 'o', 'n', 's'}; +static const symbol s_4_27[6] = {'a', 't', 'i', 'o', 'n', 's'}; +static const symbol s_4_28[6] = {'u', 't', 'i', 'o', 'n', 's'}; +static const symbol s_4_29[6] = {'a', 't', 'e', 'u', 'r', 's'}; +static const symbol s_4_30[5] = {'m', 'e', 'n', 't', 's'}; +static const symbol s_4_31[6] = {'e', 'm', 'e', 'n', 't', 's'}; +static const symbol s_4_32[9] = {'i', 's', 's', 'e', 'm', 'e', 'n', 't', 's'}; +static const symbol s_4_33[5] = {'i', 't', 0xC3, 0xA9, 's'}; +static const symbol s_4_34[4] = {'m', 'e', 'n', 't'}; +static const symbol s_4_35[5] = {'e', 'm', 'e', 'n', 't'}; +static const symbol s_4_36[8] = {'i', 's', 's', 'e', 'm', 'e', 'n', 't'}; +static const symbol s_4_37[6] = {'a', 'm', 'm', 'e', 'n', 't'}; +static const symbol s_4_38[6] = {'e', 'm', 'm', 'e', 'n', 't'}; +static const symbol s_4_39[3] = {'a', 'u', 'x'}; +static const symbol s_4_40[4] = {'e', 'a', 'u', 'x'}; +static const symbol s_4_41[3] = {'e', 'u', 'x'}; +static const symbol s_4_42[4] = {'i', 't', 0xC3, 0xA9}; + +static const struct among a_4[43] = { + /* 0 */ {4, s_4_0, -1, 1, 0}, + /* 1 */ {6, s_4_1, -1, 2, 0}, + /* 2 */ {4, s_4_2, -1, 1, 0}, + /* 3 */ {4, s_4_3, -1, 5, 0}, + /* 4 */ {5, s_4_4, -1, 3, 0}, + /* 5 */ {4, s_4_5, -1, 1, 0}, + /* 6 */ {4, s_4_6, -1, 1, 0}, + /* 7 */ {4, s_4_7, -1, 11, 0}, + /* 8 */ {4, s_4_8, -1, 1, 0}, + /* 9 */ {3, s_4_9, -1, 8, 0}, + /* 10 */ {2, s_4_10, -1, 8, 0}, + /* 11 */ {5, s_4_11, -1, 4, 0}, + /* 12 */ {5, s_4_12, -1, 2, 0}, + /* 13 */ {5, s_4_13, -1, 4, 0}, + /* 14 */ {5, s_4_14, -1, 2, 0}, + /* 15 */ {5, s_4_15, -1, 1, 0}, + /* 16 */ {7, s_4_16, -1, 2, 0}, + /* 17 */ {5, s_4_17, -1, 1, 0}, + /* 18 */ {5, s_4_18, -1, 5, 0}, + /* 19 */ {6, s_4_19, -1, 3, 0}, + /* 20 */ {5, s_4_20, -1, 1, 0}, + /* 21 */ {5, s_4_21, -1, 1, 0}, + /* 22 */ {5, s_4_22, -1, 11, 0}, + /* 23 */ {5, s_4_23, -1, 1, 0}, + /* 24 */ {4, s_4_24, -1, 8, 0}, + /* 25 */ {3, s_4_25, -1, 8, 0}, + /* 26 */ {6, s_4_26, -1, 4, 0}, + /* 27 */ {6, s_4_27, -1, 2, 0}, + /* 28 */ {6, s_4_28, -1, 4, 0}, + /* 29 */ {6, s_4_29, -1, 2, 0}, + /* 30 */ {5, s_4_30, -1, 15, 0}, + /* 31 */ {6, s_4_31, 30, 6, 0}, + /* 32 */ {9, s_4_32, 31, 12, 0}, + /* 33 */ {5, s_4_33, -1, 7, 0}, + /* 34 */ {4, s_4_34, -1, 15, 0}, + /* 35 */ {5, s_4_35, 34, 6, 0}, + /* 36 */ {8, s_4_36, 35, 12, 0}, + /* 37 */ {6, s_4_37, 34, 13, 0}, + /* 38 */ {6, s_4_38, 34, 14, 0}, + /* 39 */ {3, s_4_39, -1, 10, 0}, + /* 40 */ {4, s_4_40, 39, 9, 0}, + /* 41 */ {3, s_4_41, -1, 1, 0}, + /* 42 */ {4, s_4_42, -1, 7, 0}}; + +static const symbol s_5_0[3] = {'i', 'r', 'a'}; +static const symbol s_5_1[2] = {'i', 'e'}; +static const symbol s_5_2[4] = {'i', 's', 's', 'e'}; +static const symbol s_5_3[7] = {'i', 's', 's', 'a', 'n', 't', 'e'}; +static const symbol s_5_4[1] = {'i'}; +static const symbol s_5_5[4] = {'i', 'r', 'a', 'i'}; +static const symbol s_5_6[2] = {'i', 'r'}; +static const symbol s_5_7[4] = {'i', 'r', 'a', 's'}; +static const symbol s_5_8[3] = {'i', 'e', 's'}; +static const symbol s_5_9[5] = {0xC3, 0xAE, 'm', 'e', 's'}; +static const symbol s_5_10[5] = {'i', 's', 's', 'e', 's'}; +static const symbol s_5_11[8] = {'i', 's', 's', 'a', 'n', 't', 'e', 's'}; +static const symbol s_5_12[5] = {0xC3, 0xAE, 't', 'e', 's'}; +static const symbol s_5_13[2] = {'i', 's'}; +static const symbol s_5_14[5] = {'i', 'r', 'a', 'i', 's'}; +static const symbol s_5_15[6] = {'i', 's', 's', 'a', 'i', 's'}; +static const symbol s_5_16[6] = {'i', 'r', 'i', 'o', 'n', 's'}; +static const symbol s_5_17[7] = {'i', 's', 's', 'i', 'o', 'n', 's'}; +static const symbol s_5_18[5] = {'i', 'r', 'o', 'n', 's'}; +static const symbol s_5_19[6] = {'i', 's', 's', 'o', 'n', 's'}; +static const symbol s_5_20[7] = {'i', 's', 's', 'a', 'n', 't', 's'}; +static const symbol s_5_21[2] = {'i', 't'}; +static const symbol s_5_22[5] = {'i', 'r', 'a', 'i', 't'}; +static const symbol s_5_23[6] = {'i', 's', 's', 'a', 'i', 't'}; +static const symbol s_5_24[6] = {'i', 's', 's', 'a', 'n', 't'}; +static const symbol s_5_25[7] = {'i', 'r', 'a', 'I', 'e', 'n', 't'}; +static const symbol s_5_26[8] = {'i', 's', 's', 'a', 'I', 'e', 'n', 't'}; +static const symbol s_5_27[5] = {'i', 'r', 'e', 'n', 't'}; +static const symbol s_5_28[6] = {'i', 's', 's', 'e', 'n', 't'}; +static const symbol s_5_29[5] = {'i', 'r', 'o', 'n', 't'}; +static const symbol s_5_30[3] = {0xC3, 0xAE, 't'}; +static const symbol s_5_31[5] = {'i', 'r', 'i', 'e', 'z'}; +static const symbol s_5_32[6] = {'i', 's', 's', 'i', 'e', 'z'}; +static const symbol s_5_33[4] = {'i', 'r', 'e', 'z'}; +static const symbol s_5_34[5] = {'i', 's', 's', 'e', 'z'}; + +static const struct among a_5[35] = { + /* 0 */ {3, s_5_0, -1, 1, 0}, + /* 1 */ {2, s_5_1, -1, 1, 0}, + /* 2 */ {4, s_5_2, -1, 1, 0}, + /* 3 */ {7, s_5_3, -1, 1, 0}, + /* 4 */ {1, s_5_4, -1, 1, 0}, + /* 5 */ {4, s_5_5, 4, 1, 0}, + /* 6 */ {2, s_5_6, -1, 1, 0}, + /* 7 */ {4, s_5_7, -1, 1, 0}, + /* 8 */ {3, s_5_8, -1, 1, 0}, + /* 9 */ {5, s_5_9, -1, 1, 0}, + /* 10 */ {5, s_5_10, -1, 1, 0}, + /* 11 */ {8, s_5_11, -1, 1, 0}, + /* 12 */ {5, s_5_12, -1, 1, 0}, + /* 13 */ {2, s_5_13, -1, 1, 0}, + /* 14 */ {5, s_5_14, 13, 1, 0}, + /* 15 */ {6, s_5_15, 13, 1, 0}, + /* 16 */ {6, s_5_16, -1, 1, 0}, + /* 17 */ {7, s_5_17, -1, 1, 0}, + /* 18 */ {5, s_5_18, -1, 1, 0}, + /* 19 */ {6, s_5_19, -1, 1, 0}, + /* 20 */ {7, s_5_20, -1, 1, 0}, + /* 21 */ {2, s_5_21, -1, 1, 0}, + /* 22 */ {5, s_5_22, 21, 1, 0}, + /* 23 */ {6, s_5_23, 21, 1, 0}, + /* 24 */ {6, s_5_24, -1, 1, 0}, + /* 25 */ {7, s_5_25, -1, 1, 0}, + /* 26 */ {8, s_5_26, -1, 1, 0}, + /* 27 */ {5, s_5_27, -1, 1, 0}, + /* 28 */ {6, s_5_28, -1, 1, 0}, + /* 29 */ {5, s_5_29, -1, 1, 0}, + /* 30 */ {3, s_5_30, -1, 1, 0}, + /* 31 */ {5, s_5_31, -1, 1, 0}, + /* 32 */ {6, s_5_32, -1, 1, 0}, + /* 33 */ {4, s_5_33, -1, 1, 0}, + /* 34 */ {5, s_5_34, -1, 1, 0}}; + +static const symbol s_6_0[1] = {'a'}; +static const symbol s_6_1[3] = {'e', 'r', 'a'}; +static const symbol s_6_2[4] = {'a', 's', 's', 'e'}; +static const symbol s_6_3[4] = {'a', 'n', 't', 'e'}; +static const symbol s_6_4[3] = {0xC3, 0xA9, 'e'}; +static const symbol s_6_5[2] = {'a', 'i'}; +static const symbol s_6_6[4] = {'e', 'r', 'a', 'i'}; +static const symbol s_6_7[2] = {'e', 'r'}; +static const symbol s_6_8[2] = {'a', 's'}; +static const symbol s_6_9[4] = {'e', 'r', 'a', 's'}; +static const symbol s_6_10[5] = {0xC3, 0xA2, 'm', 'e', 's'}; +static const symbol s_6_11[5] = {'a', 's', 's', 'e', 's'}; +static const symbol s_6_12[5] = {'a', 'n', 't', 'e', 's'}; +static const symbol s_6_13[5] = {0xC3, 0xA2, 't', 'e', 's'}; +static const symbol s_6_14[4] = {0xC3, 0xA9, 'e', 's'}; +static const symbol s_6_15[3] = {'a', 'i', 's'}; +static const symbol s_6_16[5] = {'e', 'r', 'a', 'i', 's'}; +static const symbol s_6_17[4] = {'i', 'o', 'n', 's'}; +static const symbol s_6_18[6] = {'e', 'r', 'i', 'o', 'n', 's'}; +static const symbol s_6_19[7] = {'a', 's', 's', 'i', 'o', 'n', 's'}; +static const symbol s_6_20[5] = {'e', 'r', 'o', 'n', 's'}; +static const symbol s_6_21[4] = {'a', 'n', 't', 's'}; +static const symbol s_6_22[3] = {0xC3, 0xA9, 's'}; +static const symbol s_6_23[3] = {'a', 'i', 't'}; +static const symbol s_6_24[5] = {'e', 'r', 'a', 'i', 't'}; +static const symbol s_6_25[3] = {'a', 'n', 't'}; +static const symbol s_6_26[5] = {'a', 'I', 'e', 'n', 't'}; +static const symbol s_6_27[7] = {'e', 'r', 'a', 'I', 'e', 'n', 't'}; +static const symbol s_6_28[6] = {0xC3, 0xA8, 'r', 'e', 'n', 't'}; +static const symbol s_6_29[6] = {'a', 's', 's', 'e', 'n', 't'}; +static const symbol s_6_30[5] = {'e', 'r', 'o', 'n', 't'}; +static const symbol s_6_31[3] = {0xC3, 0xA2, 't'}; +static const symbol s_6_32[2] = {'e', 'z'}; +static const symbol s_6_33[3] = {'i', 'e', 'z'}; +static const symbol s_6_34[5] = {'e', 'r', 'i', 'e', 'z'}; +static const symbol s_6_35[6] = {'a', 's', 's', 'i', 'e', 'z'}; +static const symbol s_6_36[4] = {'e', 'r', 'e', 'z'}; +static const symbol s_6_37[2] = {0xC3, 0xA9}; + +static const struct among a_6[38] = { + /* 0 */ {1, s_6_0, -1, 3, 0}, + /* 1 */ {3, s_6_1, 0, 2, 0}, + /* 2 */ {4, s_6_2, -1, 3, 0}, + /* 3 */ {4, s_6_3, -1, 3, 0}, + /* 4 */ {3, s_6_4, -1, 2, 0}, + /* 5 */ {2, s_6_5, -1, 3, 0}, + /* 6 */ {4, s_6_6, 5, 2, 0}, + /* 7 */ {2, s_6_7, -1, 2, 0}, + /* 8 */ {2, s_6_8, -1, 3, 0}, + /* 9 */ {4, s_6_9, 8, 2, 0}, + /* 10 */ {5, s_6_10, -1, 3, 0}, + /* 11 */ {5, s_6_11, -1, 3, 0}, + /* 12 */ {5, s_6_12, -1, 3, 0}, + /* 13 */ {5, s_6_13, -1, 3, 0}, + /* 14 */ {4, s_6_14, -1, 2, 0}, + /* 15 */ {3, s_6_15, -1, 3, 0}, + /* 16 */ {5, s_6_16, 15, 2, 0}, + /* 17 */ {4, s_6_17, -1, 1, 0}, + /* 18 */ {6, s_6_18, 17, 2, 0}, + /* 19 */ {7, s_6_19, 17, 3, 0}, + /* 20 */ {5, s_6_20, -1, 2, 0}, + /* 21 */ {4, s_6_21, -1, 3, 0}, + /* 22 */ {3, s_6_22, -1, 2, 0}, + /* 23 */ {3, s_6_23, -1, 3, 0}, + /* 24 */ {5, s_6_24, 23, 2, 0}, + /* 25 */ {3, s_6_25, -1, 3, 0}, + /* 26 */ {5, s_6_26, -1, 3, 0}, + /* 27 */ {7, s_6_27, 26, 2, 0}, + /* 28 */ {6, s_6_28, -1, 2, 0}, + /* 29 */ {6, s_6_29, -1, 3, 0}, + /* 30 */ {5, s_6_30, -1, 2, 0}, + /* 31 */ {3, s_6_31, -1, 3, 0}, + /* 32 */ {2, s_6_32, -1, 2, 0}, + /* 33 */ {3, s_6_33, 32, 2, 0}, + /* 34 */ {5, s_6_34, 33, 2, 0}, + /* 35 */ {6, s_6_35, 33, 3, 0}, + /* 36 */ {4, s_6_36, 32, 2, 0}, + /* 37 */ {2, s_6_37, -1, 2, 0}}; + +static const symbol s_7_0[1] = {'e'}; +static const symbol s_7_1[5] = {'I', 0xC3, 0xA8, 'r', 'e'}; +static const symbol s_7_2[5] = {'i', 0xC3, 0xA8, 'r', 'e'}; +static const symbol s_7_3[3] = {'i', 'o', 'n'}; +static const symbol s_7_4[3] = {'I', 'e', 'r'}; +static const symbol s_7_5[3] = {'i', 'e', 'r'}; +static const symbol s_7_6[2] = {0xC3, 0xAB}; + +static const struct among a_7[7] = { + /* 0 */ {1, s_7_0, -1, 3, 0}, + /* 1 */ {5, s_7_1, 0, 2, 0}, + /* 2 */ {5, s_7_2, 0, 2, 0}, + /* 3 */ {3, s_7_3, -1, 1, 0}, + /* 4 */ {3, s_7_4, -1, 2, 0}, + /* 5 */ {3, s_7_5, -1, 2, 0}, + /* 6 */ {2, s_7_6, -1, 4, 0}}; + +static const symbol s_8_0[3] = {'e', 'l', 'l'}; +static const symbol s_8_1[4] = {'e', 'i', 'l', 'l'}; +static const symbol s_8_2[3] = {'e', 'n', 'n'}; +static const symbol s_8_3[3] = {'o', 'n', 'n'}; +static const symbol s_8_4[3] = {'e', 't', 't'}; + +static const struct among a_8[5] = { + /* 0 */ {3, s_8_0, -1, -1, 0}, + /* 1 */ {4, s_8_1, -1, -1, 0}, + /* 2 */ {3, s_8_2, -1, -1, 0}, + /* 3 */ {3, s_8_3, -1, -1, 0}, + /* 4 */ {3, s_8_4, -1, -1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 130, 103, 8, 5}; + +static const unsigned char g_keep_with_s[] = {1, 65, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128}; + +static const symbol s_0[] = {'u'}; +static const symbol s_1[] = {'U'}; +static const symbol s_2[] = {'i'}; +static const symbol s_3[] = {'I'}; +static const symbol s_4[] = {'y'}; +static const symbol s_5[] = {'Y'}; +static const symbol s_6[] = {'y'}; +static const symbol s_7[] = {'Y'}; +static const symbol s_8[] = {'q'}; +static const symbol s_9[] = {'u'}; +static const symbol s_10[] = {'U'}; +static const symbol s_11[] = {'i'}; +static const symbol s_12[] = {'u'}; +static const symbol s_13[] = {'y'}; +static const symbol s_14[] = {'i', 'c'}; +static const symbol s_15[] = {'i', 'q', 'U'}; +static const symbol s_16[] = {'l', 'o', 'g'}; +static const symbol s_17[] = {'u'}; +static const symbol s_18[] = {'e', 'n', 't'}; +static const symbol s_19[] = {'a', 't'}; +static const symbol s_20[] = {'e', 'u', 'x'}; +static const symbol s_21[] = {'i'}; +static const symbol s_22[] = {'a', 'b', 'l'}; +static const symbol s_23[] = {'i', 'q', 'U'}; +static const symbol s_24[] = {'a', 't'}; +static const symbol s_25[] = {'i', 'c'}; +static const symbol s_26[] = {'i', 'q', 'U'}; +static const symbol s_27[] = {'e', 'a', 'u'}; +static const symbol s_28[] = {'a', 'l'}; +static const symbol s_29[] = {'e', 'u', 'x'}; +static const symbol s_30[] = {'a', 'n', 't'}; +static const symbol s_31[] = {'e', 'n', 't'}; +static const symbol s_32[] = {'e'}; +static const symbol s_33[] = {'s'}; +static const symbol s_34[] = {'s'}; +static const symbol s_35[] = {'t'}; +static const symbol s_36[] = {'i'}; +static const symbol s_37[] = {'g', 'u'}; +static const symbol s_38[] = {0xC3, 0xA9}; +static const symbol s_39[] = {0xC3, 0xA8}; +static const symbol s_40[] = {'e'}; +static const symbol s_41[] = {'Y'}; +static const symbol s_42[] = {'i'}; +static const symbol s_43[] = {0xC3, 0xA7}; +static const symbol s_44[] = {'c'}; + +static int r_prelude(struct SN_env *z) { + while (1) { /* repeat, line 38 */ + int c1 = z->c; + while (1) { /* goto, line 38 */ + int c2 = z->c; + { + int c3 = z->c; /* or, line 44 */ + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab3; + z->bra = z->c; /* [, line 40 */ + { + int c4 = z->c; /* or, line 40 */ + if (!(eq_s(z, 1, s_0))) + goto lab5; + z->ket = z->c; /* ], line 40 */ + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab5; + { + int ret = slice_from_s(z, 1, s_1); /* <-, line 40 */ + if (ret < 0) + return ret; + } + goto lab4; + lab5: + z->c = c4; + if (!(eq_s(z, 1, s_2))) + goto lab6; + z->ket = z->c; /* ], line 41 */ + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab6; + { + int ret = slice_from_s(z, 1, s_3); /* <-, line 41 */ + if (ret < 0) + return ret; + } + goto lab4; + lab6: + z->c = c4; + if (!(eq_s(z, 1, s_4))) + goto lab3; + z->ket = z->c; /* ], line 42 */ + { + int ret = slice_from_s(z, 1, s_5); /* <-, line 42 */ + if (ret < 0) + return ret; + } + } + lab4: + goto lab2; + lab3: + z->c = c3; + z->bra = z->c; /* [, line 45 */ + if (!(eq_s(z, 1, s_6))) + goto lab7; + z->ket = z->c; /* ], line 45 */ + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab7; + { + int ret = slice_from_s(z, 1, s_7); /* <-, line 45 */ + if (ret < 0) + return ret; + } + goto lab2; + lab7: + z->c = c3; + if (!(eq_s(z, 1, s_8))) + goto lab1; + z->bra = z->c; /* [, line 47 */ + if (!(eq_s(z, 1, s_9))) + goto lab1; + z->ket = z->c; /* ], line 47 */ + { + int ret = slice_from_s(z, 1, s_10); /* <-, line 47 */ + if (ret < 0) + return ret; + } + } + lab2: + z->c = c2; + break; + lab1: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* goto, line 38 */ + } + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + z->I[2] = z->l; + { + int c1 = z->c; /* do, line 56 */ + { + int c2 = z->c; /* or, line 58 */ + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab2; + if (in_grouping_U(z, g_v, 97, 251, 0)) + goto lab2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab2; + z->c = ret; /* next, line 57 */ + } + goto lab1; + lab2: + z->c = c2; + if (z->c + 2 >= z->l || z->p[z->c + 2] >> 5 != 3 || !((331776 >> (z->p[z->c + 2] & 0x1f)) & 1)) + goto lab3; + if (!(find_among(z, a_0, 3))) + goto lab3; /* among, line 59 */ + goto lab1; + lab3: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 66 */ + } + { /* gopast */ /* grouping v, line 66 */ + int ret = out_grouping_U(z, g_v, 97, 251, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + } + lab1: + z->I[0] = z->c; /* setmark pV, line 67 */ + lab0: + z->c = c1; + } + { + int c3 = z->c; /* do, line 69 */ + { /* gopast */ /* grouping v, line 70 */ + int ret = out_grouping_U(z, g_v, 97, 251, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + { /* gopast */ /* non v, line 70 */ + int ret = in_grouping_U(z, g_v, 97, 251, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + z->I[1] = z->c; /* setmark p1, line 70 */ + { /* gopast */ /* grouping v, line 71 */ + int ret = out_grouping_U(z, g_v, 97, 251, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + { /* gopast */ /* non v, line 71 */ + int ret = in_grouping_U(z, g_v, 97, 251, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + z->I[2] = z->c; /* setmark p2, line 71 */ + lab4: + z->c = c3; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 75 */ + int c1 = z->c; + z->bra = z->c; /* [, line 77 */ + if (z->c >= z->l || z->p[z->c + 0] >> 5 != 2 || !((35652096 >> (z->p[z->c + 0] & 0x1f)) & 1)) + among_var = 4; + else + among_var = find_among(z, a_1, 4); /* substring, line 77 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 77 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_11); /* <-, line 78 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_12); /* <-, line 79 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_13); /* <-, line 80 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 81 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_RV(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[2] <= z->c)) + return 0; + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 92 */ + among_var = find_among_b(z, a_4, 43); /* substring, line 92 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 92 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 96 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 96 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 99 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 99 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 100 */ + z->ket = z->c; /* [, line 100 */ + if (!(eq_s_b(z, 2, s_14))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 100 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 100 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab2; /* call R2, line 100 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 100 */ + if (ret < 0) + return ret; + } + goto lab1; + lab2: + z->c = z->l - m1; + { + int ret = slice_from_s(z, 3, s_15); /* <-, line 100 */ + if (ret < 0) + return ret; + } + } + lab1: + lab0:; + } + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 104 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_16); /* <-, line 104 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 107 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 1, s_17); /* <-, line 107 */ + if (ret < 0) + return ret; + } + break; + case 5: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 110 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_18); /* <-, line 110 */ + if (ret < 0) + return ret; + } + break; + case 6: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 114 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 114 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 115 */ + z->ket = z->c; /* [, line 116 */ + among_var = find_among_b(z, a_2, 6); /* substring, line 116 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 116 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab3; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 117 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 117 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 117 */ + if (!(eq_s_b(z, 2, s_19))) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 117 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 117 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 117 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int m2 = z->l - z->c; + (void)m2; /* or, line 118 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab5; /* call R2, line 118 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 118 */ + if (ret < 0) + return ret; + } + goto lab4; + lab5: + z->c = z->l - m2; + { + int ret = r_R1(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R1, line 118 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_20); /* <-, line 118 */ + if (ret < 0) + return ret; + } + } + lab4: + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 120 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 120 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_RV(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call RV, line 122 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 1, s_21); /* <-, line 122 */ + if (ret < 0) + return ret; + } + break; + } + lab3:; + } + break; + case 7: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 129 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 129 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 130 */ + z->ket = z->c; /* [, line 131 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4198408 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab6; + } + among_var = find_among_b(z, a_3, 3); /* substring, line 131 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab6; + } + z->bra = z->c; /* ], line 131 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab6; + } + case 1: { + int m3 = z->l - z->c; + (void)m3; /* or, line 132 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab8; /* call R2, line 132 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 132 */ + if (ret < 0) + return ret; + } + goto lab7; + lab8: + z->c = z->l - m3; + { + int ret = slice_from_s(z, 3, s_22); /* <-, line 132 */ + if (ret < 0) + return ret; + } + } + lab7: + break; + case 2: { + int m4 = z->l - z->c; + (void)m4; /* or, line 133 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab10; /* call R2, line 133 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 133 */ + if (ret < 0) + return ret; + } + goto lab9; + lab10: + z->c = z->l - m4; + { + int ret = slice_from_s(z, 3, s_23); /* <-, line 133 */ + if (ret < 0) + return ret; + } + } + lab9: + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab6; + } /* call R2, line 134 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 134 */ + if (ret < 0) + return ret; + } + break; + } + lab6:; + } + break; + case 8: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 141 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 141 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 142 */ + z->ket = z->c; /* [, line 142 */ + if (!(eq_s_b(z, 2, s_24))) { + z->c = z->l - m_keep; + goto lab11; + } + z->bra = z->c; /* ], line 142 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab11; + } /* call R2, line 142 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 142 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 142 */ + if (!(eq_s_b(z, 2, s_25))) { + z->c = z->l - m_keep; + goto lab11; + } + z->bra = z->c; /* ], line 142 */ + { + int m5 = z->l - z->c; + (void)m5; /* or, line 142 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab13; /* call R2, line 142 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 142 */ + if (ret < 0) + return ret; + } + goto lab12; + lab13: + z->c = z->l - m5; + { + int ret = slice_from_s(z, 3, s_26); /* <-, line 142 */ + if (ret < 0) + return ret; + } + } + lab12: + lab11:; + } + break; + case 9: { + int ret = slice_from_s(z, 3, s_27); /* <-, line 144 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 145 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 2, s_28); /* <-, line 145 */ + if (ret < 0) + return ret; + } + break; + case 11: { + int m6 = z->l - z->c; + (void)m6; /* or, line 147 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab15; /* call R2, line 147 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 147 */ + if (ret < 0) + return ret; + } + goto lab14; + lab15: + z->c = z->l - m6; + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 147 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_29); /* <-, line 147 */ + if (ret < 0) + return ret; + } + } + lab14: + break; + case 12: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 150 */ + if (ret < 0) + return ret; + } + if (out_grouping_b_U(z, g_v, 97, 251, 0)) + return 0; + { + int ret = slice_del(z); /* delete, line 150 */ + if (ret < 0) + return ret; + } + break; + case 13: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 155 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_30); /* <-, line 155 */ + if (ret < 0) + return ret; + } + return 0; /* fail, line 155 */ + break; + case 14: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 156 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_31); /* <-, line 156 */ + if (ret < 0) + return ret; + } + return 0; /* fail, line 156 */ + break; + case 15: { + int m_test = z->l - z->c; /* test, line 158 */ + if (in_grouping_b_U(z, g_v, 97, 251, 0)) + return 0; + { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 158 */ + if (ret < 0) + return ret; + } + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 158 */ + if (ret < 0) + return ret; + } + return 0; /* fail, line 158 */ + break; + } + return 1; +} + +static int r_i_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 163 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 163 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 164 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((68944418 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_5, 35); /* substring, line 164 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 164 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: + if (out_grouping_b_U(z, g_v, 97, 251, 0)) { + z->lb = mlimit; + return 0; + } + { + int ret = slice_del(z); /* delete, line 170 */ + if (ret < 0) + return ret; + } + break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 174 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 174 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 175 */ + among_var = find_among_b(z, a_6, 38); /* substring, line 175 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 175 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->lb = mlimit; + return 0; + } /* call R2, line 177 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 177 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = slice_del(z); /* delete, line 185 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_del(z); /* delete, line 190 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 191 */ + z->ket = z->c; /* [, line 191 */ + if (!(eq_s_b(z, 1, s_32))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 191 */ + { + int ret = slice_del(z); /* delete, line 191 */ + if (ret < 0) + return ret; + } + lab0:; + } + break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_residual_suffix(struct SN_env *z) { + int among_var; + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 199 */ + z->ket = z->c; /* [, line 199 */ + if (!(eq_s_b(z, 1, s_33))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 199 */ + { + int m_test = z->l - z->c; /* test, line 199 */ + if (out_grouping_b_U(z, g_keep_with_s, 97, 232, 0)) { + z->c = z->l - m_keep; + goto lab0; + } + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 199 */ + if (ret < 0) + return ret; + } + lab0:; + } + { + int mlimit; /* setlimit, line 200 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 200 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 201 */ + among_var = find_among_b(z, a_7, 7); /* substring, line 201 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 201 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->lb = mlimit; + return 0; + } /* call R2, line 202 */ + if (ret < 0) + return ret; + } + { + int m2 = z->l - z->c; + (void)m2; /* or, line 202 */ + if (!(eq_s_b(z, 1, s_34))) + goto lab2; + goto lab1; + lab2: + z->c = z->l - m2; + if (!(eq_s_b(z, 1, s_35))) { + z->lb = mlimit; + return 0; + } + } + lab1: { + int ret = slice_del(z); /* delete, line 202 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_36); /* <-, line 204 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_del(z); /* delete, line 205 */ + if (ret < 0) + return ret; + } break; + case 4: + if (!(eq_s_b(z, 2, s_37))) { + z->lb = mlimit; + return 0; + } + { + int ret = slice_del(z); /* delete, line 206 */ + if (ret < 0) + return ret; + } + break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_un_double(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 212 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1069056 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_8, 5))) + return 0; /* among, line 212 */ + z->c = z->l - m_test; + } + z->ket = z->c; /* [, line 212 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 212 */ + } + z->bra = z->c; /* ], line 212 */ + { + int ret = slice_del(z); /* delete, line 212 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_un_accent(struct SN_env *z) { + { + int i = 1; + while (1) { /* atleast, line 216 */ + if (out_grouping_b_U(z, g_v, 97, 251, 0)) + goto lab0; + i--; + continue; + lab0: + break; + } + if (i > 0) + return 0; + } + z->ket = z->c; /* [, line 217 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 217 */ + if (!(eq_s_b(z, 2, s_38))) + goto lab2; + goto lab1; + lab2: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_39))) + return 0; + } +lab1: + z->bra = z->c; /* ], line 217 */ + { + int ret = slice_from_s(z, 1, s_40); /* <-, line 217 */ + if (ret < 0) + return ret; + } + return 1; +} + +extern int french_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 223 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 223 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 224 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 224 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 225 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 227 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 237 */ + { + int m5 = z->l - z->c; + (void)m5; /* and, line 233 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 229 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab6; /* call standard_suffix, line 229 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = z->l - m6; + { + int ret = r_i_verb_suffix(z); + if (ret == 0) + goto lab7; /* call i_verb_suffix, line 230 */ + if (ret < 0) + return ret; + } + goto lab5; + lab7: + z->c = z->l - m6; + { + int ret = r_verb_suffix(z); + if (ret == 0) + goto lab4; /* call verb_suffix, line 231 */ + if (ret < 0) + return ret; + } + } + lab5: + z->c = z->l - m5; + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 234 */ + z->ket = z->c; /* [, line 234 */ + { + int m7 = z->l - z->c; + (void)m7; /* or, line 234 */ + if (!(eq_s_b(z, 1, s_41))) + goto lab10; + z->bra = z->c; /* ], line 234 */ + { + int ret = slice_from_s(z, 1, s_42); /* <-, line 234 */ + if (ret < 0) + return ret; + } + goto lab9; + lab10: + z->c = z->l - m7; + if (!(eq_s_b(z, 2, s_43))) { + z->c = z->l - m_keep; + goto lab8; + } + z->bra = z->c; /* ], line 235 */ + { + int ret = slice_from_s(z, 1, s_44); /* <-, line 235 */ + if (ret < 0) + return ret; + } + } + lab9: + lab8:; + } + } + goto lab3; + lab4: + z->c = z->l - m4; + { + int ret = r_residual_suffix(z); + if (ret == 0) + goto lab2; /* call residual_suffix, line 238 */ + if (ret < 0) + return ret; + } + } + lab3: + lab2: + z->c = z->l - m3; + } + { + int m8 = z->l - z->c; + (void)m8; /* do, line 243 */ + { + int ret = r_un_double(z); + if (ret == 0) + goto lab11; /* call un_double, line 243 */ + if (ret < 0) + return ret; + } + lab11: + z->c = z->l - m8; + } + { + int m9 = z->l - z->c; + (void)m9; /* do, line 244 */ + { + int ret = r_un_accent(z); + if (ret == 0) + goto lab12; /* call un_accent, line 244 */ + if (ret < 0) + return ret; + } + lab12: + z->c = z->l - m9; + } + z->c = z->lb; + { + int c10 = z->c; /* do, line 246 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab13; /* call postlude, line 246 */ + if (ret < 0) + return ret; + } + lab13: + z->c = c10; + } + return 1; +} + +extern struct SN_env *french_UTF_8_create_env(void) { return SN_create_env(0, 3, 0); } + +extern void french_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_french.h b/internal/cpp/stemmer/stem_UTF_8_french.h new file mode 100644 index 00000000000..780b078745f --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_french.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *french_UTF_8_create_env(void); +extern void french_UTF_8_close_env(struct SN_env *z); + +extern int french_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_german.cpp b/internal/cpp/stemmer/stem_UTF_8_german.cpp new file mode 100644 index 00000000000..63a273ecec8 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_german.cpp @@ -0,0 +1,626 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int german_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_standard_suffix(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *german_UTF_8_create_env(void); +extern void german_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[1] = {'U'}; +static const symbol s_0_2[1] = {'Y'}; +static const symbol s_0_3[2] = {0xC3, 0xA4}; +static const symbol s_0_4[2] = {0xC3, 0xB6}; +static const symbol s_0_5[2] = {0xC3, 0xBC}; + +static const struct among a_0[6] = { + /* 0 */ {0, 0, -1, 6, 0}, + /* 1 */ {1, s_0_1, 0, 2, 0}, + /* 2 */ {1, s_0_2, 0, 1, 0}, + /* 3 */ {2, s_0_3, 0, 3, 0}, + /* 4 */ {2, s_0_4, 0, 4, 0}, + /* 5 */ {2, s_0_5, 0, 5, 0}}; + +static const symbol s_1_0[1] = {'e'}; +static const symbol s_1_1[2] = {'e', 'm'}; +static const symbol s_1_2[2] = {'e', 'n'}; +static const symbol s_1_3[3] = {'e', 'r', 'n'}; +static const symbol s_1_4[2] = {'e', 'r'}; +static const symbol s_1_5[1] = {'s'}; +static const symbol s_1_6[2] = {'e', 's'}; + +static const struct among a_1[7] = { + /* 0 */ {1, s_1_0, -1, 1, 0}, + /* 1 */ {2, s_1_1, -1, 1, 0}, + /* 2 */ {2, s_1_2, -1, 1, 0}, + /* 3 */ {3, s_1_3, -1, 1, 0}, + /* 4 */ {2, s_1_4, -1, 1, 0}, + /* 5 */ {1, s_1_5, -1, 2, 0}, + /* 6 */ {2, s_1_6, 5, 1, 0}}; + +static const symbol s_2_0[2] = {'e', 'n'}; +static const symbol s_2_1[2] = {'e', 'r'}; +static const symbol s_2_2[2] = {'s', 't'}; +static const symbol s_2_3[3] = {'e', 's', 't'}; + +static const struct among a_2[4] = { + /* 0 */ {2, s_2_0, -1, 1, 0}, + /* 1 */ {2, s_2_1, -1, 1, 0}, + /* 2 */ {2, s_2_2, -1, 2, 0}, + /* 3 */ {3, s_2_3, 2, 1, 0}}; + +static const symbol s_3_0[2] = {'i', 'g'}; +static const symbol s_3_1[4] = {'l', 'i', 'c', 'h'}; + +static const struct among a_3[2] = { + /* 0 */ {2, s_3_0, -1, 1, 0}, + /* 1 */ {4, s_3_1, -1, 1, 0}}; + +static const symbol s_4_0[3] = {'e', 'n', 'd'}; +static const symbol s_4_1[2] = {'i', 'g'}; +static const symbol s_4_2[3] = {'u', 'n', 'g'}; +static const symbol s_4_3[4] = {'l', 'i', 'c', 'h'}; +static const symbol s_4_4[4] = {'i', 's', 'c', 'h'}; +static const symbol s_4_5[2] = {'i', 'k'}; +static const symbol s_4_6[4] = {'h', 'e', 'i', 't'}; +static const symbol s_4_7[4] = {'k', 'e', 'i', 't'}; + +static const struct among a_4[8] = { + /* 0 */ {3, s_4_0, -1, 1, 0}, + /* 1 */ {2, s_4_1, -1, 2, 0}, + /* 2 */ {3, s_4_2, -1, 1, 0}, + /* 3 */ {4, s_4_3, -1, 3, 0}, + /* 4 */ {4, s_4_4, -1, 2, 0}, + /* 5 */ {2, s_4_5, -1, 2, 0}, + /* 6 */ {4, s_4_6, -1, 3, 0}, + /* 7 */ {4, s_4_7, -1, 4, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 32, 8}; + +static const unsigned char g_s_ending[] = {117, 30, 5}; + +static const unsigned char g_st_ending[] = {117, 30, 4}; + +static const symbol s_0[] = {0xC3, 0x9F}; +static const symbol s_1[] = {'s', 's'}; +static const symbol s_2[] = {'u'}; +static const symbol s_3[] = {'U'}; +static const symbol s_4[] = {'y'}; +static const symbol s_5[] = {'Y'}; +static const symbol s_6[] = {'y'}; +static const symbol s_7[] = {'u'}; +static const symbol s_8[] = {'a'}; +static const symbol s_9[] = {'o'}; +static const symbol s_10[] = {'u'}; +static const symbol s_11[] = {'i', 'g'}; +static const symbol s_12[] = {'e'}; +static const symbol s_13[] = {'e'}; +static const symbol s_14[] = {'e', 'r'}; +static const symbol s_15[] = {'e', 'n'}; + +static int r_prelude(struct SN_env *z) { + { + int c_test = z->c; /* test, line 30 */ + while (1) { /* repeat, line 30 */ + int c1 = z->c; + { + int c2 = z->c; /* or, line 33 */ + z->bra = z->c; /* [, line 32 */ + if (!(eq_s(z, 2, s_0))) + goto lab2; + z->ket = z->c; /* ], line 32 */ + { + int ret = slice_from_s(z, 2, s_1); /* <-, line 32 */ + if (ret < 0) + return ret; + } + goto lab1; + lab2: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 33 */ + } + } + lab1: + continue; + lab0: + z->c = c1; + break; + } + z->c = c_test; + } + while (1) { /* repeat, line 36 */ + int c3 = z->c; + while (1) { /* goto, line 36 */ + int c4 = z->c; + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab4; + z->bra = z->c; /* [, line 37 */ + { + int c5 = z->c; /* or, line 37 */ + if (!(eq_s(z, 1, s_2))) + goto lab6; + z->ket = z->c; /* ], line 37 */ + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab6; + { + int ret = slice_from_s(z, 1, s_3); /* <-, line 37 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = c5; + if (!(eq_s(z, 1, s_4))) + goto lab4; + z->ket = z->c; /* ], line 38 */ + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab4; + { + int ret = slice_from_s(z, 1, s_5); /* <-, line 38 */ + if (ret < 0) + return ret; + } + } + lab5: + z->c = c4; + break; + lab4: + z->c = c4; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab3; + z->c = ret; /* goto, line 36 */ + } + } + continue; + lab3: + z->c = c3; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + { + int c_test = z->c; /* test, line 47 */ + { + int ret = skip_utf8(z->p, z->c, 0, z->l, +3); + if (ret < 0) + return 0; + z->c = ret; /* hop, line 47 */ + } + z->I[2] = z->c; /* setmark x, line 47 */ + z->c = c_test; + } + { /* gopast */ /* grouping v, line 49 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + return 0; + z->c += ret; + } + { /* gopast */ /* non v, line 49 */ + int ret = in_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 49 */ + /* try, line 50 */ + if (!(z->I[0] < z->I[2])) + goto lab0; + z->I[0] = z->I[2]; +lab0: { /* gopast */ /* grouping v, line 51 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + return 0; + z->c += ret; +} + { /* gopast */ /* non v, line 51 */ + int ret = in_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 51 */ + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 55 */ + int c1 = z->c; + z->bra = z->c; /* [, line 57 */ + among_var = find_among(z, a_0, 6); /* substring, line 57 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 57 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_6); /* <-, line 58 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_7); /* <-, line 59 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_8); /* <-, line 60 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_9); /* <-, line 61 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 1, s_10); /* <-, line 62 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 63 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + { + int m1 = z->l - z->c; + (void)m1; /* do, line 74 */ + z->ket = z->c; /* [, line 75 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((811040 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab0; + among_var = find_among_b(z, a_1, 7); /* substring, line 75 */ + if (!(among_var)) + goto lab0; + z->bra = z->c; /* ], line 75 */ + { + int ret = r_R1(z); + if (ret == 0) + goto lab0; /* call R1, line 75 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_del(z); /* delete, line 77 */ + if (ret < 0) + return ret; + } break; + case 2: + if (in_grouping_b_U(z, g_s_ending, 98, 116, 0)) + goto lab0; + { + int ret = slice_del(z); /* delete, line 80 */ + if (ret < 0) + return ret; + } + break; + } + lab0: + z->c = z->l - m1; + } + { + int m2 = z->l - z->c; + (void)m2; /* do, line 84 */ + z->ket = z->c; /* [, line 85 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1327104 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab1; + among_var = find_among_b(z, a_2, 4); /* substring, line 85 */ + if (!(among_var)) + goto lab1; + z->bra = z->c; /* ], line 85 */ + { + int ret = r_R1(z); + if (ret == 0) + goto lab1; /* call R1, line 85 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + goto lab1; + case 1: { + int ret = slice_del(z); /* delete, line 87 */ + if (ret < 0) + return ret; + } break; + case 2: + if (in_grouping_b_U(z, g_st_ending, 98, 116, 0)) + goto lab1; + { + int ret = skip_utf8(z->p, z->c, z->lb, z->l, -3); + if (ret < 0) + goto lab1; + z->c = ret; /* hop, line 90 */ + } + { + int ret = slice_del(z); /* delete, line 90 */ + if (ret < 0) + return ret; + } + break; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 94 */ + z->ket = z->c; /* [, line 95 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1051024 >> (z->p[z->c - 1] & 0x1f)) & 1)) + goto lab2; + among_var = find_among_b(z, a_4, 8); /* substring, line 95 */ + if (!(among_var)) + goto lab2; + z->bra = z->c; /* ], line 95 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab2; /* call R2, line 95 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + goto lab2; + case 1: { + int ret = slice_del(z); /* delete, line 97 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 98 */ + z->ket = z->c; /* [, line 98 */ + if (!(eq_s_b(z, 2, s_11))) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 98 */ + { + int m4 = z->l - z->c; + (void)m4; /* not, line 98 */ + if (!(eq_s_b(z, 1, s_12))) + goto lab4; + { + z->c = z->l - m_keep; + goto lab3; + } + lab4: + z->c = z->l - m4; + } + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 98 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 98 */ + if (ret < 0) + return ret; + } + lab3:; + } + break; + case 2: { + int m5 = z->l - z->c; + (void)m5; /* not, line 101 */ + if (!(eq_s_b(z, 1, s_13))) + goto lab5; + goto lab2; + lab5: + z->c = z->l - m5; + } + { + int ret = slice_del(z); /* delete, line 101 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = slice_del(z); /* delete, line 104 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 105 */ + z->ket = z->c; /* [, line 106 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 106 */ + if (!(eq_s_b(z, 2, s_14))) + goto lab8; + goto lab7; + lab8: + z->c = z->l - m6; + if (!(eq_s_b(z, 2, s_15))) { + z->c = z->l - m_keep; + goto lab6; + } + } + lab7: + z->bra = z->c; /* ], line 106 */ + { + int ret = r_R1(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab6; + } /* call R1, line 106 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 106 */ + if (ret < 0) + return ret; + } + lab6:; + } + break; + case 4: { + int ret = slice_del(z); /* delete, line 110 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 111 */ + z->ket = z->c; /* [, line 112 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 103 && z->p[z->c - 1] != 104)) { + z->c = z->l - m_keep; + goto lab9; + } + among_var = find_among_b(z, a_3, 2); /* substring, line 112 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab9; + } + z->bra = z->c; /* ], line 112 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab9; + } /* call R2, line 112 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab9; + } + case 1: { + int ret = slice_del(z); /* delete, line 114 */ + if (ret < 0) + return ret; + } break; + } + lab9:; + } + break; + } + lab2: + z->c = z->l - m3; + } + return 1; +} + +extern int german_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 125 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 125 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 126 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 126 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 127 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 128 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab2; /* call standard_suffix, line 128 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + z->c = z->lb; + { + int c4 = z->c; /* do, line 129 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab3; /* call postlude, line 129 */ + if (ret < 0) + return ret; + } + lab3: + z->c = c4; + } + return 1; +} + +extern struct SN_env *german_UTF_8_create_env(void) { return SN_create_env(0, 3, 0); } + +extern void german_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_german.h b/internal/cpp/stemmer/stem_UTF_8_german.h new file mode 100644 index 00000000000..69df3507e89 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_german.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *german_UTF_8_create_env(void); +extern void german_UTF_8_close_env(struct SN_env *z); + +extern int german_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_hungarian.cpp b/internal/cpp/stemmer/stem_UTF_8_hungarian.cpp new file mode 100644 index 00000000000..a97ad36982f --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_hungarian.cpp @@ -0,0 +1,1353 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int hungarian_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_double(struct SN_env *z); +static int r_undouble(struct SN_env *z); +static int r_factive(struct SN_env *z); +static int r_instrum(struct SN_env *z); +static int r_plur_owner(struct SN_env *z); +static int r_sing_owner(struct SN_env *z); +static int r_owned(struct SN_env *z); +static int r_plural(struct SN_env *z); +static int r_case_other(struct SN_env *z); +static int r_case_special(struct SN_env *z); +static int r_case(struct SN_env *z); +static int r_v_ending(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *hungarian_UTF_8_create_env(void); +extern void hungarian_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[2] = {'c', 's'}; +static const symbol s_0_1[3] = {'d', 'z', 's'}; +static const symbol s_0_2[2] = {'g', 'y'}; +static const symbol s_0_3[2] = {'l', 'y'}; +static const symbol s_0_4[2] = {'n', 'y'}; +static const symbol s_0_5[2] = {'s', 'z'}; +static const symbol s_0_6[2] = {'t', 'y'}; +static const symbol s_0_7[2] = {'z', 's'}; + +static const struct among a_0[8] = { + /* 0 */ {2, s_0_0, -1, -1, 0}, + /* 1 */ {3, s_0_1, -1, -1, 0}, + /* 2 */ {2, s_0_2, -1, -1, 0}, + /* 3 */ {2, s_0_3, -1, -1, 0}, + /* 4 */ {2, s_0_4, -1, -1, 0}, + /* 5 */ {2, s_0_5, -1, -1, 0}, + /* 6 */ {2, s_0_6, -1, -1, 0}, + /* 7 */ {2, s_0_7, -1, -1, 0}}; + +static const symbol s_1_0[2] = {0xC3, 0xA1}; +static const symbol s_1_1[2] = {0xC3, 0xA9}; + +static const struct among a_1[2] = { + /* 0 */ {2, s_1_0, -1, 1, 0}, + /* 1 */ {2, s_1_1, -1, 2, 0}}; + +static const symbol s_2_0[2] = {'b', 'b'}; +static const symbol s_2_1[2] = {'c', 'c'}; +static const symbol s_2_2[2] = {'d', 'd'}; +static const symbol s_2_3[2] = {'f', 'f'}; +static const symbol s_2_4[2] = {'g', 'g'}; +static const symbol s_2_5[2] = {'j', 'j'}; +static const symbol s_2_6[2] = {'k', 'k'}; +static const symbol s_2_7[2] = {'l', 'l'}; +static const symbol s_2_8[2] = {'m', 'm'}; +static const symbol s_2_9[2] = {'n', 'n'}; +static const symbol s_2_10[2] = {'p', 'p'}; +static const symbol s_2_11[2] = {'r', 'r'}; +static const symbol s_2_12[3] = {'c', 'c', 's'}; +static const symbol s_2_13[2] = {'s', 's'}; +static const symbol s_2_14[3] = {'z', 'z', 's'}; +static const symbol s_2_15[2] = {'t', 't'}; +static const symbol s_2_16[2] = {'v', 'v'}; +static const symbol s_2_17[3] = {'g', 'g', 'y'}; +static const symbol s_2_18[3] = {'l', 'l', 'y'}; +static const symbol s_2_19[3] = {'n', 'n', 'y'}; +static const symbol s_2_20[3] = {'t', 't', 'y'}; +static const symbol s_2_21[3] = {'s', 's', 'z'}; +static const symbol s_2_22[2] = {'z', 'z'}; + +static const struct among a_2[23] = { + /* 0 */ {2, s_2_0, -1, -1, 0}, + /* 1 */ {2, s_2_1, -1, -1, 0}, + /* 2 */ {2, s_2_2, -1, -1, 0}, + /* 3 */ {2, s_2_3, -1, -1, 0}, + /* 4 */ {2, s_2_4, -1, -1, 0}, + /* 5 */ {2, s_2_5, -1, -1, 0}, + /* 6 */ {2, s_2_6, -1, -1, 0}, + /* 7 */ {2, s_2_7, -1, -1, 0}, + /* 8 */ {2, s_2_8, -1, -1, 0}, + /* 9 */ {2, s_2_9, -1, -1, 0}, + /* 10 */ {2, s_2_10, -1, -1, 0}, + /* 11 */ {2, s_2_11, -1, -1, 0}, + /* 12 */ {3, s_2_12, -1, -1, 0}, + /* 13 */ {2, s_2_13, -1, -1, 0}, + /* 14 */ {3, s_2_14, -1, -1, 0}, + /* 15 */ {2, s_2_15, -1, -1, 0}, + /* 16 */ {2, s_2_16, -1, -1, 0}, + /* 17 */ {3, s_2_17, -1, -1, 0}, + /* 18 */ {3, s_2_18, -1, -1, 0}, + /* 19 */ {3, s_2_19, -1, -1, 0}, + /* 20 */ {3, s_2_20, -1, -1, 0}, + /* 21 */ {3, s_2_21, -1, -1, 0}, + /* 22 */ {2, s_2_22, -1, -1, 0}}; + +static const symbol s_3_0[2] = {'a', 'l'}; +static const symbol s_3_1[2] = {'e', 'l'}; + +static const struct among a_3[2] = { + /* 0 */ {2, s_3_0, -1, 1, 0}, + /* 1 */ {2, s_3_1, -1, 2, 0}}; + +static const symbol s_4_0[2] = {'b', 'a'}; +static const symbol s_4_1[2] = {'r', 'a'}; +static const symbol s_4_2[2] = {'b', 'e'}; +static const symbol s_4_3[2] = {'r', 'e'}; +static const symbol s_4_4[2] = {'i', 'g'}; +static const symbol s_4_5[3] = {'n', 'a', 'k'}; +static const symbol s_4_6[3] = {'n', 'e', 'k'}; +static const symbol s_4_7[3] = {'v', 'a', 'l'}; +static const symbol s_4_8[3] = {'v', 'e', 'l'}; +static const symbol s_4_9[2] = {'u', 'l'}; +static const symbol s_4_10[4] = {'n', 0xC3, 0xA1, 'l'}; +static const symbol s_4_11[4] = {'n', 0xC3, 0xA9, 'l'}; +static const symbol s_4_12[4] = {'b', 0xC3, 0xB3, 'l'}; +static const symbol s_4_13[4] = {'r', 0xC3, 0xB3, 'l'}; +static const symbol s_4_14[4] = {'t', 0xC3, 0xB3, 'l'}; +static const symbol s_4_15[4] = {'b', 0xC3, 0xB5, 'l'}; +static const symbol s_4_16[4] = {'r', 0xC3, 0xB5, 'l'}; +static const symbol s_4_17[4] = {'t', 0xC3, 0xB5, 'l'}; +static const symbol s_4_18[3] = {0xC3, 0xBC, 'l'}; +static const symbol s_4_19[1] = {'n'}; +static const symbol s_4_20[2] = {'a', 'n'}; +static const symbol s_4_21[3] = {'b', 'a', 'n'}; +static const symbol s_4_22[2] = {'e', 'n'}; +static const symbol s_4_23[3] = {'b', 'e', 'n'}; +static const symbol s_4_24[7] = {'k', 0xC3, 0xA9, 'p', 'p', 'e', 'n'}; +static const symbol s_4_25[2] = {'o', 'n'}; +static const symbol s_4_26[3] = {0xC3, 0xB6, 'n'}; +static const symbol s_4_27[5] = {'k', 0xC3, 0xA9, 'p', 'p'}; +static const symbol s_4_28[3] = {'k', 'o', 'r'}; +static const symbol s_4_29[1] = {'t'}; +static const symbol s_4_30[2] = {'a', 't'}; +static const symbol s_4_31[2] = {'e', 't'}; +static const symbol s_4_32[5] = {'k', 0xC3, 0xA9, 'n', 't'}; +static const symbol s_4_33[7] = {'a', 'n', 'k', 0xC3, 0xA9, 'n', 't'}; +static const symbol s_4_34[7] = {'e', 'n', 'k', 0xC3, 0xA9, 'n', 't'}; +static const symbol s_4_35[7] = {'o', 'n', 'k', 0xC3, 0xA9, 'n', 't'}; +static const symbol s_4_36[2] = {'o', 't'}; +static const symbol s_4_37[4] = {0xC3, 0xA9, 'r', 't'}; +static const symbol s_4_38[3] = {0xC3, 0xB6, 't'}; +static const symbol s_4_39[3] = {'h', 'e', 'z'}; +static const symbol s_4_40[3] = {'h', 'o', 'z'}; +static const symbol s_4_41[4] = {'h', 0xC3, 0xB6, 'z'}; +static const symbol s_4_42[3] = {'v', 0xC3, 0xA1}; +static const symbol s_4_43[3] = {'v', 0xC3, 0xA9}; + +static const struct among a_4[44] = { + /* 0 */ {2, s_4_0, -1, -1, 0}, + /* 1 */ {2, s_4_1, -1, -1, 0}, + /* 2 */ {2, s_4_2, -1, -1, 0}, + /* 3 */ {2, s_4_3, -1, -1, 0}, + /* 4 */ {2, s_4_4, -1, -1, 0}, + /* 5 */ {3, s_4_5, -1, -1, 0}, + /* 6 */ {3, s_4_6, -1, -1, 0}, + /* 7 */ {3, s_4_7, -1, -1, 0}, + /* 8 */ {3, s_4_8, -1, -1, 0}, + /* 9 */ {2, s_4_9, -1, -1, 0}, + /* 10 */ {4, s_4_10, -1, -1, 0}, + /* 11 */ {4, s_4_11, -1, -1, 0}, + /* 12 */ {4, s_4_12, -1, -1, 0}, + /* 13 */ {4, s_4_13, -1, -1, 0}, + /* 14 */ {4, s_4_14, -1, -1, 0}, + /* 15 */ {4, s_4_15, -1, -1, 0}, + /* 16 */ {4, s_4_16, -1, -1, 0}, + /* 17 */ {4, s_4_17, -1, -1, 0}, + /* 18 */ {3, s_4_18, -1, -1, 0}, + /* 19 */ {1, s_4_19, -1, -1, 0}, + /* 20 */ {2, s_4_20, 19, -1, 0}, + /* 21 */ {3, s_4_21, 20, -1, 0}, + /* 22 */ {2, s_4_22, 19, -1, 0}, + /* 23 */ {3, s_4_23, 22, -1, 0}, + /* 24 */ {7, s_4_24, 22, -1, 0}, + /* 25 */ {2, s_4_25, 19, -1, 0}, + /* 26 */ {3, s_4_26, 19, -1, 0}, + /* 27 */ {5, s_4_27, -1, -1, 0}, + /* 28 */ {3, s_4_28, -1, -1, 0}, + /* 29 */ {1, s_4_29, -1, -1, 0}, + /* 30 */ {2, s_4_30, 29, -1, 0}, + /* 31 */ {2, s_4_31, 29, -1, 0}, + /* 32 */ {5, s_4_32, 29, -1, 0}, + /* 33 */ {7, s_4_33, 32, -1, 0}, + /* 34 */ {7, s_4_34, 32, -1, 0}, + /* 35 */ {7, s_4_35, 32, -1, 0}, + /* 36 */ {2, s_4_36, 29, -1, 0}, + /* 37 */ {4, s_4_37, 29, -1, 0}, + /* 38 */ {3, s_4_38, 29, -1, 0}, + /* 39 */ {3, s_4_39, -1, -1, 0}, + /* 40 */ {3, s_4_40, -1, -1, 0}, + /* 41 */ {4, s_4_41, -1, -1, 0}, + /* 42 */ {3, s_4_42, -1, -1, 0}, + /* 43 */ {3, s_4_43, -1, -1, 0}}; + +static const symbol s_5_0[3] = {0xC3, 0xA1, 'n'}; +static const symbol s_5_1[3] = {0xC3, 0xA9, 'n'}; +static const symbol s_5_2[8] = {0xC3, 0xA1, 'n', 'k', 0xC3, 0xA9, 'n', 't'}; + +static const struct among a_5[3] = { + /* 0 */ {3, s_5_0, -1, 2, 0}, + /* 1 */ {3, s_5_1, -1, 1, 0}, + /* 2 */ {8, s_5_2, -1, 3, 0}}; + +static const symbol s_6_0[4] = {'s', 't', 'u', 'l'}; +static const symbol s_6_1[5] = {'a', 's', 't', 'u', 'l'}; +static const symbol s_6_2[6] = {0xC3, 0xA1, 's', 't', 'u', 'l'}; +static const symbol s_6_3[5] = {'s', 't', 0xC3, 0xBC, 'l'}; +static const symbol s_6_4[6] = {'e', 's', 't', 0xC3, 0xBC, 'l'}; +static const symbol s_6_5[7] = {0xC3, 0xA9, 's', 't', 0xC3, 0xBC, 'l'}; + +static const struct among a_6[6] = { + /* 0 */ {4, s_6_0, -1, 2, 0}, + /* 1 */ {5, s_6_1, 0, 1, 0}, + /* 2 */ {6, s_6_2, 0, 3, 0}, + /* 3 */ {5, s_6_3, -1, 2, 0}, + /* 4 */ {6, s_6_4, 3, 1, 0}, + /* 5 */ {7, s_6_5, 3, 4, 0}}; + +static const symbol s_7_0[2] = {0xC3, 0xA1}; +static const symbol s_7_1[2] = {0xC3, 0xA9}; + +static const struct among a_7[2] = { + /* 0 */ {2, s_7_0, -1, 1, 0}, + /* 1 */ {2, s_7_1, -1, 2, 0}}; + +static const symbol s_8_0[1] = {'k'}; +static const symbol s_8_1[2] = {'a', 'k'}; +static const symbol s_8_2[2] = {'e', 'k'}; +static const symbol s_8_3[2] = {'o', 'k'}; +static const symbol s_8_4[3] = {0xC3, 0xA1, 'k'}; +static const symbol s_8_5[3] = {0xC3, 0xA9, 'k'}; +static const symbol s_8_6[3] = {0xC3, 0xB6, 'k'}; + +static const struct among a_8[7] = { + /* 0 */ {1, s_8_0, -1, 7, 0}, + /* 1 */ {2, s_8_1, 0, 4, 0}, + /* 2 */ {2, s_8_2, 0, 6, 0}, + /* 3 */ {2, s_8_3, 0, 5, 0}, + /* 4 */ {3, s_8_4, 0, 1, 0}, + /* 5 */ {3, s_8_5, 0, 2, 0}, + /* 6 */ {3, s_8_6, 0, 3, 0}}; + +static const symbol s_9_0[3] = {0xC3, 0xA9, 'i'}; +static const symbol s_9_1[5] = {0xC3, 0xA1, 0xC3, 0xA9, 'i'}; +static const symbol s_9_2[5] = {0xC3, 0xA9, 0xC3, 0xA9, 'i'}; +static const symbol s_9_3[2] = {0xC3, 0xA9}; +static const symbol s_9_4[3] = {'k', 0xC3, 0xA9}; +static const symbol s_9_5[4] = {'a', 'k', 0xC3, 0xA9}; +static const symbol s_9_6[4] = {'e', 'k', 0xC3, 0xA9}; +static const symbol s_9_7[4] = {'o', 'k', 0xC3, 0xA9}; +static const symbol s_9_8[5] = {0xC3, 0xA1, 'k', 0xC3, 0xA9}; +static const symbol s_9_9[5] = {0xC3, 0xA9, 'k', 0xC3, 0xA9}; +static const symbol s_9_10[5] = {0xC3, 0xB6, 'k', 0xC3, 0xA9}; +static const symbol s_9_11[4] = {0xC3, 0xA9, 0xC3, 0xA9}; + +static const struct among a_9[12] = { + /* 0 */ {3, s_9_0, -1, 7, 0}, + /* 1 */ {5, s_9_1, 0, 6, 0}, + /* 2 */ {5, s_9_2, 0, 5, 0}, + /* 3 */ {2, s_9_3, -1, 9, 0}, + /* 4 */ {3, s_9_4, 3, 4, 0}, + /* 5 */ {4, s_9_5, 4, 1, 0}, + /* 6 */ {4, s_9_6, 4, 1, 0}, + /* 7 */ {4, s_9_7, 4, 1, 0}, + /* 8 */ {5, s_9_8, 4, 3, 0}, + /* 9 */ {5, s_9_9, 4, 2, 0}, + /* 10 */ {5, s_9_10, 4, 1, 0}, + /* 11 */ {4, s_9_11, 3, 8, 0}}; + +static const symbol s_10_0[1] = {'a'}; +static const symbol s_10_1[2] = {'j', 'a'}; +static const symbol s_10_2[1] = {'d'}; +static const symbol s_10_3[2] = {'a', 'd'}; +static const symbol s_10_4[2] = {'e', 'd'}; +static const symbol s_10_5[2] = {'o', 'd'}; +static const symbol s_10_6[3] = {0xC3, 0xA1, 'd'}; +static const symbol s_10_7[3] = {0xC3, 0xA9, 'd'}; +static const symbol s_10_8[3] = {0xC3, 0xB6, 'd'}; +static const symbol s_10_9[1] = {'e'}; +static const symbol s_10_10[2] = {'j', 'e'}; +static const symbol s_10_11[2] = {'n', 'k'}; +static const symbol s_10_12[3] = {'u', 'n', 'k'}; +static const symbol s_10_13[4] = {0xC3, 0xA1, 'n', 'k'}; +static const symbol s_10_14[4] = {0xC3, 0xA9, 'n', 'k'}; +static const symbol s_10_15[4] = {0xC3, 0xBC, 'n', 'k'}; +static const symbol s_10_16[2] = {'u', 'k'}; +static const symbol s_10_17[3] = {'j', 'u', 'k'}; +static const symbol s_10_18[5] = {0xC3, 0xA1, 'j', 'u', 'k'}; +static const symbol s_10_19[3] = {0xC3, 0xBC, 'k'}; +static const symbol s_10_20[4] = {'j', 0xC3, 0xBC, 'k'}; +static const symbol s_10_21[6] = {0xC3, 0xA9, 'j', 0xC3, 0xBC, 'k'}; +static const symbol s_10_22[1] = {'m'}; +static const symbol s_10_23[2] = {'a', 'm'}; +static const symbol s_10_24[2] = {'e', 'm'}; +static const symbol s_10_25[2] = {'o', 'm'}; +static const symbol s_10_26[3] = {0xC3, 0xA1, 'm'}; +static const symbol s_10_27[3] = {0xC3, 0xA9, 'm'}; +static const symbol s_10_28[1] = {'o'}; +static const symbol s_10_29[2] = {0xC3, 0xA1}; +static const symbol s_10_30[2] = {0xC3, 0xA9}; + +static const struct among a_10[31] = { + /* 0 */ {1, s_10_0, -1, 18, 0}, + /* 1 */ {2, s_10_1, 0, 17, 0}, + /* 2 */ {1, s_10_2, -1, 16, 0}, + /* 3 */ {2, s_10_3, 2, 13, 0}, + /* 4 */ {2, s_10_4, 2, 13, 0}, + /* 5 */ {2, s_10_5, 2, 13, 0}, + /* 6 */ {3, s_10_6, 2, 14, 0}, + /* 7 */ {3, s_10_7, 2, 15, 0}, + /* 8 */ {3, s_10_8, 2, 13, 0}, + /* 9 */ {1, s_10_9, -1, 18, 0}, + /* 10 */ {2, s_10_10, 9, 17, 0}, + /* 11 */ {2, s_10_11, -1, 4, 0}, + /* 12 */ {3, s_10_12, 11, 1, 0}, + /* 13 */ {4, s_10_13, 11, 2, 0}, + /* 14 */ {4, s_10_14, 11, 3, 0}, + /* 15 */ {4, s_10_15, 11, 1, 0}, + /* 16 */ {2, s_10_16, -1, 8, 0}, + /* 17 */ {3, s_10_17, 16, 7, 0}, + /* 18 */ {5, s_10_18, 17, 5, 0}, + /* 19 */ {3, s_10_19, -1, 8, 0}, + /* 20 */ {4, s_10_20, 19, 7, 0}, + /* 21 */ {6, s_10_21, 20, 6, 0}, + /* 22 */ {1, s_10_22, -1, 12, 0}, + /* 23 */ {2, s_10_23, 22, 9, 0}, + /* 24 */ {2, s_10_24, 22, 9, 0}, + /* 25 */ {2, s_10_25, 22, 9, 0}, + /* 26 */ {3, s_10_26, 22, 10, 0}, + /* 27 */ {3, s_10_27, 22, 11, 0}, + /* 28 */ {1, s_10_28, -1, 18, 0}, + /* 29 */ {2, s_10_29, -1, 19, 0}, + /* 30 */ {2, s_10_30, -1, 20, 0}}; + +static const symbol s_11_0[2] = {'i', 'd'}; +static const symbol s_11_1[3] = {'a', 'i', 'd'}; +static const symbol s_11_2[4] = {'j', 'a', 'i', 'd'}; +static const symbol s_11_3[3] = {'e', 'i', 'd'}; +static const symbol s_11_4[4] = {'j', 'e', 'i', 'd'}; +static const symbol s_11_5[4] = {0xC3, 0xA1, 'i', 'd'}; +static const symbol s_11_6[4] = {0xC3, 0xA9, 'i', 'd'}; +static const symbol s_11_7[1] = {'i'}; +static const symbol s_11_8[2] = {'a', 'i'}; +static const symbol s_11_9[3] = {'j', 'a', 'i'}; +static const symbol s_11_10[2] = {'e', 'i'}; +static const symbol s_11_11[3] = {'j', 'e', 'i'}; +static const symbol s_11_12[3] = {0xC3, 0xA1, 'i'}; +static const symbol s_11_13[3] = {0xC3, 0xA9, 'i'}; +static const symbol s_11_14[4] = {'i', 't', 'e', 'k'}; +static const symbol s_11_15[5] = {'e', 'i', 't', 'e', 'k'}; +static const symbol s_11_16[6] = {'j', 'e', 'i', 't', 'e', 'k'}; +static const symbol s_11_17[6] = {0xC3, 0xA9, 'i', 't', 'e', 'k'}; +static const symbol s_11_18[2] = {'i', 'k'}; +static const symbol s_11_19[3] = {'a', 'i', 'k'}; +static const symbol s_11_20[4] = {'j', 'a', 'i', 'k'}; +static const symbol s_11_21[3] = {'e', 'i', 'k'}; +static const symbol s_11_22[4] = {'j', 'e', 'i', 'k'}; +static const symbol s_11_23[4] = {0xC3, 0xA1, 'i', 'k'}; +static const symbol s_11_24[4] = {0xC3, 0xA9, 'i', 'k'}; +static const symbol s_11_25[3] = {'i', 'n', 'k'}; +static const symbol s_11_26[4] = {'a', 'i', 'n', 'k'}; +static const symbol s_11_27[5] = {'j', 'a', 'i', 'n', 'k'}; +static const symbol s_11_28[4] = {'e', 'i', 'n', 'k'}; +static const symbol s_11_29[5] = {'j', 'e', 'i', 'n', 'k'}; +static const symbol s_11_30[5] = {0xC3, 0xA1, 'i', 'n', 'k'}; +static const symbol s_11_31[5] = {0xC3, 0xA9, 'i', 'n', 'k'}; +static const symbol s_11_32[5] = {'a', 'i', 't', 'o', 'k'}; +static const symbol s_11_33[6] = {'j', 'a', 'i', 't', 'o', 'k'}; +static const symbol s_11_34[6] = {0xC3, 0xA1, 'i', 't', 'o', 'k'}; +static const symbol s_11_35[2] = {'i', 'm'}; +static const symbol s_11_36[3] = {'a', 'i', 'm'}; +static const symbol s_11_37[4] = {'j', 'a', 'i', 'm'}; +static const symbol s_11_38[3] = {'e', 'i', 'm'}; +static const symbol s_11_39[4] = {'j', 'e', 'i', 'm'}; +static const symbol s_11_40[4] = {0xC3, 0xA1, 'i', 'm'}; +static const symbol s_11_41[4] = {0xC3, 0xA9, 'i', 'm'}; + +static const struct among a_11[42] = { + /* 0 */ {2, s_11_0, -1, 10, 0}, + /* 1 */ {3, s_11_1, 0, 9, 0}, + /* 2 */ {4, s_11_2, 1, 6, 0}, + /* 3 */ {3, s_11_3, 0, 9, 0}, + /* 4 */ {4, s_11_4, 3, 6, 0}, + /* 5 */ {4, s_11_5, 0, 7, 0}, + /* 6 */ {4, s_11_6, 0, 8, 0}, + /* 7 */ {1, s_11_7, -1, 15, 0}, + /* 8 */ {2, s_11_8, 7, 14, 0}, + /* 9 */ {3, s_11_9, 8, 11, 0}, + /* 10 */ {2, s_11_10, 7, 14, 0}, + /* 11 */ {3, s_11_11, 10, 11, 0}, + /* 12 */ {3, s_11_12, 7, 12, 0}, + /* 13 */ {3, s_11_13, 7, 13, 0}, + /* 14 */ {4, s_11_14, -1, 24, 0}, + /* 15 */ {5, s_11_15, 14, 21, 0}, + /* 16 */ {6, s_11_16, 15, 20, 0}, + /* 17 */ {6, s_11_17, 14, 23, 0}, + /* 18 */ {2, s_11_18, -1, 29, 0}, + /* 19 */ {3, s_11_19, 18, 26, 0}, + /* 20 */ {4, s_11_20, 19, 25, 0}, + /* 21 */ {3, s_11_21, 18, 26, 0}, + /* 22 */ {4, s_11_22, 21, 25, 0}, + /* 23 */ {4, s_11_23, 18, 27, 0}, + /* 24 */ {4, s_11_24, 18, 28, 0}, + /* 25 */ {3, s_11_25, -1, 20, 0}, + /* 26 */ {4, s_11_26, 25, 17, 0}, + /* 27 */ {5, s_11_27, 26, 16, 0}, + /* 28 */ {4, s_11_28, 25, 17, 0}, + /* 29 */ {5, s_11_29, 28, 16, 0}, + /* 30 */ {5, s_11_30, 25, 18, 0}, + /* 31 */ {5, s_11_31, 25, 19, 0}, + /* 32 */ {5, s_11_32, -1, 21, 0}, + /* 33 */ {6, s_11_33, 32, 20, 0}, + /* 34 */ {6, s_11_34, -1, 22, 0}, + /* 35 */ {2, s_11_35, -1, 5, 0}, + /* 36 */ {3, s_11_36, 35, 4, 0}, + /* 37 */ {4, s_11_37, 36, 1, 0}, + /* 38 */ {3, s_11_38, 35, 4, 0}, + /* 39 */ {4, s_11_39, 38, 1, 0}, + /* 40 */ {4, s_11_40, 35, 2, 0}, + /* 41 */ {4, s_11_41, 35, 3, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 17, 52, 14}; + +static const symbol s_0[] = {'a'}; +static const symbol s_1[] = {'e'}; +static const symbol s_2[] = {'e'}; +static const symbol s_3[] = {'a'}; +static const symbol s_4[] = {'a'}; +static const symbol s_5[] = {'a'}; +static const symbol s_6[] = {'e'}; +static const symbol s_7[] = {'a'}; +static const symbol s_8[] = {'e'}; +static const symbol s_9[] = {'e'}; +static const symbol s_10[] = {'a'}; +static const symbol s_11[] = {'e'}; +static const symbol s_12[] = {'a'}; +static const symbol s_13[] = {'e'}; +static const symbol s_14[] = {'a'}; +static const symbol s_15[] = {'e'}; +static const symbol s_16[] = {'a'}; +static const symbol s_17[] = {'e'}; +static const symbol s_18[] = {'a'}; +static const symbol s_19[] = {'e'}; +static const symbol s_20[] = {'a'}; +static const symbol s_21[] = {'e'}; +static const symbol s_22[] = {'a'}; +static const symbol s_23[] = {'e'}; +static const symbol s_24[] = {'a'}; +static const symbol s_25[] = {'e'}; +static const symbol s_26[] = {'a'}; +static const symbol s_27[] = {'e'}; +static const symbol s_28[] = {'a'}; +static const symbol s_29[] = {'e'}; +static const symbol s_30[] = {'a'}; +static const symbol s_31[] = {'e'}; +static const symbol s_32[] = {'a'}; +static const symbol s_33[] = {'e'}; +static const symbol s_34[] = {'a'}; +static const symbol s_35[] = {'e'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + { + int c1 = z->c; /* or, line 51 */ + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab1; + if (in_grouping_U(z, g_v, 97, 252, 1) < 0) + goto lab1; /* goto */ /* non v, line 48 */ + { + int c2 = z->c; /* or, line 49 */ + if (z->c + 1 >= z->l || z->p[z->c + 1] >> 5 != 3 || !((101187584 >> (z->p[z->c + 1] & 0x1f)) & 1)) + goto lab3; + if (!(find_among(z, a_0, 8))) + goto lab3; /* among, line 49 */ + goto lab2; + lab3: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab1; + z->c = ret; /* next, line 49 */ + } + } + lab2: + z->I[0] = z->c; /* setmark p1, line 50 */ + goto lab0; + lab1: + z->c = c1; + if (out_grouping_U(z, g_v, 97, 252, 0)) + return 0; + { /* gopast */ /* grouping v, line 53 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 53 */ + } +lab0: + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_v_ending(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 61 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 161 && z->p[z->c - 1] != 169)) + return 0; + among_var = find_among_b(z, a_1, 2); /* substring, line 61 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 61 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 61 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 1, s_0); /* <-, line 62 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_1); /* <-, line 63 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_double(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 68 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((106790108 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_2, 23))) + return 0; /* among, line 68 */ + z->c = z->l - m_test; + } + return 1; +} + +static int r_undouble(struct SN_env *z) { + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 73 */ + } + z->ket = z->c; /* [, line 73 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, z->l, -1); + if (ret < 0) + return 0; + z->c = ret; /* hop, line 73 */ + } + z->bra = z->c; /* ], line 73 */ + { + int ret = slice_del(z); /* delete, line 73 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_instrum(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 77 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 108) + return 0; + among_var = find_among_b(z, a_3, 2); /* substring, line 77 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 77 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 77 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_double(z); + if (ret == 0) + return 0; /* call double, line 78 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = r_double(z); + if (ret == 0) + return 0; /* call double, line 79 */ + if (ret < 0) + return ret; + } break; + } + { + int ret = slice_del(z); /* delete, line 81 */ + if (ret < 0) + return ret; + } + { + int ret = r_undouble(z); + if (ret == 0) + return 0; /* call undouble, line 82 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_case(struct SN_env *z) { + z->ket = z->c; /* [, line 87 */ + if (!(find_among_b(z, a_4, 44))) + return 0; /* substring, line 87 */ + z->bra = z->c; /* ], line 87 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 87 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 111 */ + if (ret < 0) + return ret; + } + { + int ret = r_v_ending(z); + if (ret == 0) + return 0; /* call v_ending, line 112 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_case_special(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 116 */ + if (z->c - 2 <= z->lb || (z->p[z->c - 1] != 110 && z->p[z->c - 1] != 116)) + return 0; + among_var = find_among_b(z, a_5, 3); /* substring, line 116 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 116 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 116 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 1, s_2); /* <-, line 117 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_3); /* <-, line 118 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_4); /* <-, line 119 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_case_other(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 124 */ + if (z->c - 3 <= z->lb || z->p[z->c - 1] != 108) + return 0; + among_var = find_among_b(z, a_6, 6); /* substring, line 124 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 124 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 124 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 125 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_del(z); /* delete, line 126 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_5); /* <-, line 127 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_6); /* <-, line 128 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_factive(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 133 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 161 && z->p[z->c - 1] != 169)) + return 0; + among_var = find_among_b(z, a_7, 2); /* substring, line 133 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 133 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 133 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_double(z); + if (ret == 0) + return 0; /* call double, line 134 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = r_double(z); + if (ret == 0) + return 0; /* call double, line 135 */ + if (ret < 0) + return ret; + } break; + } + { + int ret = slice_del(z); /* delete, line 137 */ + if (ret < 0) + return ret; + } + { + int ret = r_undouble(z); + if (ret == 0) + return 0; /* call undouble, line 138 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_plural(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 142 */ + if (z->c <= z->lb || z->p[z->c - 1] != 107) + return 0; + among_var = find_among_b(z, a_8, 7); /* substring, line 142 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 142 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 142 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 1, s_7); /* <-, line 143 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_8); /* <-, line 144 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_del(z); /* delete, line 145 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_del(z); /* delete, line 146 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_del(z); /* delete, line 147 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_del(z); /* delete, line 148 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_del(z); /* delete, line 149 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_owned(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 154 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 105 && z->p[z->c - 1] != 169)) + return 0; + among_var = find_among_b(z, a_9, 12); /* substring, line 154 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 154 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 154 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 155 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_9); /* <-, line 156 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_10); /* <-, line 157 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_del(z); /* delete, line 158 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 1, s_11); /* <-, line 159 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 1, s_12); /* <-, line 160 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_del(z); /* delete, line 161 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_from_s(z, 1, s_13); /* <-, line 162 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_del(z); /* delete, line 163 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_sing_owner(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 168 */ + among_var = find_among_b(z, a_10, 31); /* substring, line 168 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 168 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 168 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 169 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_14); /* <-, line 170 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_15); /* <-, line 171 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_del(z); /* delete, line 172 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 1, s_16); /* <-, line 173 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 1, s_17); /* <-, line 174 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_del(z); /* delete, line 175 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_del(z); /* delete, line 176 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_del(z); /* delete, line 177 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = slice_from_s(z, 1, s_18); /* <-, line 178 */ + if (ret < 0) + return ret; + } break; + case 11: { + int ret = slice_from_s(z, 1, s_19); /* <-, line 179 */ + if (ret < 0) + return ret; + } break; + case 12: { + int ret = slice_del(z); /* delete, line 180 */ + if (ret < 0) + return ret; + } break; + case 13: { + int ret = slice_del(z); /* delete, line 181 */ + if (ret < 0) + return ret; + } break; + case 14: { + int ret = slice_from_s(z, 1, s_20); /* <-, line 182 */ + if (ret < 0) + return ret; + } break; + case 15: { + int ret = slice_from_s(z, 1, s_21); /* <-, line 183 */ + if (ret < 0) + return ret; + } break; + case 16: { + int ret = slice_del(z); /* delete, line 184 */ + if (ret < 0) + return ret; + } break; + case 17: { + int ret = slice_del(z); /* delete, line 185 */ + if (ret < 0) + return ret; + } break; + case 18: { + int ret = slice_del(z); /* delete, line 186 */ + if (ret < 0) + return ret; + } break; + case 19: { + int ret = slice_from_s(z, 1, s_22); /* <-, line 187 */ + if (ret < 0) + return ret; + } break; + case 20: { + int ret = slice_from_s(z, 1, s_23); /* <-, line 188 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_plur_owner(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 193 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((10768 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_11, 42); /* substring, line 193 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 193 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 193 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 194 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_24); /* <-, line 195 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_25); /* <-, line 196 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_del(z); /* delete, line 197 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_del(z); /* delete, line 198 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_del(z); /* delete, line 199 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_from_s(z, 1, s_26); /* <-, line 200 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_from_s(z, 1, s_27); /* <-, line 201 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_del(z); /* delete, line 202 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = slice_del(z); /* delete, line 203 */ + if (ret < 0) + return ret; + } break; + case 11: { + int ret = slice_del(z); /* delete, line 204 */ + if (ret < 0) + return ret; + } break; + case 12: { + int ret = slice_from_s(z, 1, s_28); /* <-, line 205 */ + if (ret < 0) + return ret; + } break; + case 13: { + int ret = slice_from_s(z, 1, s_29); /* <-, line 206 */ + if (ret < 0) + return ret; + } break; + case 14: { + int ret = slice_del(z); /* delete, line 207 */ + if (ret < 0) + return ret; + } break; + case 15: { + int ret = slice_del(z); /* delete, line 208 */ + if (ret < 0) + return ret; + } break; + case 16: { + int ret = slice_del(z); /* delete, line 209 */ + if (ret < 0) + return ret; + } break; + case 17: { + int ret = slice_del(z); /* delete, line 210 */ + if (ret < 0) + return ret; + } break; + case 18: { + int ret = slice_from_s(z, 1, s_30); /* <-, line 211 */ + if (ret < 0) + return ret; + } break; + case 19: { + int ret = slice_from_s(z, 1, s_31); /* <-, line 212 */ + if (ret < 0) + return ret; + } break; + case 20: { + int ret = slice_del(z); /* delete, line 214 */ + if (ret < 0) + return ret; + } break; + case 21: { + int ret = slice_del(z); /* delete, line 215 */ + if (ret < 0) + return ret; + } break; + case 22: { + int ret = slice_from_s(z, 1, s_32); /* <-, line 216 */ + if (ret < 0) + return ret; + } break; + case 23: { + int ret = slice_from_s(z, 1, s_33); /* <-, line 217 */ + if (ret < 0) + return ret; + } break; + case 24: { + int ret = slice_del(z); /* delete, line 218 */ + if (ret < 0) + return ret; + } break; + case 25: { + int ret = slice_del(z); /* delete, line 219 */ + if (ret < 0) + return ret; + } break; + case 26: { + int ret = slice_del(z); /* delete, line 220 */ + if (ret < 0) + return ret; + } break; + case 27: { + int ret = slice_from_s(z, 1, s_34); /* <-, line 221 */ + if (ret < 0) + return ret; + } break; + case 28: { + int ret = slice_from_s(z, 1, s_35); /* <-, line 222 */ + if (ret < 0) + return ret; + } break; + case 29: { + int ret = slice_del(z); /* delete, line 223 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +extern int hungarian_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 229 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 229 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 230 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 231 */ + { + int ret = r_instrum(z); + if (ret == 0) + goto lab1; /* call instrum, line 231 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 232 */ + { + int ret = r_case(z); + if (ret == 0) + goto lab2; /* call case, line 232 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 233 */ + { + int ret = r_case_special(z); + if (ret == 0) + goto lab3; /* call case_special, line 233 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 234 */ + { + int ret = r_case_other(z); + if (ret == 0) + goto lab4; /* call case_other, line 234 */ + if (ret < 0) + return ret; + } + lab4: + z->c = z->l - m5; + } + { + int m6 = z->l - z->c; + (void)m6; /* do, line 235 */ + { + int ret = r_factive(z); + if (ret == 0) + goto lab5; /* call factive, line 235 */ + if (ret < 0) + return ret; + } + lab5: + z->c = z->l - m6; + } + { + int m7 = z->l - z->c; + (void)m7; /* do, line 236 */ + { + int ret = r_owned(z); + if (ret == 0) + goto lab6; /* call owned, line 236 */ + if (ret < 0) + return ret; + } + lab6: + z->c = z->l - m7; + } + { + int m8 = z->l - z->c; + (void)m8; /* do, line 237 */ + { + int ret = r_sing_owner(z); + if (ret == 0) + goto lab7; /* call sing_owner, line 237 */ + if (ret < 0) + return ret; + } + lab7: + z->c = z->l - m8; + } + { + int m9 = z->l - z->c; + (void)m9; /* do, line 238 */ + { + int ret = r_plur_owner(z); + if (ret == 0) + goto lab8; /* call plur_owner, line 238 */ + if (ret < 0) + return ret; + } + lab8: + z->c = z->l - m9; + } + { + int m10 = z->l - z->c; + (void)m10; /* do, line 239 */ + { + int ret = r_plural(z); + if (ret == 0) + goto lab9; /* call plural, line 239 */ + if (ret < 0) + return ret; + } + lab9: + z->c = z->l - m10; + } + z->c = z->lb; + return 1; +} + +extern struct SN_env *hungarian_UTF_8_create_env(void) { return SN_create_env(0, 1, 0); } + +extern void hungarian_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_hungarian.h b/internal/cpp/stemmer/stem_UTF_8_hungarian.h new file mode 100644 index 00000000000..8f994a56c2e --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_hungarian.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *hungarian_UTF_8_create_env(void); +extern void hungarian_UTF_8_close_env(struct SN_env *z); + +extern int hungarian_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_italian.cpp b/internal/cpp/stemmer/stem_UTF_8_italian.cpp new file mode 100644 index 00000000000..249dde23f38 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_italian.cpp @@ -0,0 +1,1288 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int italian_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_vowel_suffix(struct SN_env *z); +static int r_verb_suffix(struct SN_env *z); +static int r_standard_suffix(struct SN_env *z); +static int r_attached_pronoun(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_RV(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *italian_UTF_8_create_env(void); +extern void italian_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[2] = {'q', 'u'}; +static const symbol s_0_2[2] = {0xC3, 0xA1}; +static const symbol s_0_3[2] = {0xC3, 0xA9}; +static const symbol s_0_4[2] = {0xC3, 0xAD}; +static const symbol s_0_5[2] = {0xC3, 0xB3}; +static const symbol s_0_6[2] = {0xC3, 0xBA}; + +static const struct among a_0[7] = { + /* 0 */ {0, 0, -1, 7, 0}, + /* 1 */ {2, s_0_1, 0, 6, 0}, + /* 2 */ {2, s_0_2, 0, 1, 0}, + /* 3 */ {2, s_0_3, 0, 2, 0}, + /* 4 */ {2, s_0_4, 0, 3, 0}, + /* 5 */ {2, s_0_5, 0, 4, 0}, + /* 6 */ {2, s_0_6, 0, 5, 0}}; + +static const symbol s_1_1[1] = {'I'}; +static const symbol s_1_2[1] = {'U'}; + +static const struct among a_1[3] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {1, s_1_1, 0, 1, 0}, + /* 2 */ {1, s_1_2, 0, 2, 0}}; + +static const symbol s_2_0[2] = {'l', 'a'}; +static const symbol s_2_1[4] = {'c', 'e', 'l', 'a'}; +static const symbol s_2_2[6] = {'g', 'l', 'i', 'e', 'l', 'a'}; +static const symbol s_2_3[4] = {'m', 'e', 'l', 'a'}; +static const symbol s_2_4[4] = {'t', 'e', 'l', 'a'}; +static const symbol s_2_5[4] = {'v', 'e', 'l', 'a'}; +static const symbol s_2_6[2] = {'l', 'e'}; +static const symbol s_2_7[4] = {'c', 'e', 'l', 'e'}; +static const symbol s_2_8[6] = {'g', 'l', 'i', 'e', 'l', 'e'}; +static const symbol s_2_9[4] = {'m', 'e', 'l', 'e'}; +static const symbol s_2_10[4] = {'t', 'e', 'l', 'e'}; +static const symbol s_2_11[4] = {'v', 'e', 'l', 'e'}; +static const symbol s_2_12[2] = {'n', 'e'}; +static const symbol s_2_13[4] = {'c', 'e', 'n', 'e'}; +static const symbol s_2_14[6] = {'g', 'l', 'i', 'e', 'n', 'e'}; +static const symbol s_2_15[4] = {'m', 'e', 'n', 'e'}; +static const symbol s_2_16[4] = {'s', 'e', 'n', 'e'}; +static const symbol s_2_17[4] = {'t', 'e', 'n', 'e'}; +static const symbol s_2_18[4] = {'v', 'e', 'n', 'e'}; +static const symbol s_2_19[2] = {'c', 'i'}; +static const symbol s_2_20[2] = {'l', 'i'}; +static const symbol s_2_21[4] = {'c', 'e', 'l', 'i'}; +static const symbol s_2_22[6] = {'g', 'l', 'i', 'e', 'l', 'i'}; +static const symbol s_2_23[4] = {'m', 'e', 'l', 'i'}; +static const symbol s_2_24[4] = {'t', 'e', 'l', 'i'}; +static const symbol s_2_25[4] = {'v', 'e', 'l', 'i'}; +static const symbol s_2_26[3] = {'g', 'l', 'i'}; +static const symbol s_2_27[2] = {'m', 'i'}; +static const symbol s_2_28[2] = {'s', 'i'}; +static const symbol s_2_29[2] = {'t', 'i'}; +static const symbol s_2_30[2] = {'v', 'i'}; +static const symbol s_2_31[2] = {'l', 'o'}; +static const symbol s_2_32[4] = {'c', 'e', 'l', 'o'}; +static const symbol s_2_33[6] = {'g', 'l', 'i', 'e', 'l', 'o'}; +static const symbol s_2_34[4] = {'m', 'e', 'l', 'o'}; +static const symbol s_2_35[4] = {'t', 'e', 'l', 'o'}; +static const symbol s_2_36[4] = {'v', 'e', 'l', 'o'}; + +static const struct among a_2[37] = { + /* 0 */ {2, s_2_0, -1, -1, 0}, + /* 1 */ {4, s_2_1, 0, -1, 0}, + /* 2 */ {6, s_2_2, 0, -1, 0}, + /* 3 */ {4, s_2_3, 0, -1, 0}, + /* 4 */ {4, s_2_4, 0, -1, 0}, + /* 5 */ {4, s_2_5, 0, -1, 0}, + /* 6 */ {2, s_2_6, -1, -1, 0}, + /* 7 */ {4, s_2_7, 6, -1, 0}, + /* 8 */ {6, s_2_8, 6, -1, 0}, + /* 9 */ {4, s_2_9, 6, -1, 0}, + /* 10 */ {4, s_2_10, 6, -1, 0}, + /* 11 */ {4, s_2_11, 6, -1, 0}, + /* 12 */ {2, s_2_12, -1, -1, 0}, + /* 13 */ {4, s_2_13, 12, -1, 0}, + /* 14 */ {6, s_2_14, 12, -1, 0}, + /* 15 */ {4, s_2_15, 12, -1, 0}, + /* 16 */ {4, s_2_16, 12, -1, 0}, + /* 17 */ {4, s_2_17, 12, -1, 0}, + /* 18 */ {4, s_2_18, 12, -1, 0}, + /* 19 */ {2, s_2_19, -1, -1, 0}, + /* 20 */ {2, s_2_20, -1, -1, 0}, + /* 21 */ {4, s_2_21, 20, -1, 0}, + /* 22 */ {6, s_2_22, 20, -1, 0}, + /* 23 */ {4, s_2_23, 20, -1, 0}, + /* 24 */ {4, s_2_24, 20, -1, 0}, + /* 25 */ {4, s_2_25, 20, -1, 0}, + /* 26 */ {3, s_2_26, 20, -1, 0}, + /* 27 */ {2, s_2_27, -1, -1, 0}, + /* 28 */ {2, s_2_28, -1, -1, 0}, + /* 29 */ {2, s_2_29, -1, -1, 0}, + /* 30 */ {2, s_2_30, -1, -1, 0}, + /* 31 */ {2, s_2_31, -1, -1, 0}, + /* 32 */ {4, s_2_32, 31, -1, 0}, + /* 33 */ {6, s_2_33, 31, -1, 0}, + /* 34 */ {4, s_2_34, 31, -1, 0}, + /* 35 */ {4, s_2_35, 31, -1, 0}, + /* 36 */ {4, s_2_36, 31, -1, 0}}; + +static const symbol s_3_0[4] = {'a', 'n', 'd', 'o'}; +static const symbol s_3_1[4] = {'e', 'n', 'd', 'o'}; +static const symbol s_3_2[2] = {'a', 'r'}; +static const symbol s_3_3[2] = {'e', 'r'}; +static const symbol s_3_4[2] = {'i', 'r'}; + +static const struct among a_3[5] = { + /* 0 */ {4, s_3_0, -1, 1, 0}, + /* 1 */ {4, s_3_1, -1, 1, 0}, + /* 2 */ {2, s_3_2, -1, 2, 0}, + /* 3 */ {2, s_3_3, -1, 2, 0}, + /* 4 */ {2, s_3_4, -1, 2, 0}}; + +static const symbol s_4_0[2] = {'i', 'c'}; +static const symbol s_4_1[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_4_2[2] = {'o', 's'}; +static const symbol s_4_3[2] = {'i', 'v'}; + +static const struct among a_4[4] = { + /* 0 */ {2, s_4_0, -1, -1, 0}, + /* 1 */ {4, s_4_1, -1, -1, 0}, + /* 2 */ {2, s_4_2, -1, -1, 0}, + /* 3 */ {2, s_4_3, -1, 1, 0}}; + +static const symbol s_5_0[2] = {'i', 'c'}; +static const symbol s_5_1[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_5_2[2] = {'i', 'v'}; + +static const struct among a_5[3] = { + /* 0 */ {2, s_5_0, -1, 1, 0}, + /* 1 */ {4, s_5_1, -1, 1, 0}, + /* 2 */ {2, s_5_2, -1, 1, 0}}; + +static const symbol s_6_0[3] = {'i', 'c', 'a'}; +static const symbol s_6_1[5] = {'l', 'o', 'g', 'i', 'a'}; +static const symbol s_6_2[3] = {'o', 's', 'a'}; +static const symbol s_6_3[4] = {'i', 's', 't', 'a'}; +static const symbol s_6_4[3] = {'i', 'v', 'a'}; +static const symbol s_6_5[4] = {'a', 'n', 'z', 'a'}; +static const symbol s_6_6[4] = {'e', 'n', 'z', 'a'}; +static const symbol s_6_7[3] = {'i', 'c', 'e'}; +static const symbol s_6_8[6] = {'a', 't', 'r', 'i', 'c', 'e'}; +static const symbol s_6_9[4] = {'i', 'c', 'h', 'e'}; +static const symbol s_6_10[5] = {'l', 'o', 'g', 'i', 'e'}; +static const symbol s_6_11[5] = {'a', 'b', 'i', 'l', 'e'}; +static const symbol s_6_12[5] = {'i', 'b', 'i', 'l', 'e'}; +static const symbol s_6_13[6] = {'u', 's', 'i', 'o', 'n', 'e'}; +static const symbol s_6_14[6] = {'a', 'z', 'i', 'o', 'n', 'e'}; +static const symbol s_6_15[6] = {'u', 'z', 'i', 'o', 'n', 'e'}; +static const symbol s_6_16[5] = {'a', 't', 'o', 'r', 'e'}; +static const symbol s_6_17[3] = {'o', 's', 'e'}; +static const symbol s_6_18[4] = {'a', 'n', 't', 'e'}; +static const symbol s_6_19[5] = {'m', 'e', 'n', 't', 'e'}; +static const symbol s_6_20[6] = {'a', 'm', 'e', 'n', 't', 'e'}; +static const symbol s_6_21[4] = {'i', 's', 't', 'e'}; +static const symbol s_6_22[3] = {'i', 'v', 'e'}; +static const symbol s_6_23[4] = {'a', 'n', 'z', 'e'}; +static const symbol s_6_24[4] = {'e', 'n', 'z', 'e'}; +static const symbol s_6_25[3] = {'i', 'c', 'i'}; +static const symbol s_6_26[6] = {'a', 't', 'r', 'i', 'c', 'i'}; +static const symbol s_6_27[4] = {'i', 'c', 'h', 'i'}; +static const symbol s_6_28[5] = {'a', 'b', 'i', 'l', 'i'}; +static const symbol s_6_29[5] = {'i', 'b', 'i', 'l', 'i'}; +static const symbol s_6_30[4] = {'i', 's', 'm', 'i'}; +static const symbol s_6_31[6] = {'u', 's', 'i', 'o', 'n', 'i'}; +static const symbol s_6_32[6] = {'a', 'z', 'i', 'o', 'n', 'i'}; +static const symbol s_6_33[6] = {'u', 'z', 'i', 'o', 'n', 'i'}; +static const symbol s_6_34[5] = {'a', 't', 'o', 'r', 'i'}; +static const symbol s_6_35[3] = {'o', 's', 'i'}; +static const symbol s_6_36[4] = {'a', 'n', 't', 'i'}; +static const symbol s_6_37[6] = {'a', 'm', 'e', 'n', 't', 'i'}; +static const symbol s_6_38[6] = {'i', 'm', 'e', 'n', 't', 'i'}; +static const symbol s_6_39[4] = {'i', 's', 't', 'i'}; +static const symbol s_6_40[3] = {'i', 'v', 'i'}; +static const symbol s_6_41[3] = {'i', 'c', 'o'}; +static const symbol s_6_42[4] = {'i', 's', 'm', 'o'}; +static const symbol s_6_43[3] = {'o', 's', 'o'}; +static const symbol s_6_44[6] = {'a', 'm', 'e', 'n', 't', 'o'}; +static const symbol s_6_45[6] = {'i', 'm', 'e', 'n', 't', 'o'}; +static const symbol s_6_46[3] = {'i', 'v', 'o'}; +static const symbol s_6_47[4] = {'i', 't', 0xC3, 0xA0}; +static const symbol s_6_48[5] = {'i', 's', 't', 0xC3, 0xA0}; +static const symbol s_6_49[5] = {'i', 's', 't', 0xC3, 0xA8}; +static const symbol s_6_50[5] = {'i', 's', 't', 0xC3, 0xAC}; + +static const struct among a_6[51] = { + /* 0 */ {3, s_6_0, -1, 1, 0}, + /* 1 */ {5, s_6_1, -1, 3, 0}, + /* 2 */ {3, s_6_2, -1, 1, 0}, + /* 3 */ {4, s_6_3, -1, 1, 0}, + /* 4 */ {3, s_6_4, -1, 9, 0}, + /* 5 */ {4, s_6_5, -1, 1, 0}, + /* 6 */ {4, s_6_6, -1, 5, 0}, + /* 7 */ {3, s_6_7, -1, 1, 0}, + /* 8 */ {6, s_6_8, 7, 1, 0}, + /* 9 */ {4, s_6_9, -1, 1, 0}, + /* 10 */ {5, s_6_10, -1, 3, 0}, + /* 11 */ {5, s_6_11, -1, 1, 0}, + /* 12 */ {5, s_6_12, -1, 1, 0}, + /* 13 */ {6, s_6_13, -1, 4, 0}, + /* 14 */ {6, s_6_14, -1, 2, 0}, + /* 15 */ {6, s_6_15, -1, 4, 0}, + /* 16 */ {5, s_6_16, -1, 2, 0}, + /* 17 */ {3, s_6_17, -1, 1, 0}, + /* 18 */ {4, s_6_18, -1, 1, 0}, + /* 19 */ {5, s_6_19, -1, 1, 0}, + /* 20 */ {6, s_6_20, 19, 7, 0}, + /* 21 */ {4, s_6_21, -1, 1, 0}, + /* 22 */ {3, s_6_22, -1, 9, 0}, + /* 23 */ {4, s_6_23, -1, 1, 0}, + /* 24 */ {4, s_6_24, -1, 5, 0}, + /* 25 */ {3, s_6_25, -1, 1, 0}, + /* 26 */ {6, s_6_26, 25, 1, 0}, + /* 27 */ {4, s_6_27, -1, 1, 0}, + /* 28 */ {5, s_6_28, -1, 1, 0}, + /* 29 */ {5, s_6_29, -1, 1, 0}, + /* 30 */ {4, s_6_30, -1, 1, 0}, + /* 31 */ {6, s_6_31, -1, 4, 0}, + /* 32 */ {6, s_6_32, -1, 2, 0}, + /* 33 */ {6, s_6_33, -1, 4, 0}, + /* 34 */ {5, s_6_34, -1, 2, 0}, + /* 35 */ {3, s_6_35, -1, 1, 0}, + /* 36 */ {4, s_6_36, -1, 1, 0}, + /* 37 */ {6, s_6_37, -1, 6, 0}, + /* 38 */ {6, s_6_38, -1, 6, 0}, + /* 39 */ {4, s_6_39, -1, 1, 0}, + /* 40 */ {3, s_6_40, -1, 9, 0}, + /* 41 */ {3, s_6_41, -1, 1, 0}, + /* 42 */ {4, s_6_42, -1, 1, 0}, + /* 43 */ {3, s_6_43, -1, 1, 0}, + /* 44 */ {6, s_6_44, -1, 6, 0}, + /* 45 */ {6, s_6_45, -1, 6, 0}, + /* 46 */ {3, s_6_46, -1, 9, 0}, + /* 47 */ {4, s_6_47, -1, 8, 0}, + /* 48 */ {5, s_6_48, -1, 1, 0}, + /* 49 */ {5, s_6_49, -1, 1, 0}, + /* 50 */ {5, s_6_50, -1, 1, 0}}; + +static const symbol s_7_0[4] = {'i', 's', 'c', 'a'}; +static const symbol s_7_1[4] = {'e', 'n', 'd', 'a'}; +static const symbol s_7_2[3] = {'a', 't', 'a'}; +static const symbol s_7_3[3] = {'i', 't', 'a'}; +static const symbol s_7_4[3] = {'u', 't', 'a'}; +static const symbol s_7_5[3] = {'a', 'v', 'a'}; +static const symbol s_7_6[3] = {'e', 'v', 'a'}; +static const symbol s_7_7[3] = {'i', 'v', 'a'}; +static const symbol s_7_8[6] = {'e', 'r', 'e', 'b', 'b', 'e'}; +static const symbol s_7_9[6] = {'i', 'r', 'e', 'b', 'b', 'e'}; +static const symbol s_7_10[4] = {'i', 's', 'c', 'e'}; +static const symbol s_7_11[4] = {'e', 'n', 'd', 'e'}; +static const symbol s_7_12[3] = {'a', 'r', 'e'}; +static const symbol s_7_13[3] = {'e', 'r', 'e'}; +static const symbol s_7_14[3] = {'i', 'r', 'e'}; +static const symbol s_7_15[4] = {'a', 's', 's', 'e'}; +static const symbol s_7_16[3] = {'a', 't', 'e'}; +static const symbol s_7_17[5] = {'a', 'v', 'a', 't', 'e'}; +static const symbol s_7_18[5] = {'e', 'v', 'a', 't', 'e'}; +static const symbol s_7_19[5] = {'i', 'v', 'a', 't', 'e'}; +static const symbol s_7_20[3] = {'e', 't', 'e'}; +static const symbol s_7_21[5] = {'e', 'r', 'e', 't', 'e'}; +static const symbol s_7_22[5] = {'i', 'r', 'e', 't', 'e'}; +static const symbol s_7_23[3] = {'i', 't', 'e'}; +static const symbol s_7_24[6] = {'e', 'r', 'e', 's', 't', 'e'}; +static const symbol s_7_25[6] = {'i', 'r', 'e', 's', 't', 'e'}; +static const symbol s_7_26[3] = {'u', 't', 'e'}; +static const symbol s_7_27[4] = {'e', 'r', 'a', 'i'}; +static const symbol s_7_28[4] = {'i', 'r', 'a', 'i'}; +static const symbol s_7_29[4] = {'i', 's', 'c', 'i'}; +static const symbol s_7_30[4] = {'e', 'n', 'd', 'i'}; +static const symbol s_7_31[4] = {'e', 'r', 'e', 'i'}; +static const symbol s_7_32[4] = {'i', 'r', 'e', 'i'}; +static const symbol s_7_33[4] = {'a', 's', 's', 'i'}; +static const symbol s_7_34[3] = {'a', 't', 'i'}; +static const symbol s_7_35[3] = {'i', 't', 'i'}; +static const symbol s_7_36[6] = {'e', 'r', 'e', 's', 't', 'i'}; +static const symbol s_7_37[6] = {'i', 'r', 'e', 's', 't', 'i'}; +static const symbol s_7_38[3] = {'u', 't', 'i'}; +static const symbol s_7_39[3] = {'a', 'v', 'i'}; +static const symbol s_7_40[3] = {'e', 'v', 'i'}; +static const symbol s_7_41[3] = {'i', 'v', 'i'}; +static const symbol s_7_42[4] = {'i', 's', 'c', 'o'}; +static const symbol s_7_43[4] = {'a', 'n', 'd', 'o'}; +static const symbol s_7_44[4] = {'e', 'n', 'd', 'o'}; +static const symbol s_7_45[4] = {'Y', 'a', 'm', 'o'}; +static const symbol s_7_46[4] = {'i', 'a', 'm', 'o'}; +static const symbol s_7_47[5] = {'a', 'v', 'a', 'm', 'o'}; +static const symbol s_7_48[5] = {'e', 'v', 'a', 'm', 'o'}; +static const symbol s_7_49[5] = {'i', 'v', 'a', 'm', 'o'}; +static const symbol s_7_50[5] = {'e', 'r', 'e', 'm', 'o'}; +static const symbol s_7_51[5] = {'i', 'r', 'e', 'm', 'o'}; +static const symbol s_7_52[6] = {'a', 's', 's', 'i', 'm', 'o'}; +static const symbol s_7_53[4] = {'a', 'm', 'm', 'o'}; +static const symbol s_7_54[4] = {'e', 'm', 'm', 'o'}; +static const symbol s_7_55[6] = {'e', 'r', 'e', 'm', 'm', 'o'}; +static const symbol s_7_56[6] = {'i', 'r', 'e', 'm', 'm', 'o'}; +static const symbol s_7_57[4] = {'i', 'm', 'm', 'o'}; +static const symbol s_7_58[3] = {'a', 'n', 'o'}; +static const symbol s_7_59[6] = {'i', 's', 'c', 'a', 'n', 'o'}; +static const symbol s_7_60[5] = {'a', 'v', 'a', 'n', 'o'}; +static const symbol s_7_61[5] = {'e', 'v', 'a', 'n', 'o'}; +static const symbol s_7_62[5] = {'i', 'v', 'a', 'n', 'o'}; +static const symbol s_7_63[6] = {'e', 'r', 'a', 'n', 'n', 'o'}; +static const symbol s_7_64[6] = {'i', 'r', 'a', 'n', 'n', 'o'}; +static const symbol s_7_65[3] = {'o', 'n', 'o'}; +static const symbol s_7_66[6] = {'i', 's', 'c', 'o', 'n', 'o'}; +static const symbol s_7_67[5] = {'a', 'r', 'o', 'n', 'o'}; +static const symbol s_7_68[5] = {'e', 'r', 'o', 'n', 'o'}; +static const symbol s_7_69[5] = {'i', 'r', 'o', 'n', 'o'}; +static const symbol s_7_70[8] = {'e', 'r', 'e', 'b', 'b', 'e', 'r', 'o'}; +static const symbol s_7_71[8] = {'i', 'r', 'e', 'b', 'b', 'e', 'r', 'o'}; +static const symbol s_7_72[6] = {'a', 's', 's', 'e', 'r', 'o'}; +static const symbol s_7_73[6] = {'e', 's', 's', 'e', 'r', 'o'}; +static const symbol s_7_74[6] = {'i', 's', 's', 'e', 'r', 'o'}; +static const symbol s_7_75[3] = {'a', 't', 'o'}; +static const symbol s_7_76[3] = {'i', 't', 'o'}; +static const symbol s_7_77[3] = {'u', 't', 'o'}; +static const symbol s_7_78[3] = {'a', 'v', 'o'}; +static const symbol s_7_79[3] = {'e', 'v', 'o'}; +static const symbol s_7_80[3] = {'i', 'v', 'o'}; +static const symbol s_7_81[2] = {'a', 'r'}; +static const symbol s_7_82[2] = {'i', 'r'}; +static const symbol s_7_83[4] = {'e', 'r', 0xC3, 0xA0}; +static const symbol s_7_84[4] = {'i', 'r', 0xC3, 0xA0}; +static const symbol s_7_85[4] = {'e', 'r', 0xC3, 0xB2}; +static const symbol s_7_86[4] = {'i', 'r', 0xC3, 0xB2}; + +static const struct among a_7[87] = { + /* 0 */ {4, s_7_0, -1, 1, 0}, + /* 1 */ {4, s_7_1, -1, 1, 0}, + /* 2 */ {3, s_7_2, -1, 1, 0}, + /* 3 */ {3, s_7_3, -1, 1, 0}, + /* 4 */ {3, s_7_4, -1, 1, 0}, + /* 5 */ {3, s_7_5, -1, 1, 0}, + /* 6 */ {3, s_7_6, -1, 1, 0}, + /* 7 */ {3, s_7_7, -1, 1, 0}, + /* 8 */ {6, s_7_8, -1, 1, 0}, + /* 9 */ {6, s_7_9, -1, 1, 0}, + /* 10 */ {4, s_7_10, -1, 1, 0}, + /* 11 */ {4, s_7_11, -1, 1, 0}, + /* 12 */ {3, s_7_12, -1, 1, 0}, + /* 13 */ {3, s_7_13, -1, 1, 0}, + /* 14 */ {3, s_7_14, -1, 1, 0}, + /* 15 */ {4, s_7_15, -1, 1, 0}, + /* 16 */ {3, s_7_16, -1, 1, 0}, + /* 17 */ {5, s_7_17, 16, 1, 0}, + /* 18 */ {5, s_7_18, 16, 1, 0}, + /* 19 */ {5, s_7_19, 16, 1, 0}, + /* 20 */ {3, s_7_20, -1, 1, 0}, + /* 21 */ {5, s_7_21, 20, 1, 0}, + /* 22 */ {5, s_7_22, 20, 1, 0}, + /* 23 */ {3, s_7_23, -1, 1, 0}, + /* 24 */ {6, s_7_24, -1, 1, 0}, + /* 25 */ {6, s_7_25, -1, 1, 0}, + /* 26 */ {3, s_7_26, -1, 1, 0}, + /* 27 */ {4, s_7_27, -1, 1, 0}, + /* 28 */ {4, s_7_28, -1, 1, 0}, + /* 29 */ {4, s_7_29, -1, 1, 0}, + /* 30 */ {4, s_7_30, -1, 1, 0}, + /* 31 */ {4, s_7_31, -1, 1, 0}, + /* 32 */ {4, s_7_32, -1, 1, 0}, + /* 33 */ {4, s_7_33, -1, 1, 0}, + /* 34 */ {3, s_7_34, -1, 1, 0}, + /* 35 */ {3, s_7_35, -1, 1, 0}, + /* 36 */ {6, s_7_36, -1, 1, 0}, + /* 37 */ {6, s_7_37, -1, 1, 0}, + /* 38 */ {3, s_7_38, -1, 1, 0}, + /* 39 */ {3, s_7_39, -1, 1, 0}, + /* 40 */ {3, s_7_40, -1, 1, 0}, + /* 41 */ {3, s_7_41, -1, 1, 0}, + /* 42 */ {4, s_7_42, -1, 1, 0}, + /* 43 */ {4, s_7_43, -1, 1, 0}, + /* 44 */ {4, s_7_44, -1, 1, 0}, + /* 45 */ {4, s_7_45, -1, 1, 0}, + /* 46 */ {4, s_7_46, -1, 1, 0}, + /* 47 */ {5, s_7_47, -1, 1, 0}, + /* 48 */ {5, s_7_48, -1, 1, 0}, + /* 49 */ {5, s_7_49, -1, 1, 0}, + /* 50 */ {5, s_7_50, -1, 1, 0}, + /* 51 */ {5, s_7_51, -1, 1, 0}, + /* 52 */ {6, s_7_52, -1, 1, 0}, + /* 53 */ {4, s_7_53, -1, 1, 0}, + /* 54 */ {4, s_7_54, -1, 1, 0}, + /* 55 */ {6, s_7_55, 54, 1, 0}, + /* 56 */ {6, s_7_56, 54, 1, 0}, + /* 57 */ {4, s_7_57, -1, 1, 0}, + /* 58 */ {3, s_7_58, -1, 1, 0}, + /* 59 */ {6, s_7_59, 58, 1, 0}, + /* 60 */ {5, s_7_60, 58, 1, 0}, + /* 61 */ {5, s_7_61, 58, 1, 0}, + /* 62 */ {5, s_7_62, 58, 1, 0}, + /* 63 */ {6, s_7_63, -1, 1, 0}, + /* 64 */ {6, s_7_64, -1, 1, 0}, + /* 65 */ {3, s_7_65, -1, 1, 0}, + /* 66 */ {6, s_7_66, 65, 1, 0}, + /* 67 */ {5, s_7_67, 65, 1, 0}, + /* 68 */ {5, s_7_68, 65, 1, 0}, + /* 69 */ {5, s_7_69, 65, 1, 0}, + /* 70 */ {8, s_7_70, -1, 1, 0}, + /* 71 */ {8, s_7_71, -1, 1, 0}, + /* 72 */ {6, s_7_72, -1, 1, 0}, + /* 73 */ {6, s_7_73, -1, 1, 0}, + /* 74 */ {6, s_7_74, -1, 1, 0}, + /* 75 */ {3, s_7_75, -1, 1, 0}, + /* 76 */ {3, s_7_76, -1, 1, 0}, + /* 77 */ {3, s_7_77, -1, 1, 0}, + /* 78 */ {3, s_7_78, -1, 1, 0}, + /* 79 */ {3, s_7_79, -1, 1, 0}, + /* 80 */ {3, s_7_80, -1, 1, 0}, + /* 81 */ {2, s_7_81, -1, 1, 0}, + /* 82 */ {2, s_7_82, -1, 1, 0}, + /* 83 */ {4, s_7_83, -1, 1, 0}, + /* 84 */ {4, s_7_84, -1, 1, 0}, + /* 85 */ {4, s_7_85, -1, 1, 0}, + /* 86 */ {4, s_7_86, -1, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 128, 8, 2, 1}; + +static const unsigned char g_AEIO[] = {17, 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 128, 8, 2}; + +static const unsigned char g_CG[] = {17}; + +static const symbol s_0[] = {0xC3, 0xA0}; +static const symbol s_1[] = {0xC3, 0xA8}; +static const symbol s_2[] = {0xC3, 0xAC}; +static const symbol s_3[] = {0xC3, 0xB2}; +static const symbol s_4[] = {0xC3, 0xB9}; +static const symbol s_5[] = {'q', 'U'}; +static const symbol s_6[] = {'u'}; +static const symbol s_7[] = {'U'}; +static const symbol s_8[] = {'i'}; +static const symbol s_9[] = {'I'}; +static const symbol s_10[] = {'i'}; +static const symbol s_11[] = {'u'}; +static const symbol s_12[] = {'e'}; +static const symbol s_13[] = {'i', 'c'}; +static const symbol s_14[] = {'l', 'o', 'g'}; +static const symbol s_15[] = {'u'}; +static const symbol s_16[] = {'e', 'n', 't', 'e'}; +static const symbol s_17[] = {'a', 't'}; +static const symbol s_18[] = {'a', 't'}; +static const symbol s_19[] = {'i', 'c'}; +static const symbol s_20[] = {'i'}; +static const symbol s_21[] = {'h'}; + +static int r_prelude(struct SN_env *z) { + int among_var; + { + int c_test = z->c; /* test, line 35 */ + while (1) { /* repeat, line 35 */ + int c1 = z->c; + z->bra = z->c; /* [, line 36 */ + among_var = find_among(z, a_0, 7); /* substring, line 36 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 36 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 2, s_0); /* <-, line 37 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 2, s_1); /* <-, line 38 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 2, s_2); /* <-, line 39 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 2, s_3); /* <-, line 40 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 2, s_4); /* <-, line 41 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 2, s_5); /* <-, line 42 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 43 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + z->c = c_test; + } + while (1) { /* repeat, line 46 */ + int c2 = z->c; + while (1) { /* goto, line 46 */ + int c3 = z->c; + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab2; + z->bra = z->c; /* [, line 47 */ + { + int c4 = z->c; /* or, line 47 */ + if (!(eq_s(z, 1, s_6))) + goto lab4; + z->ket = z->c; /* ], line 47 */ + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab4; + { + int ret = slice_from_s(z, 1, s_7); /* <-, line 47 */ + if (ret < 0) + return ret; + } + goto lab3; + lab4: + z->c = c4; + if (!(eq_s(z, 1, s_8))) + goto lab2; + z->ket = z->c; /* ], line 48 */ + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab2; + { + int ret = slice_from_s(z, 1, s_9); /* <-, line 48 */ + if (ret < 0) + return ret; + } + } + lab3: + z->c = c3; + break; + lab2: + z->c = c3; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab1; + z->c = ret; /* goto, line 46 */ + } + } + continue; + lab1: + z->c = c2; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + z->I[2] = z->l; + { + int c1 = z->c; /* do, line 58 */ + { + int c2 = z->c; /* or, line 60 */ + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab2; + { + int c3 = z->c; /* or, line 59 */ + if (out_grouping_U(z, g_v, 97, 249, 0)) + goto lab4; + { /* gopast */ /* grouping v, line 59 */ + int ret = out_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + goto lab3; + lab4: + z->c = c3; + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab2; + { /* gopast */ /* non v, line 59 */ + int ret = in_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab2; + z->c += ret; + } + } + lab3: + goto lab1; + lab2: + z->c = c2; + if (out_grouping_U(z, g_v, 97, 249, 0)) + goto lab0; + { + int c4 = z->c; /* or, line 61 */ + if (out_grouping_U(z, g_v, 97, 249, 0)) + goto lab6; + { /* gopast */ /* grouping v, line 61 */ + int ret = out_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab6; + z->c += ret; + } + goto lab5; + lab6: + z->c = c4; + if (in_grouping_U(z, g_v, 97, 249, 0)) + goto lab0; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 61 */ + } + } + lab5:; + } + lab1: + z->I[0] = z->c; /* setmark pV, line 62 */ + lab0: + z->c = c1; + } + { + int c5 = z->c; /* do, line 64 */ + { /* gopast */ /* grouping v, line 65 */ + int ret = out_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 65 */ + int ret = in_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[1] = z->c; /* setmark p1, line 65 */ + { /* gopast */ /* grouping v, line 66 */ + int ret = out_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 66 */ + int ret = in_grouping_U(z, g_v, 97, 249, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[2] = z->c; /* setmark p2, line 66 */ + lab7: + z->c = c5; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 70 */ + int c1 = z->c; + z->bra = z->c; /* [, line 72 */ + if (z->c >= z->l || (z->p[z->c + 0] != 73 && z->p[z->c + 0] != 85)) + among_var = 3; + else + among_var = find_among(z, a_1, 3); /* substring, line 72 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 72 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_10); /* <-, line 73 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_11); /* <-, line 74 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 75 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_RV(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[2] <= z->c)) + return 0; + return 1; +} + +static int r_attached_pronoun(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 87 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((33314 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_2, 37))) + return 0; /* substring, line 87 */ + z->bra = z->c; /* ], line 87 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 111 && z->p[z->c - 1] != 114)) + return 0; + among_var = find_among_b(z, a_3, 5); /* among, line 97 */ + if (!(among_var)) + return 0; + { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 97 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 98 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_12); /* <-, line 99 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 104 */ + among_var = find_among_b(z, a_6, 51); /* substring, line 104 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 104 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 111 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 111 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 113 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 113 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 114 */ + z->ket = z->c; /* [, line 114 */ + if (!(eq_s_b(z, 2, s_13))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 114 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call R2, line 114 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 114 */ + if (ret < 0) + return ret; + } + lab0:; + } + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 117 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_14); /* <-, line 117 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 119 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 1, s_15); /* <-, line 119 */ + if (ret < 0) + return ret; + } + break; + case 5: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 121 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 4, s_16); /* <-, line 121 */ + if (ret < 0) + return ret; + } + break; + case 6: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 123 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 123 */ + if (ret < 0) + return ret; + } + break; + case 7: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 125 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 125 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 126 */ + z->ket = z->c; /* [, line 127 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4722696 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab1; + } + among_var = find_among_b(z, a_4, 4); /* substring, line 127 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 127 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call R2, line 127 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 127 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab1; + } + case 1: + z->ket = z->c; /* [, line 128 */ + if (!(eq_s_b(z, 2, s_17))) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 128 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call R2, line 128 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 128 */ + if (ret < 0) + return ret; + } + break; + } + lab1:; + } + break; + case 8: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 134 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 134 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 135 */ + z->ket = z->c; /* [, line 136 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4198408 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab2; + } + among_var = find_among_b(z, a_5, 3); /* substring, line 136 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab2; + } + z->bra = z->c; /* ], line 136 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab2; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab2; + } /* call R2, line 137 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 137 */ + if (ret < 0) + return ret; + } + break; + } + lab2:; + } + break; + case 9: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 142 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 142 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 143 */ + z->ket = z->c; /* [, line 143 */ + if (!(eq_s_b(z, 2, s_18))) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 143 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 143 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 143 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 143 */ + if (!(eq_s_b(z, 2, s_19))) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 143 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 143 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 143 */ + if (ret < 0) + return ret; + } + lab3:; + } + break; + } + return 1; +} + +static int r_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 148 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 148 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 149 */ + among_var = find_among_b(z, a_7, 87); /* substring, line 149 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 149 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int ret = slice_del(z); /* delete, line 163 */ + if (ret < 0) + return ret; + } break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_vowel_suffix(struct SN_env *z) { + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 171 */ + z->ket = z->c; /* [, line 172 */ + if (in_grouping_b_U(z, g_AEIO, 97, 242, 0)) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 172 */ + { + int ret = r_RV(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call RV, line 172 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 172 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 173 */ + if (!(eq_s_b(z, 1, s_20))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 173 */ + { + int ret = r_RV(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call RV, line 173 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 173 */ + if (ret < 0) + return ret; + } + lab0:; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 175 */ + z->ket = z->c; /* [, line 176 */ + if (!(eq_s_b(z, 1, s_21))) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 176 */ + if (in_grouping_b_U(z, g_CG, 99, 103, 0)) { + z->c = z->l - m_keep; + goto lab1; + } + { + int ret = r_RV(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call RV, line 176 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 176 */ + if (ret < 0) + return ret; + } + lab1:; + } + return 1; +} + +extern int italian_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 182 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 182 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 183 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 183 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 184 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 185 */ + { + int ret = r_attached_pronoun(z); + if (ret == 0) + goto lab2; /* call attached_pronoun, line 185 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 186 */ + { + int m5 = z->l - z->c; + (void)m5; /* or, line 186 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab5; /* call standard_suffix, line 186 */ + if (ret < 0) + return ret; + } + goto lab4; + lab5: + z->c = z->l - m5; + { + int ret = r_verb_suffix(z); + if (ret == 0) + goto lab3; /* call verb_suffix, line 186 */ + if (ret < 0) + return ret; + } + } + lab4: + lab3: + z->c = z->l - m4; + } + { + int m6 = z->l - z->c; + (void)m6; /* do, line 187 */ + { + int ret = r_vowel_suffix(z); + if (ret == 0) + goto lab6; /* call vowel_suffix, line 187 */ + if (ret < 0) + return ret; + } + lab6: + z->c = z->l - m6; + } + z->c = z->lb; + { + int c7 = z->c; /* do, line 189 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab7; /* call postlude, line 189 */ + if (ret < 0) + return ret; + } + lab7: + z->c = c7; + } + return 1; +} + +extern struct SN_env *italian_UTF_8_create_env(void) { return SN_create_env(0, 3, 0); } + +extern void italian_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_italian.h b/internal/cpp/stemmer/stem_UTF_8_italian.h new file mode 100644 index 00000000000..1f79599ace8 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_italian.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *italian_UTF_8_create_env(void); +extern void italian_UTF_8_close_env(struct SN_env *z); + +extern int italian_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_norwegian.cpp b/internal/cpp/stemmer/stem_UTF_8_norwegian.cpp new file mode 100644 index 00000000000..4fbc9cd4b1c --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_norwegian.cpp @@ -0,0 +1,357 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int norwegian_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_other_suffix(struct SN_env *z); +static int r_consonant_pair(struct SN_env *z); +static int r_main_suffix(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *norwegian_UTF_8_create_env(void); +extern void norwegian_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[1] = {'a'}; +static const symbol s_0_1[1] = {'e'}; +static const symbol s_0_2[3] = {'e', 'd', 'e'}; +static const symbol s_0_3[4] = {'a', 'n', 'd', 'e'}; +static const symbol s_0_4[4] = {'e', 'n', 'd', 'e'}; +static const symbol s_0_5[3] = {'a', 'n', 'e'}; +static const symbol s_0_6[3] = {'e', 'n', 'e'}; +static const symbol s_0_7[6] = {'h', 'e', 't', 'e', 'n', 'e'}; +static const symbol s_0_8[4] = {'e', 'r', 't', 'e'}; +static const symbol s_0_9[2] = {'e', 'n'}; +static const symbol s_0_10[5] = {'h', 'e', 't', 'e', 'n'}; +static const symbol s_0_11[2] = {'a', 'r'}; +static const symbol s_0_12[2] = {'e', 'r'}; +static const symbol s_0_13[5] = {'h', 'e', 't', 'e', 'r'}; +static const symbol s_0_14[1] = {'s'}; +static const symbol s_0_15[2] = {'a', 's'}; +static const symbol s_0_16[2] = {'e', 's'}; +static const symbol s_0_17[4] = {'e', 'd', 'e', 's'}; +static const symbol s_0_18[5] = {'e', 'n', 'd', 'e', 's'}; +static const symbol s_0_19[4] = {'e', 'n', 'e', 's'}; +static const symbol s_0_20[7] = {'h', 'e', 't', 'e', 'n', 'e', 's'}; +static const symbol s_0_21[3] = {'e', 'n', 's'}; +static const symbol s_0_22[6] = {'h', 'e', 't', 'e', 'n', 's'}; +static const symbol s_0_23[3] = {'e', 'r', 's'}; +static const symbol s_0_24[3] = {'e', 't', 's'}; +static const symbol s_0_25[2] = {'e', 't'}; +static const symbol s_0_26[3] = {'h', 'e', 't'}; +static const symbol s_0_27[3] = {'e', 'r', 't'}; +static const symbol s_0_28[3] = {'a', 's', 't'}; + +static const struct among a_0[29] = { + /* 0 */ {1, s_0_0, -1, 1, 0}, + /* 1 */ {1, s_0_1, -1, 1, 0}, + /* 2 */ {3, s_0_2, 1, 1, 0}, + /* 3 */ {4, s_0_3, 1, 1, 0}, + /* 4 */ {4, s_0_4, 1, 1, 0}, + /* 5 */ {3, s_0_5, 1, 1, 0}, + /* 6 */ {3, s_0_6, 1, 1, 0}, + /* 7 */ {6, s_0_7, 6, 1, 0}, + /* 8 */ {4, s_0_8, 1, 3, 0}, + /* 9 */ {2, s_0_9, -1, 1, 0}, + /* 10 */ {5, s_0_10, 9, 1, 0}, + /* 11 */ {2, s_0_11, -1, 1, 0}, + /* 12 */ {2, s_0_12, -1, 1, 0}, + /* 13 */ {5, s_0_13, 12, 1, 0}, + /* 14 */ {1, s_0_14, -1, 2, 0}, + /* 15 */ {2, s_0_15, 14, 1, 0}, + /* 16 */ {2, s_0_16, 14, 1, 0}, + /* 17 */ {4, s_0_17, 16, 1, 0}, + /* 18 */ {5, s_0_18, 16, 1, 0}, + /* 19 */ {4, s_0_19, 16, 1, 0}, + /* 20 */ {7, s_0_20, 19, 1, 0}, + /* 21 */ {3, s_0_21, 14, 1, 0}, + /* 22 */ {6, s_0_22, 21, 1, 0}, + /* 23 */ {3, s_0_23, 14, 1, 0}, + /* 24 */ {3, s_0_24, 14, 1, 0}, + /* 25 */ {2, s_0_25, -1, 1, 0}, + /* 26 */ {3, s_0_26, 25, 1, 0}, + /* 27 */ {3, s_0_27, -1, 3, 0}, + /* 28 */ {3, s_0_28, -1, 1, 0}}; + +static const symbol s_1_0[2] = {'d', 't'}; +static const symbol s_1_1[2] = {'v', 't'}; + +static const struct among a_1[2] = { + /* 0 */ {2, s_1_0, -1, -1, 0}, + /* 1 */ {2, s_1_1, -1, -1, 0}}; + +static const symbol s_2_0[3] = {'l', 'e', 'g'}; +static const symbol s_2_1[4] = {'e', 'l', 'e', 'g'}; +static const symbol s_2_2[2] = {'i', 'g'}; +static const symbol s_2_3[3] = {'e', 'i', 'g'}; +static const symbol s_2_4[3] = {'l', 'i', 'g'}; +static const symbol s_2_5[4] = {'e', 'l', 'i', 'g'}; +static const symbol s_2_6[3] = {'e', 'l', 's'}; +static const symbol s_2_7[3] = {'l', 'o', 'v'}; +static const symbol s_2_8[4] = {'e', 'l', 'o', 'v'}; +static const symbol s_2_9[4] = {'s', 'l', 'o', 'v'}; +static const symbol s_2_10[7] = {'h', 'e', 't', 's', 'l', 'o', 'v'}; + +static const struct among a_2[11] = { + /* 0 */ {3, s_2_0, -1, 1, 0}, + /* 1 */ {4, s_2_1, 0, 1, 0}, + /* 2 */ {2, s_2_2, -1, 1, 0}, + /* 3 */ {3, s_2_3, 2, 1, 0}, + /* 4 */ {3, s_2_4, 2, 1, 0}, + /* 5 */ {4, s_2_5, 4, 1, 0}, + /* 6 */ {3, s_2_6, -1, 1, 0}, + /* 7 */ {3, s_2_7, -1, 1, 0}, + /* 8 */ {4, s_2_8, 7, 1, 0}, + /* 9 */ {4, s_2_9, 7, 1, 0}, + /* 10 */ {7, s_2_10, 9, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 0, 128}; + +static const unsigned char g_s_ending[] = {119, 125, 149, 1}; + +static const symbol s_0[] = {'k'}; +static const symbol s_1[] = {'e', 'r'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + { + int c_test = z->c; /* test, line 30 */ + { + int ret = skip_utf8(z->p, z->c, 0, z->l, +3); + if (ret < 0) + return 0; + z->c = ret; /* hop, line 30 */ + } + z->I[1] = z->c; /* setmark x, line 30 */ + z->c = c_test; + } + if (out_grouping_U(z, g_v, 97, 248, 1) < 0) + return 0; /* goto */ /* grouping v, line 31 */ + { /* gopast */ /* non v, line 31 */ + int ret = in_grouping_U(z, g_v, 97, 248, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 31 */ + /* try, line 32 */ + if (!(z->I[0] < z->I[1])) + goto lab0; + z->I[0] = z->I[1]; +lab0: + return 1; +} + +static int r_main_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 38 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 38 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 38 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1851426 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_0, 29); /* substring, line 38 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 38 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 44 */ + if (ret < 0) + return ret; + } break; + case 2: { + int m2 = z->l - z->c; + (void)m2; /* or, line 46 */ + if (in_grouping_b_U(z, g_s_ending, 98, 122, 0)) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m2; + if (!(eq_s_b(z, 1, s_0))) + return 0; + if (out_grouping_b_U(z, g_v, 97, 248, 0)) + return 0; + } + lab0: { + int ret = slice_del(z); /* delete, line 46 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 2, s_1); /* <-, line 48 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_consonant_pair(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 53 */ + { + int mlimit; /* setlimit, line 54 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 54 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 54 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 116) { + z->lb = mlimit; + return 0; + } + if (!(find_among_b(z, a_1, 2))) { + z->lb = mlimit; + return 0; + } /* substring, line 54 */ + z->bra = z->c; /* ], line 54 */ + z->lb = mlimit; + } + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 59 */ + } + z->bra = z->c; /* ], line 59 */ + { + int ret = slice_del(z); /* delete, line 59 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_other_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 63 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 63 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 63 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4718720 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_2, 11); /* substring, line 63 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 63 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 67 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +extern int norwegian_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 74 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 74 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 75 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 76 */ + { + int ret = r_main_suffix(z); + if (ret == 0) + goto lab1; /* call main_suffix, line 76 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 77 */ + { + int ret = r_consonant_pair(z); + if (ret == 0) + goto lab2; /* call consonant_pair, line 77 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 78 */ + { + int ret = r_other_suffix(z); + if (ret == 0) + goto lab3; /* call other_suffix, line 78 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + z->c = z->lb; + return 1; +} + +extern struct SN_env *norwegian_UTF_8_create_env(void) { return SN_create_env(0, 2, 0); } + +extern void norwegian_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_norwegian.h b/internal/cpp/stemmer/stem_UTF_8_norwegian.h new file mode 100644 index 00000000000..e9ce2f8fa6b --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_norwegian.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *norwegian_UTF_8_create_env(void); +extern void norwegian_UTF_8_close_env(struct SN_env *z); + +extern int norwegian_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_porter.cpp b/internal/cpp/stemmer/stem_UTF_8_porter.cpp new file mode 100644 index 00000000000..a13ce35eb0b --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_porter.cpp @@ -0,0 +1,888 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int porter_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_Step_5b(struct SN_env *z); +static int r_Step_5a(struct SN_env *z); +static int r_Step_4(struct SN_env *z); +static int r_Step_3(struct SN_env *z); +static int r_Step_2(struct SN_env *z); +static int r_Step_1c(struct SN_env *z); +static int r_Step_1b(struct SN_env *z); +static int r_Step_1a(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_shortv(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *porter_UTF_8_create_env(void); +extern void porter_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[1] = {'s'}; +static const symbol s_0_1[3] = {'i', 'e', 's'}; +static const symbol s_0_2[4] = {'s', 's', 'e', 's'}; +static const symbol s_0_3[2] = {'s', 's'}; + +static const struct among a_0[4] = { + /* 0 */ {1, s_0_0, -1, 3, 0}, + /* 1 */ {3, s_0_1, 0, 2, 0}, + /* 2 */ {4, s_0_2, 0, 1, 0}, + /* 3 */ {2, s_0_3, 0, -1, 0}}; + +static const symbol s_1_1[2] = {'b', 'b'}; +static const symbol s_1_2[2] = {'d', 'd'}; +static const symbol s_1_3[2] = {'f', 'f'}; +static const symbol s_1_4[2] = {'g', 'g'}; +static const symbol s_1_5[2] = {'b', 'l'}; +static const symbol s_1_6[2] = {'m', 'm'}; +static const symbol s_1_7[2] = {'n', 'n'}; +static const symbol s_1_8[2] = {'p', 'p'}; +static const symbol s_1_9[2] = {'r', 'r'}; +static const symbol s_1_10[2] = {'a', 't'}; +static const symbol s_1_11[2] = {'t', 't'}; +static const symbol s_1_12[2] = {'i', 'z'}; + +static const struct among a_1[13] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {2, s_1_1, 0, 2, 0}, + /* 2 */ {2, s_1_2, 0, 2, 0}, + /* 3 */ {2, s_1_3, 0, 2, 0}, + /* 4 */ {2, s_1_4, 0, 2, 0}, + /* 5 */ {2, s_1_5, 0, 1, 0}, + /* 6 */ {2, s_1_6, 0, 2, 0}, + /* 7 */ {2, s_1_7, 0, 2, 0}, + /* 8 */ {2, s_1_8, 0, 2, 0}, + /* 9 */ {2, s_1_9, 0, 2, 0}, + /* 10 */ {2, s_1_10, 0, 1, 0}, + /* 11 */ {2, s_1_11, 0, 2, 0}, + /* 12 */ {2, s_1_12, 0, 1, 0}}; + +static const symbol s_2_0[2] = {'e', 'd'}; +static const symbol s_2_1[3] = {'e', 'e', 'd'}; +static const symbol s_2_2[3] = {'i', 'n', 'g'}; + +static const struct among a_2[3] = { + /* 0 */ {2, s_2_0, -1, 2, 0}, + /* 1 */ {3, s_2_1, 0, 1, 0}, + /* 2 */ {3, s_2_2, -1, 2, 0}}; + +static const symbol s_3_0[4] = {'a', 'n', 'c', 'i'}; +static const symbol s_3_1[4] = {'e', 'n', 'c', 'i'}; +static const symbol s_3_2[4] = {'a', 'b', 'l', 'i'}; +static const symbol s_3_3[3] = {'e', 'l', 'i'}; +static const symbol s_3_4[4] = {'a', 'l', 'l', 'i'}; +static const symbol s_3_5[5] = {'o', 'u', 's', 'l', 'i'}; +static const symbol s_3_6[5] = {'e', 'n', 't', 'l', 'i'}; +static const symbol s_3_7[5] = {'a', 'l', 'i', 't', 'i'}; +static const symbol s_3_8[6] = {'b', 'i', 'l', 'i', 't', 'i'}; +static const symbol s_3_9[5] = {'i', 'v', 'i', 't', 'i'}; +static const symbol s_3_10[6] = {'t', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_3_11[7] = {'a', 't', 'i', 'o', 'n', 'a', 'l'}; +static const symbol s_3_12[5] = {'a', 'l', 'i', 's', 'm'}; +static const symbol s_3_13[5] = {'a', 't', 'i', 'o', 'n'}; +static const symbol s_3_14[7] = {'i', 'z', 'a', 't', 'i', 'o', 'n'}; +static const symbol s_3_15[4] = {'i', 'z', 'e', 'r'}; +static const symbol s_3_16[4] = {'a', 't', 'o', 'r'}; +static const symbol s_3_17[7] = {'i', 'v', 'e', 'n', 'e', 's', 's'}; +static const symbol s_3_18[7] = {'f', 'u', 'l', 'n', 'e', 's', 's'}; +static const symbol s_3_19[7] = {'o', 'u', 's', 'n', 'e', 's', 's'}; + +static const struct among a_3[20] = { + /* 0 */ {4, s_3_0, -1, 3, 0}, + /* 1 */ {4, s_3_1, -1, 2, 0}, + /* 2 */ {4, s_3_2, -1, 4, 0}, + /* 3 */ {3, s_3_3, -1, 6, 0}, + /* 4 */ {4, s_3_4, -1, 9, 0}, + /* 5 */ {5, s_3_5, -1, 12, 0}, + /* 6 */ {5, s_3_6, -1, 5, 0}, + /* 7 */ {5, s_3_7, -1, 10, 0}, + /* 8 */ {6, s_3_8, -1, 14, 0}, + /* 9 */ {5, s_3_9, -1, 13, 0}, + /* 10 */ {6, s_3_10, -1, 1, 0}, + /* 11 */ {7, s_3_11, 10, 8, 0}, + /* 12 */ {5, s_3_12, -1, 10, 0}, + /* 13 */ {5, s_3_13, -1, 8, 0}, + /* 14 */ {7, s_3_14, 13, 7, 0}, + /* 15 */ {4, s_3_15, -1, 7, 0}, + /* 16 */ {4, s_3_16, -1, 8, 0}, + /* 17 */ {7, s_3_17, -1, 13, 0}, + /* 18 */ {7, s_3_18, -1, 11, 0}, + /* 19 */ {7, s_3_19, -1, 12, 0}}; + +static const symbol s_4_0[5] = {'i', 'c', 'a', 't', 'e'}; +static const symbol s_4_1[5] = {'a', 't', 'i', 'v', 'e'}; +static const symbol s_4_2[5] = {'a', 'l', 'i', 'z', 'e'}; +static const symbol s_4_3[5] = {'i', 'c', 'i', 't', 'i'}; +static const symbol s_4_4[4] = {'i', 'c', 'a', 'l'}; +static const symbol s_4_5[3] = {'f', 'u', 'l'}; +static const symbol s_4_6[4] = {'n', 'e', 's', 's'}; + +static const struct among a_4[7] = { + /* 0 */ {5, s_4_0, -1, 2, 0}, + /* 1 */ {5, s_4_1, -1, 3, 0}, + /* 2 */ {5, s_4_2, -1, 1, 0}, + /* 3 */ {5, s_4_3, -1, 2, 0}, + /* 4 */ {4, s_4_4, -1, 2, 0}, + /* 5 */ {3, s_4_5, -1, 3, 0}, + /* 6 */ {4, s_4_6, -1, 3, 0}}; + +static const symbol s_5_0[2] = {'i', 'c'}; +static const symbol s_5_1[4] = {'a', 'n', 'c', 'e'}; +static const symbol s_5_2[4] = {'e', 'n', 'c', 'e'}; +static const symbol s_5_3[4] = {'a', 'b', 'l', 'e'}; +static const symbol s_5_4[4] = {'i', 'b', 'l', 'e'}; +static const symbol s_5_5[3] = {'a', 't', 'e'}; +static const symbol s_5_6[3] = {'i', 'v', 'e'}; +static const symbol s_5_7[3] = {'i', 'z', 'e'}; +static const symbol s_5_8[3] = {'i', 't', 'i'}; +static const symbol s_5_9[2] = {'a', 'l'}; +static const symbol s_5_10[3] = {'i', 's', 'm'}; +static const symbol s_5_11[3] = {'i', 'o', 'n'}; +static const symbol s_5_12[2] = {'e', 'r'}; +static const symbol s_5_13[3] = {'o', 'u', 's'}; +static const symbol s_5_14[3] = {'a', 'n', 't'}; +static const symbol s_5_15[3] = {'e', 'n', 't'}; +static const symbol s_5_16[4] = {'m', 'e', 'n', 't'}; +static const symbol s_5_17[5] = {'e', 'm', 'e', 'n', 't'}; +static const symbol s_5_18[2] = {'o', 'u'}; + +static const struct among a_5[19] = { + /* 0 */ {2, s_5_0, -1, 1, 0}, + /* 1 */ {4, s_5_1, -1, 1, 0}, + /* 2 */ {4, s_5_2, -1, 1, 0}, + /* 3 */ {4, s_5_3, -1, 1, 0}, + /* 4 */ {4, s_5_4, -1, 1, 0}, + /* 5 */ {3, s_5_5, -1, 1, 0}, + /* 6 */ {3, s_5_6, -1, 1, 0}, + /* 7 */ {3, s_5_7, -1, 1, 0}, + /* 8 */ {3, s_5_8, -1, 1, 0}, + /* 9 */ {2, s_5_9, -1, 1, 0}, + /* 10 */ {3, s_5_10, -1, 1, 0}, + /* 11 */ {3, s_5_11, -1, 2, 0}, + /* 12 */ {2, s_5_12, -1, 1, 0}, + /* 13 */ {3, s_5_13, -1, 1, 0}, + /* 14 */ {3, s_5_14, -1, 1, 0}, + /* 15 */ {3, s_5_15, -1, 1, 0}, + /* 16 */ {4, s_5_16, 15, 1, 0}, + /* 17 */ {5, s_5_17, 16, 1, 0}, + /* 18 */ {2, s_5_18, -1, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1}; + +static const unsigned char g_v_WXY[] = {1, 17, 65, 208, 1}; + +static const symbol s_0[] = {'s', 's'}; +static const symbol s_1[] = {'i'}; +static const symbol s_2[] = {'e', 'e'}; +static const symbol s_3[] = {'e'}; +static const symbol s_4[] = {'e'}; +static const symbol s_5[] = {'y'}; +static const symbol s_6[] = {'Y'}; +static const symbol s_7[] = {'i'}; +static const symbol s_8[] = {'t', 'i', 'o', 'n'}; +static const symbol s_9[] = {'e', 'n', 'c', 'e'}; +static const symbol s_10[] = {'a', 'n', 'c', 'e'}; +static const symbol s_11[] = {'a', 'b', 'l', 'e'}; +static const symbol s_12[] = {'e', 'n', 't'}; +static const symbol s_13[] = {'e'}; +static const symbol s_14[] = {'i', 'z', 'e'}; +static const symbol s_15[] = {'a', 't', 'e'}; +static const symbol s_16[] = {'a', 'l'}; +static const symbol s_17[] = {'a', 'l'}; +static const symbol s_18[] = {'f', 'u', 'l'}; +static const symbol s_19[] = {'o', 'u', 's'}; +static const symbol s_20[] = {'i', 'v', 'e'}; +static const symbol s_21[] = {'b', 'l', 'e'}; +static const symbol s_22[] = {'a', 'l'}; +static const symbol s_23[] = {'i', 'c'}; +static const symbol s_24[] = {'s'}; +static const symbol s_25[] = {'t'}; +static const symbol s_26[] = {'e'}; +static const symbol s_27[] = {'l'}; +static const symbol s_28[] = {'l'}; +static const symbol s_29[] = {'y'}; +static const symbol s_30[] = {'Y'}; +static const symbol s_31[] = {'y'}; +static const symbol s_32[] = {'Y'}; +static const symbol s_33[] = {'Y'}; +static const symbol s_34[] = {'y'}; + +static int r_shortv(struct SN_env *z) { + if (out_grouping_b_U(z, g_v_WXY, 89, 121, 0)) + return 0; + if (in_grouping_b_U(z, g_v, 97, 121, 0)) + return 0; + if (out_grouping_b_U(z, g_v, 97, 121, 0)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_Step_1a(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 25 */ + if (z->c <= z->lb || z->p[z->c - 1] != 115) + return 0; + among_var = find_among_b(z, a_0, 4); /* substring, line 25 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 25 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 2, s_0); /* <-, line 26 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_1); /* <-, line 27 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_del(z); /* delete, line 29 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_Step_1b(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 34 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 100 && z->p[z->c - 1] != 103)) + return 0; + among_var = find_among_b(z, a_2, 3); /* substring, line 34 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 34 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 35 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 2, s_2); /* <-, line 35 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int m_test = z->l - z->c; /* test, line 38 */ + { /* gopast */ /* grouping v, line 38 */ + int ret = out_grouping_b_U(z, g_v, 97, 121, 1); + if (ret < 0) + return 0; + z->c -= ret; + } + z->c = z->l - m_test; + } + { + int ret = slice_del(z); /* delete, line 38 */ + if (ret < 0) + return ret; + } + { + int m_test = z->l - z->c; /* test, line 39 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((68514004 >> (z->p[z->c - 1] & 0x1f)) & 1)) + among_var = 3; + else + among_var = find_among_b(z, a_1, 13); /* substring, line 39 */ + if (!(among_var)) + return 0; + z->c = z->l - m_test; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_3); /* <+, line 41 */ + z->c = c_keep; + if (ret < 0) + return ret; + } break; + case 2: + z->ket = z->c; /* [, line 44 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 44 */ + } + z->bra = z->c; /* ], line 44 */ + { + int ret = slice_del(z); /* delete, line 44 */ + if (ret < 0) + return ret; + } + break; + case 3: + if (z->c != z->I[0]) + return 0; /* atmark, line 45 */ + { + int m_test = z->l - z->c; /* test, line 45 */ + { + int ret = r_shortv(z); + if (ret == 0) + return 0; /* call shortv, line 45 */ + if (ret < 0) + return ret; + } + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_4); /* <+, line 45 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + break; + } + break; + } + return 1; +} + +static int r_Step_1c(struct SN_env *z) { + z->ket = z->c; /* [, line 52 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 52 */ + if (!(eq_s_b(z, 1, s_5))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_6))) + return 0; + } +lab0: + z->bra = z->c; /* ], line 52 */ + { /* gopast */ /* grouping v, line 53 */ + int ret = out_grouping_b_U(z, g_v, 97, 121, 1); + if (ret < 0) + return 0; + z->c -= ret; + } + { + int ret = slice_from_s(z, 1, s_7); /* <-, line 54 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_Step_2(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 58 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((815616 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_3, 20); /* substring, line 58 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 58 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 58 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 4, s_8); /* <-, line 59 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 4, s_9); /* <-, line 60 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 4, s_10); /* <-, line 61 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 4, s_11); /* <-, line 62 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 3, s_12); /* <-, line 63 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 1, s_13); /* <-, line 64 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_from_s(z, 3, s_14); /* <-, line 66 */ + if (ret < 0) + return ret; + } break; + case 8: { + int ret = slice_from_s(z, 3, s_15); /* <-, line 68 */ + if (ret < 0) + return ret; + } break; + case 9: { + int ret = slice_from_s(z, 2, s_16); /* <-, line 69 */ + if (ret < 0) + return ret; + } break; + case 10: { + int ret = slice_from_s(z, 2, s_17); /* <-, line 71 */ + if (ret < 0) + return ret; + } break; + case 11: { + int ret = slice_from_s(z, 3, s_18); /* <-, line 72 */ + if (ret < 0) + return ret; + } break; + case 12: { + int ret = slice_from_s(z, 3, s_19); /* <-, line 74 */ + if (ret < 0) + return ret; + } break; + case 13: { + int ret = slice_from_s(z, 3, s_20); /* <-, line 76 */ + if (ret < 0) + return ret; + } break; + case 14: { + int ret = slice_from_s(z, 3, s_21); /* <-, line 77 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_Step_3(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 82 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((528928 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_4, 7); /* substring, line 82 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 82 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 82 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 2, s_22); /* <-, line 83 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 2, s_23); /* <-, line 85 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_del(z); /* delete, line 87 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_Step_4(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 92 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((3961384 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_5, 19); /* substring, line 92 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 92 */ + { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 92 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 95 */ + if (ret < 0) + return ret; + } break; + case 2: { + int m1 = z->l - z->c; + (void)m1; /* or, line 96 */ + if (!(eq_s_b(z, 1, s_24))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_25))) + return 0; + } + lab0: { + int ret = slice_del(z); /* delete, line 96 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_Step_5a(struct SN_env *z) { + z->ket = z->c; /* [, line 101 */ + if (!(eq_s_b(z, 1, s_26))) + return 0; + z->bra = z->c; /* ], line 101 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 102 */ + { + int ret = r_R2(z); + if (ret == 0) + goto lab1; /* call R2, line 102 */ + if (ret < 0) + return ret; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 102 */ + if (ret < 0) + return ret; + } + { + int m2 = z->l - z->c; + (void)m2; /* not, line 102 */ + { + int ret = r_shortv(z); + if (ret == 0) + goto lab2; /* call shortv, line 102 */ + if (ret < 0) + return ret; + } + return 0; + lab2: + z->c = z->l - m2; + } + } +lab0: { + int ret = slice_del(z); /* delete, line 103 */ + if (ret < 0) + return ret; +} + return 1; +} + +static int r_Step_5b(struct SN_env *z) { + z->ket = z->c; /* [, line 107 */ + if (!(eq_s_b(z, 1, s_27))) + return 0; + z->bra = z->c; /* ], line 107 */ + { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 108 */ + if (ret < 0) + return ret; + } + if (!(eq_s_b(z, 1, s_28))) + return 0; + { + int ret = slice_del(z); /* delete, line 109 */ + if (ret < 0) + return ret; + } + return 1; +} + +extern int porter_UTF_8_stem(struct SN_env *z) { + z->B[0] = 0; /* unset Y_found, line 115 */ + { + int c1 = z->c; /* do, line 116 */ + z->bra = z->c; /* [, line 116 */ + if (!(eq_s(z, 1, s_29))) + goto lab0; + z->ket = z->c; /* ], line 116 */ + { + int ret = slice_from_s(z, 1, s_30); /* <-, line 116 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set Y_found, line 116 */ + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 117 */ + while (1) { /* repeat, line 117 */ + int c3 = z->c; + while (1) { /* goto, line 117 */ + int c4 = z->c; + if (in_grouping_U(z, g_v, 97, 121, 0)) + goto lab3; + z->bra = z->c; /* [, line 117 */ + if (!(eq_s(z, 1, s_31))) + goto lab3; + z->ket = z->c; /* ], line 117 */ + z->c = c4; + break; + lab3: + z->c = c4; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab2; + z->c = ret; /* goto, line 117 */ + } + } + { + int ret = slice_from_s(z, 1, s_32); /* <-, line 117 */ + if (ret < 0) + return ret; + } + z->B[0] = 1; /* set Y_found, line 117 */ + continue; + lab2: + z->c = c3; + break; + } + z->c = c2; + } + z->I[0] = z->l; + z->I[1] = z->l; + { + int c5 = z->c; /* do, line 121 */ + { /* gopast */ /* grouping v, line 122 */ + int ret = out_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + { /* gopast */ /* non v, line 122 */ + int ret = in_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 122 */ + { /* gopast */ /* grouping v, line 123 */ + int ret = out_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + { /* gopast */ /* non v, line 123 */ + int ret = in_grouping_U(z, g_v, 97, 121, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 123 */ + lab4: + z->c = c5; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 126 */ + + { + int m6 = z->l - z->c; + (void)m6; /* do, line 127 */ + { + int ret = r_Step_1a(z); + if (ret == 0) + goto lab5; /* call Step_1a, line 127 */ + if (ret < 0) + return ret; + } + lab5: + z->c = z->l - m6; + } + { + int m7 = z->l - z->c; + (void)m7; /* do, line 128 */ + { + int ret = r_Step_1b(z); + if (ret == 0) + goto lab6; /* call Step_1b, line 128 */ + if (ret < 0) + return ret; + } + lab6: + z->c = z->l - m7; + } + { + int m8 = z->l - z->c; + (void)m8; /* do, line 129 */ + { + int ret = r_Step_1c(z); + if (ret == 0) + goto lab7; /* call Step_1c, line 129 */ + if (ret < 0) + return ret; + } + lab7: + z->c = z->l - m8; + } + { + int m9 = z->l - z->c; + (void)m9; /* do, line 130 */ + { + int ret = r_Step_2(z); + if (ret == 0) + goto lab8; /* call Step_2, line 130 */ + if (ret < 0) + return ret; + } + lab8: + z->c = z->l - m9; + } + { + int m10 = z->l - z->c; + (void)m10; /* do, line 131 */ + { + int ret = r_Step_3(z); + if (ret == 0) + goto lab9; /* call Step_3, line 131 */ + if (ret < 0) + return ret; + } + lab9: + z->c = z->l - m10; + } + { + int m11 = z->l - z->c; + (void)m11; /* do, line 132 */ + { + int ret = r_Step_4(z); + if (ret == 0) + goto lab10; /* call Step_4, line 132 */ + if (ret < 0) + return ret; + } + lab10: + z->c = z->l - m11; + } + { + int m12 = z->l - z->c; + (void)m12; /* do, line 133 */ + { + int ret = r_Step_5a(z); + if (ret == 0) + goto lab11; /* call Step_5a, line 133 */ + if (ret < 0) + return ret; + } + lab11: + z->c = z->l - m12; + } + { + int m13 = z->l - z->c; + (void)m13; /* do, line 134 */ + { + int ret = r_Step_5b(z); + if (ret == 0) + goto lab12; /* call Step_5b, line 134 */ + if (ret < 0) + return ret; + } + lab12: + z->c = z->l - m13; + } + z->c = z->lb; + { + int c14 = z->c; /* do, line 137 */ + if (!(z->B[0])) + goto lab13; /* Boolean test Y_found, line 137 */ + while (1) { /* repeat, line 137 */ + int c15 = z->c; + while (1) { /* goto, line 137 */ + int c16 = z->c; + z->bra = z->c; /* [, line 137 */ + if (!(eq_s(z, 1, s_33))) + goto lab15; + z->ket = z->c; /* ], line 137 */ + z->c = c16; + break; + lab15: + z->c = c16; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab14; + z->c = ret; /* goto, line 137 */ + } + } + { + int ret = slice_from_s(z, 1, s_34); /* <-, line 137 */ + if (ret < 0) + return ret; + } + continue; + lab14: + z->c = c15; + break; + } + lab13: + z->c = c14; + } + return 1; +} + +extern struct SN_env *porter_UTF_8_create_env(void) { return SN_create_env(0, 2, 1); } + +extern void porter_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_porter.h b/internal/cpp/stemmer/stem_UTF_8_porter.h new file mode 100644 index 00000000000..f5a3cbcaf6c --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_porter.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *porter_UTF_8_create_env(void); +extern void porter_UTF_8_close_env(struct SN_env *z); + +extern int porter_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_portuguese.cpp b/internal/cpp/stemmer/stem_UTF_8_portuguese.cpp new file mode 100644 index 00000000000..dfba9643518 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_portuguese.cpp @@ -0,0 +1,1217 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int portuguese_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_residual_form(struct SN_env *z); +static int r_residual_suffix(struct SN_env *z); +static int r_verb_suffix(struct SN_env *z); +static int r_standard_suffix(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_RV(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *portuguese_UTF_8_create_env(void); +extern void portuguese_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[2] = {0xC3, 0xA3}; +static const symbol s_0_2[2] = {0xC3, 0xB5}; + +static const struct among a_0[3] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {2, s_0_1, 0, 1, 0}, + /* 2 */ {2, s_0_2, 0, 2, 0}}; + +static const symbol s_1_1[2] = {'a', '~'}; +static const symbol s_1_2[2] = {'o', '~'}; + +static const struct among a_1[3] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {2, s_1_1, 0, 1, 0}, + /* 2 */ {2, s_1_2, 0, 2, 0}}; + +static const symbol s_2_0[2] = {'i', 'c'}; +static const symbol s_2_1[2] = {'a', 'd'}; +static const symbol s_2_2[2] = {'o', 's'}; +static const symbol s_2_3[2] = {'i', 'v'}; + +static const struct among a_2[4] = { + /* 0 */ {2, s_2_0, -1, -1, 0}, + /* 1 */ {2, s_2_1, -1, -1, 0}, + /* 2 */ {2, s_2_2, -1, -1, 0}, + /* 3 */ {2, s_2_3, -1, 1, 0}}; + +static const symbol s_3_0[4] = {'a', 'n', 't', 'e'}; +static const symbol s_3_1[4] = {'a', 'v', 'e', 'l'}; +static const symbol s_3_2[5] = {0xC3, 0xAD, 'v', 'e', 'l'}; + +static const struct among a_3[3] = { + /* 0 */ {4, s_3_0, -1, 1, 0}, + /* 1 */ {4, s_3_1, -1, 1, 0}, + /* 2 */ {5, s_3_2, -1, 1, 0}}; + +static const symbol s_4_0[2] = {'i', 'c'}; +static const symbol s_4_1[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_4_2[2] = {'i', 'v'}; + +static const struct among a_4[3] = { + /* 0 */ {2, s_4_0, -1, 1, 0}, + /* 1 */ {4, s_4_1, -1, 1, 0}, + /* 2 */ {2, s_4_2, -1, 1, 0}}; + +static const symbol s_5_0[3] = {'i', 'c', 'a'}; +static const symbol s_5_1[6] = {0xC3, 0xA2, 'n', 'c', 'i', 'a'}; +static const symbol s_5_2[6] = {0xC3, 0xAA, 'n', 'c', 'i', 'a'}; +static const symbol s_5_3[3] = {'i', 'r', 'a'}; +static const symbol s_5_4[5] = {'a', 'd', 'o', 'r', 'a'}; +static const symbol s_5_5[3] = {'o', 's', 'a'}; +static const symbol s_5_6[4] = {'i', 's', 't', 'a'}; +static const symbol s_5_7[3] = {'i', 'v', 'a'}; +static const symbol s_5_8[3] = {'e', 'z', 'a'}; +static const symbol s_5_9[6] = {'l', 'o', 'g', 0xC3, 0xAD, 'a'}; +static const symbol s_5_10[5] = {'i', 'd', 'a', 'd', 'e'}; +static const symbol s_5_11[4] = {'a', 'n', 't', 'e'}; +static const symbol s_5_12[5] = {'m', 'e', 'n', 't', 'e'}; +static const symbol s_5_13[6] = {'a', 'm', 'e', 'n', 't', 'e'}; +static const symbol s_5_14[5] = {0xC3, 0xA1, 'v', 'e', 'l'}; +static const symbol s_5_15[5] = {0xC3, 0xAD, 'v', 'e', 'l'}; +static const symbol s_5_16[6] = {'u', 'c', 'i', 0xC3, 0xB3, 'n'}; +static const symbol s_5_17[3] = {'i', 'c', 'o'}; +static const symbol s_5_18[4] = {'i', 's', 'm', 'o'}; +static const symbol s_5_19[3] = {'o', 's', 'o'}; +static const symbol s_5_20[6] = {'a', 'm', 'e', 'n', 't', 'o'}; +static const symbol s_5_21[6] = {'i', 'm', 'e', 'n', 't', 'o'}; +static const symbol s_5_22[3] = {'i', 'v', 'o'}; +static const symbol s_5_23[6] = {'a', 0xC3, 0xA7, 'a', '~', 'o'}; +static const symbol s_5_24[4] = {'a', 'd', 'o', 'r'}; +static const symbol s_5_25[4] = {'i', 'c', 'a', 's'}; +static const symbol s_5_26[7] = {0xC3, 0xAA, 'n', 'c', 'i', 'a', 's'}; +static const symbol s_5_27[4] = {'i', 'r', 'a', 's'}; +static const symbol s_5_28[6] = {'a', 'd', 'o', 'r', 'a', 's'}; +static const symbol s_5_29[4] = {'o', 's', 'a', 's'}; +static const symbol s_5_30[5] = {'i', 's', 't', 'a', 's'}; +static const symbol s_5_31[4] = {'i', 'v', 'a', 's'}; +static const symbol s_5_32[4] = {'e', 'z', 'a', 's'}; +static const symbol s_5_33[7] = {'l', 'o', 'g', 0xC3, 0xAD, 'a', 's'}; +static const symbol s_5_34[6] = {'i', 'd', 'a', 'd', 'e', 's'}; +static const symbol s_5_35[7] = {'u', 'c', 'i', 'o', 'n', 'e', 's'}; +static const symbol s_5_36[6] = {'a', 'd', 'o', 'r', 'e', 's'}; +static const symbol s_5_37[5] = {'a', 'n', 't', 'e', 's'}; +static const symbol s_5_38[7] = {'a', 0xC3, 0xA7, 'o', '~', 'e', 's'}; +static const symbol s_5_39[4] = {'i', 'c', 'o', 's'}; +static const symbol s_5_40[5] = {'i', 's', 'm', 'o', 's'}; +static const symbol s_5_41[4] = {'o', 's', 'o', 's'}; +static const symbol s_5_42[7] = {'a', 'm', 'e', 'n', 't', 'o', 's'}; +static const symbol s_5_43[7] = {'i', 'm', 'e', 'n', 't', 'o', 's'}; +static const symbol s_5_44[4] = {'i', 'v', 'o', 's'}; + +static const struct among a_5[45] = { + /* 0 */ {3, s_5_0, -1, 1, 0}, + /* 1 */ {6, s_5_1, -1, 1, 0}, + /* 2 */ {6, s_5_2, -1, 4, 0}, + /* 3 */ {3, s_5_3, -1, 9, 0}, + /* 4 */ {5, s_5_4, -1, 1, 0}, + /* 5 */ {3, s_5_5, -1, 1, 0}, + /* 6 */ {4, s_5_6, -1, 1, 0}, + /* 7 */ {3, s_5_7, -1, 8, 0}, + /* 8 */ {3, s_5_8, -1, 1, 0}, + /* 9 */ {6, s_5_9, -1, 2, 0}, + /* 10 */ {5, s_5_10, -1, 7, 0}, + /* 11 */ {4, s_5_11, -1, 1, 0}, + /* 12 */ {5, s_5_12, -1, 6, 0}, + /* 13 */ {6, s_5_13, 12, 5, 0}, + /* 14 */ {5, s_5_14, -1, 1, 0}, + /* 15 */ {5, s_5_15, -1, 1, 0}, + /* 16 */ {6, s_5_16, -1, 3, 0}, + /* 17 */ {3, s_5_17, -1, 1, 0}, + /* 18 */ {4, s_5_18, -1, 1, 0}, + /* 19 */ {3, s_5_19, -1, 1, 0}, + /* 20 */ {6, s_5_20, -1, 1, 0}, + /* 21 */ {6, s_5_21, -1, 1, 0}, + /* 22 */ {3, s_5_22, -1, 8, 0}, + /* 23 */ {6, s_5_23, -1, 1, 0}, + /* 24 */ {4, s_5_24, -1, 1, 0}, + /* 25 */ {4, s_5_25, -1, 1, 0}, + /* 26 */ {7, s_5_26, -1, 4, 0}, + /* 27 */ {4, s_5_27, -1, 9, 0}, + /* 28 */ {6, s_5_28, -1, 1, 0}, + /* 29 */ {4, s_5_29, -1, 1, 0}, + /* 30 */ {5, s_5_30, -1, 1, 0}, + /* 31 */ {4, s_5_31, -1, 8, 0}, + /* 32 */ {4, s_5_32, -1, 1, 0}, + /* 33 */ {7, s_5_33, -1, 2, 0}, + /* 34 */ {6, s_5_34, -1, 7, 0}, + /* 35 */ {7, s_5_35, -1, 3, 0}, + /* 36 */ {6, s_5_36, -1, 1, 0}, + /* 37 */ {5, s_5_37, -1, 1, 0}, + /* 38 */ {7, s_5_38, -1, 1, 0}, + /* 39 */ {4, s_5_39, -1, 1, 0}, + /* 40 */ {5, s_5_40, -1, 1, 0}, + /* 41 */ {4, s_5_41, -1, 1, 0}, + /* 42 */ {7, s_5_42, -1, 1, 0}, + /* 43 */ {7, s_5_43, -1, 1, 0}, + /* 44 */ {4, s_5_44, -1, 8, 0}}; + +static const symbol s_6_0[3] = {'a', 'd', 'a'}; +static const symbol s_6_1[3] = {'i', 'd', 'a'}; +static const symbol s_6_2[2] = {'i', 'a'}; +static const symbol s_6_3[4] = {'a', 'r', 'i', 'a'}; +static const symbol s_6_4[4] = {'e', 'r', 'i', 'a'}; +static const symbol s_6_5[4] = {'i', 'r', 'i', 'a'}; +static const symbol s_6_6[3] = {'a', 'r', 'a'}; +static const symbol s_6_7[3] = {'e', 'r', 'a'}; +static const symbol s_6_8[3] = {'i', 'r', 'a'}; +static const symbol s_6_9[3] = {'a', 'v', 'a'}; +static const symbol s_6_10[4] = {'a', 's', 's', 'e'}; +static const symbol s_6_11[4] = {'e', 's', 's', 'e'}; +static const symbol s_6_12[4] = {'i', 's', 's', 'e'}; +static const symbol s_6_13[4] = {'a', 's', 't', 'e'}; +static const symbol s_6_14[4] = {'e', 's', 't', 'e'}; +static const symbol s_6_15[4] = {'i', 's', 't', 'e'}; +static const symbol s_6_16[2] = {'e', 'i'}; +static const symbol s_6_17[4] = {'a', 'r', 'e', 'i'}; +static const symbol s_6_18[4] = {'e', 'r', 'e', 'i'}; +static const symbol s_6_19[4] = {'i', 'r', 'e', 'i'}; +static const symbol s_6_20[2] = {'a', 'm'}; +static const symbol s_6_21[3] = {'i', 'a', 'm'}; +static const symbol s_6_22[5] = {'a', 'r', 'i', 'a', 'm'}; +static const symbol s_6_23[5] = {'e', 'r', 'i', 'a', 'm'}; +static const symbol s_6_24[5] = {'i', 'r', 'i', 'a', 'm'}; +static const symbol s_6_25[4] = {'a', 'r', 'a', 'm'}; +static const symbol s_6_26[4] = {'e', 'r', 'a', 'm'}; +static const symbol s_6_27[4] = {'i', 'r', 'a', 'm'}; +static const symbol s_6_28[4] = {'a', 'v', 'a', 'm'}; +static const symbol s_6_29[2] = {'e', 'm'}; +static const symbol s_6_30[4] = {'a', 'r', 'e', 'm'}; +static const symbol s_6_31[4] = {'e', 'r', 'e', 'm'}; +static const symbol s_6_32[4] = {'i', 'r', 'e', 'm'}; +static const symbol s_6_33[5] = {'a', 's', 's', 'e', 'm'}; +static const symbol s_6_34[5] = {'e', 's', 's', 'e', 'm'}; +static const symbol s_6_35[5] = {'i', 's', 's', 'e', 'm'}; +static const symbol s_6_36[3] = {'a', 'd', 'o'}; +static const symbol s_6_37[3] = {'i', 'd', 'o'}; +static const symbol s_6_38[4] = {'a', 'n', 'd', 'o'}; +static const symbol s_6_39[4] = {'e', 'n', 'd', 'o'}; +static const symbol s_6_40[4] = {'i', 'n', 'd', 'o'}; +static const symbol s_6_41[5] = {'a', 'r', 'a', '~', 'o'}; +static const symbol s_6_42[5] = {'e', 'r', 'a', '~', 'o'}; +static const symbol s_6_43[5] = {'i', 'r', 'a', '~', 'o'}; +static const symbol s_6_44[2] = {'a', 'r'}; +static const symbol s_6_45[2] = {'e', 'r'}; +static const symbol s_6_46[2] = {'i', 'r'}; +static const symbol s_6_47[2] = {'a', 's'}; +static const symbol s_6_48[4] = {'a', 'd', 'a', 's'}; +static const symbol s_6_49[4] = {'i', 'd', 'a', 's'}; +static const symbol s_6_50[3] = {'i', 'a', 's'}; +static const symbol s_6_51[5] = {'a', 'r', 'i', 'a', 's'}; +static const symbol s_6_52[5] = {'e', 'r', 'i', 'a', 's'}; +static const symbol s_6_53[5] = {'i', 'r', 'i', 'a', 's'}; +static const symbol s_6_54[4] = {'a', 'r', 'a', 's'}; +static const symbol s_6_55[4] = {'e', 'r', 'a', 's'}; +static const symbol s_6_56[4] = {'i', 'r', 'a', 's'}; +static const symbol s_6_57[4] = {'a', 'v', 'a', 's'}; +static const symbol s_6_58[2] = {'e', 's'}; +static const symbol s_6_59[5] = {'a', 'r', 'd', 'e', 's'}; +static const symbol s_6_60[5] = {'e', 'r', 'd', 'e', 's'}; +static const symbol s_6_61[5] = {'i', 'r', 'd', 'e', 's'}; +static const symbol s_6_62[4] = {'a', 'r', 'e', 's'}; +static const symbol s_6_63[4] = {'e', 'r', 'e', 's'}; +static const symbol s_6_64[4] = {'i', 'r', 'e', 's'}; +static const symbol s_6_65[5] = {'a', 's', 's', 'e', 's'}; +static const symbol s_6_66[5] = {'e', 's', 's', 'e', 's'}; +static const symbol s_6_67[5] = {'i', 's', 's', 'e', 's'}; +static const symbol s_6_68[5] = {'a', 's', 't', 'e', 's'}; +static const symbol s_6_69[5] = {'e', 's', 't', 'e', 's'}; +static const symbol s_6_70[5] = {'i', 's', 't', 'e', 's'}; +static const symbol s_6_71[2] = {'i', 's'}; +static const symbol s_6_72[3] = {'a', 'i', 's'}; +static const symbol s_6_73[3] = {'e', 'i', 's'}; +static const symbol s_6_74[5] = {'a', 'r', 'e', 'i', 's'}; +static const symbol s_6_75[5] = {'e', 'r', 'e', 'i', 's'}; +static const symbol s_6_76[5] = {'i', 'r', 'e', 'i', 's'}; +static const symbol s_6_77[6] = {0xC3, 0xA1, 'r', 'e', 'i', 's'}; +static const symbol s_6_78[6] = {0xC3, 0xA9, 'r', 'e', 'i', 's'}; +static const symbol s_6_79[6] = {0xC3, 0xAD, 'r', 'e', 'i', 's'}; +static const symbol s_6_80[7] = {0xC3, 0xA1, 's', 's', 'e', 'i', 's'}; +static const symbol s_6_81[7] = {0xC3, 0xA9, 's', 's', 'e', 'i', 's'}; +static const symbol s_6_82[7] = {0xC3, 0xAD, 's', 's', 'e', 'i', 's'}; +static const symbol s_6_83[6] = {0xC3, 0xA1, 'v', 'e', 'i', 's'}; +static const symbol s_6_84[5] = {0xC3, 0xAD, 'e', 'i', 's'}; +static const symbol s_6_85[7] = {'a', 'r', 0xC3, 0xAD, 'e', 'i', 's'}; +static const symbol s_6_86[7] = {'e', 'r', 0xC3, 0xAD, 'e', 'i', 's'}; +static const symbol s_6_87[7] = {'i', 'r', 0xC3, 0xAD, 'e', 'i', 's'}; +static const symbol s_6_88[4] = {'a', 'd', 'o', 's'}; +static const symbol s_6_89[4] = {'i', 'd', 'o', 's'}; +static const symbol s_6_90[4] = {'a', 'm', 'o', 's'}; +static const symbol s_6_91[7] = {0xC3, 0xA1, 'r', 'a', 'm', 'o', 's'}; +static const symbol s_6_92[7] = {0xC3, 0xA9, 'r', 'a', 'm', 'o', 's'}; +static const symbol s_6_93[7] = {0xC3, 0xAD, 'r', 'a', 'm', 'o', 's'}; +static const symbol s_6_94[7] = {0xC3, 0xA1, 'v', 'a', 'm', 'o', 's'}; +static const symbol s_6_95[6] = {0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_6_96[8] = {'a', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_6_97[8] = {'e', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_6_98[8] = {'i', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_6_99[4] = {'e', 'm', 'o', 's'}; +static const symbol s_6_100[6] = {'a', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_6_101[6] = {'e', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_6_102[6] = {'i', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_6_103[8] = {0xC3, 0xA1, 's', 's', 'e', 'm', 'o', 's'}; +static const symbol s_6_104[8] = {0xC3, 0xAA, 's', 's', 'e', 'm', 'o', 's'}; +static const symbol s_6_105[8] = {0xC3, 0xAD, 's', 's', 'e', 'm', 'o', 's'}; +static const symbol s_6_106[4] = {'i', 'm', 'o', 's'}; +static const symbol s_6_107[5] = {'a', 'r', 'm', 'o', 's'}; +static const symbol s_6_108[5] = {'e', 'r', 'm', 'o', 's'}; +static const symbol s_6_109[5] = {'i', 'r', 'm', 'o', 's'}; +static const symbol s_6_110[5] = {0xC3, 0xA1, 'm', 'o', 's'}; +static const symbol s_6_111[5] = {'a', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_6_112[5] = {'e', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_6_113[5] = {'i', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_6_114[2] = {'e', 'u'}; +static const symbol s_6_115[2] = {'i', 'u'}; +static const symbol s_6_116[2] = {'o', 'u'}; +static const symbol s_6_117[4] = {'a', 'r', 0xC3, 0xA1}; +static const symbol s_6_118[4] = {'e', 'r', 0xC3, 0xA1}; +static const symbol s_6_119[4] = {'i', 'r', 0xC3, 0xA1}; + +static const struct among a_6[120] = { + /* 0 */ {3, s_6_0, -1, 1, 0}, + /* 1 */ {3, s_6_1, -1, 1, 0}, + /* 2 */ {2, s_6_2, -1, 1, 0}, + /* 3 */ {4, s_6_3, 2, 1, 0}, + /* 4 */ {4, s_6_4, 2, 1, 0}, + /* 5 */ {4, s_6_5, 2, 1, 0}, + /* 6 */ {3, s_6_6, -1, 1, 0}, + /* 7 */ {3, s_6_7, -1, 1, 0}, + /* 8 */ {3, s_6_8, -1, 1, 0}, + /* 9 */ {3, s_6_9, -1, 1, 0}, + /* 10 */ {4, s_6_10, -1, 1, 0}, + /* 11 */ {4, s_6_11, -1, 1, 0}, + /* 12 */ {4, s_6_12, -1, 1, 0}, + /* 13 */ {4, s_6_13, -1, 1, 0}, + /* 14 */ {4, s_6_14, -1, 1, 0}, + /* 15 */ {4, s_6_15, -1, 1, 0}, + /* 16 */ {2, s_6_16, -1, 1, 0}, + /* 17 */ {4, s_6_17, 16, 1, 0}, + /* 18 */ {4, s_6_18, 16, 1, 0}, + /* 19 */ {4, s_6_19, 16, 1, 0}, + /* 20 */ {2, s_6_20, -1, 1, 0}, + /* 21 */ {3, s_6_21, 20, 1, 0}, + /* 22 */ {5, s_6_22, 21, 1, 0}, + /* 23 */ {5, s_6_23, 21, 1, 0}, + /* 24 */ {5, s_6_24, 21, 1, 0}, + /* 25 */ {4, s_6_25, 20, 1, 0}, + /* 26 */ {4, s_6_26, 20, 1, 0}, + /* 27 */ {4, s_6_27, 20, 1, 0}, + /* 28 */ {4, s_6_28, 20, 1, 0}, + /* 29 */ {2, s_6_29, -1, 1, 0}, + /* 30 */ {4, s_6_30, 29, 1, 0}, + /* 31 */ {4, s_6_31, 29, 1, 0}, + /* 32 */ {4, s_6_32, 29, 1, 0}, + /* 33 */ {5, s_6_33, 29, 1, 0}, + /* 34 */ {5, s_6_34, 29, 1, 0}, + /* 35 */ {5, s_6_35, 29, 1, 0}, + /* 36 */ {3, s_6_36, -1, 1, 0}, + /* 37 */ {3, s_6_37, -1, 1, 0}, + /* 38 */ {4, s_6_38, -1, 1, 0}, + /* 39 */ {4, s_6_39, -1, 1, 0}, + /* 40 */ {4, s_6_40, -1, 1, 0}, + /* 41 */ {5, s_6_41, -1, 1, 0}, + /* 42 */ {5, s_6_42, -1, 1, 0}, + /* 43 */ {5, s_6_43, -1, 1, 0}, + /* 44 */ {2, s_6_44, -1, 1, 0}, + /* 45 */ {2, s_6_45, -1, 1, 0}, + /* 46 */ {2, s_6_46, -1, 1, 0}, + /* 47 */ {2, s_6_47, -1, 1, 0}, + /* 48 */ {4, s_6_48, 47, 1, 0}, + /* 49 */ {4, s_6_49, 47, 1, 0}, + /* 50 */ {3, s_6_50, 47, 1, 0}, + /* 51 */ {5, s_6_51, 50, 1, 0}, + /* 52 */ {5, s_6_52, 50, 1, 0}, + /* 53 */ {5, s_6_53, 50, 1, 0}, + /* 54 */ {4, s_6_54, 47, 1, 0}, + /* 55 */ {4, s_6_55, 47, 1, 0}, + /* 56 */ {4, s_6_56, 47, 1, 0}, + /* 57 */ {4, s_6_57, 47, 1, 0}, + /* 58 */ {2, s_6_58, -1, 1, 0}, + /* 59 */ {5, s_6_59, 58, 1, 0}, + /* 60 */ {5, s_6_60, 58, 1, 0}, + /* 61 */ {5, s_6_61, 58, 1, 0}, + /* 62 */ {4, s_6_62, 58, 1, 0}, + /* 63 */ {4, s_6_63, 58, 1, 0}, + /* 64 */ {4, s_6_64, 58, 1, 0}, + /* 65 */ {5, s_6_65, 58, 1, 0}, + /* 66 */ {5, s_6_66, 58, 1, 0}, + /* 67 */ {5, s_6_67, 58, 1, 0}, + /* 68 */ {5, s_6_68, 58, 1, 0}, + /* 69 */ {5, s_6_69, 58, 1, 0}, + /* 70 */ {5, s_6_70, 58, 1, 0}, + /* 71 */ {2, s_6_71, -1, 1, 0}, + /* 72 */ {3, s_6_72, 71, 1, 0}, + /* 73 */ {3, s_6_73, 71, 1, 0}, + /* 74 */ {5, s_6_74, 73, 1, 0}, + /* 75 */ {5, s_6_75, 73, 1, 0}, + /* 76 */ {5, s_6_76, 73, 1, 0}, + /* 77 */ {6, s_6_77, 73, 1, 0}, + /* 78 */ {6, s_6_78, 73, 1, 0}, + /* 79 */ {6, s_6_79, 73, 1, 0}, + /* 80 */ {7, s_6_80, 73, 1, 0}, + /* 81 */ {7, s_6_81, 73, 1, 0}, + /* 82 */ {7, s_6_82, 73, 1, 0}, + /* 83 */ {6, s_6_83, 73, 1, 0}, + /* 84 */ {5, s_6_84, 73, 1, 0}, + /* 85 */ {7, s_6_85, 84, 1, 0}, + /* 86 */ {7, s_6_86, 84, 1, 0}, + /* 87 */ {7, s_6_87, 84, 1, 0}, + /* 88 */ {4, s_6_88, -1, 1, 0}, + /* 89 */ {4, s_6_89, -1, 1, 0}, + /* 90 */ {4, s_6_90, -1, 1, 0}, + /* 91 */ {7, s_6_91, 90, 1, 0}, + /* 92 */ {7, s_6_92, 90, 1, 0}, + /* 93 */ {7, s_6_93, 90, 1, 0}, + /* 94 */ {7, s_6_94, 90, 1, 0}, + /* 95 */ {6, s_6_95, 90, 1, 0}, + /* 96 */ {8, s_6_96, 95, 1, 0}, + /* 97 */ {8, s_6_97, 95, 1, 0}, + /* 98 */ {8, s_6_98, 95, 1, 0}, + /* 99 */ {4, s_6_99, -1, 1, 0}, + /*100 */ {6, s_6_100, 99, 1, 0}, + /*101 */ {6, s_6_101, 99, 1, 0}, + /*102 */ {6, s_6_102, 99, 1, 0}, + /*103 */ {8, s_6_103, 99, 1, 0}, + /*104 */ {8, s_6_104, 99, 1, 0}, + /*105 */ {8, s_6_105, 99, 1, 0}, + /*106 */ {4, s_6_106, -1, 1, 0}, + /*107 */ {5, s_6_107, -1, 1, 0}, + /*108 */ {5, s_6_108, -1, 1, 0}, + /*109 */ {5, s_6_109, -1, 1, 0}, + /*110 */ {5, s_6_110, -1, 1, 0}, + /*111 */ {5, s_6_111, -1, 1, 0}, + /*112 */ {5, s_6_112, -1, 1, 0}, + /*113 */ {5, s_6_113, -1, 1, 0}, + /*114 */ {2, s_6_114, -1, 1, 0}, + /*115 */ {2, s_6_115, -1, 1, 0}, + /*116 */ {2, s_6_116, -1, 1, 0}, + /*117 */ {4, s_6_117, -1, 1, 0}, + /*118 */ {4, s_6_118, -1, 1, 0}, + /*119 */ {4, s_6_119, -1, 1, 0}}; + +static const symbol s_7_0[1] = {'a'}; +static const symbol s_7_1[1] = {'i'}; +static const symbol s_7_2[1] = {'o'}; +static const symbol s_7_3[2] = {'o', 's'}; +static const symbol s_7_4[2] = {0xC3, 0xA1}; +static const symbol s_7_5[2] = {0xC3, 0xAD}; +static const symbol s_7_6[2] = {0xC3, 0xB3}; + +static const struct among a_7[7] = { + /* 0 */ {1, s_7_0, -1, 1, 0}, + /* 1 */ {1, s_7_1, -1, 1, 0}, + /* 2 */ {1, s_7_2, -1, 1, 0}, + /* 3 */ {2, s_7_3, -1, 1, 0}, + /* 4 */ {2, s_7_4, -1, 1, 0}, + /* 5 */ {2, s_7_5, -1, 1, 0}, + /* 6 */ {2, s_7_6, -1, 1, 0}}; + +static const symbol s_8_0[1] = {'e'}; +static const symbol s_8_1[2] = {0xC3, 0xA7}; +static const symbol s_8_2[2] = {0xC3, 0xA9}; +static const symbol s_8_3[2] = {0xC3, 0xAA}; + +static const struct among a_8[4] = { + /* 0 */ {1, s_8_0, -1, 1, 0}, + /* 1 */ {2, s_8_1, -1, 2, 0}, + /* 2 */ {2, s_8_2, -1, 1, 0}, + /* 3 */ {2, s_8_3, -1, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 19, 12, 2}; + +static const symbol s_0[] = {'a', '~'}; +static const symbol s_1[] = {'o', '~'}; +static const symbol s_2[] = {0xC3, 0xA3}; +static const symbol s_3[] = {0xC3, 0xB5}; +static const symbol s_4[] = {'l', 'o', 'g'}; +static const symbol s_5[] = {'u'}; +static const symbol s_6[] = {'e', 'n', 't', 'e'}; +static const symbol s_7[] = {'a', 't'}; +static const symbol s_8[] = {'a', 't'}; +static const symbol s_9[] = {'e'}; +static const symbol s_10[] = {'i', 'r'}; +static const symbol s_11[] = {'u'}; +static const symbol s_12[] = {'g'}; +static const symbol s_13[] = {'i'}; +static const symbol s_14[] = {'c'}; +static const symbol s_15[] = {'c'}; +static const symbol s_16[] = {'i'}; +static const symbol s_17[] = {'c'}; + +static int r_prelude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 36 */ + int c1 = z->c; + z->bra = z->c; /* [, line 37 */ + if (z->c + 1 >= z->l || (z->p[z->c + 1] != 163 && z->p[z->c + 1] != 181)) + among_var = 3; + else + among_var = find_among(z, a_0, 3); /* substring, line 37 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 37 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 2, s_0); /* <-, line 38 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 2, s_1); /* <-, line 39 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 40 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + z->I[2] = z->l; + { + int c1 = z->c; /* do, line 50 */ + { + int c2 = z->c; /* or, line 52 */ + if (in_grouping_U(z, g_v, 97, 250, 0)) + goto lab2; + { + int c3 = z->c; /* or, line 51 */ + if (out_grouping_U(z, g_v, 97, 250, 0)) + goto lab4; + { /* gopast */ /* grouping v, line 51 */ + int ret = out_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + goto lab3; + lab4: + z->c = c3; + if (in_grouping_U(z, g_v, 97, 250, 0)) + goto lab2; + { /* gopast */ /* non v, line 51 */ + int ret = in_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab2; + z->c += ret; + } + } + lab3: + goto lab1; + lab2: + z->c = c2; + if (out_grouping_U(z, g_v, 97, 250, 0)) + goto lab0; + { + int c4 = z->c; /* or, line 53 */ + if (out_grouping_U(z, g_v, 97, 250, 0)) + goto lab6; + { /* gopast */ /* grouping v, line 53 */ + int ret = out_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab6; + z->c += ret; + } + goto lab5; + lab6: + z->c = c4; + if (in_grouping_U(z, g_v, 97, 250, 0)) + goto lab0; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 53 */ + } + } + lab5:; + } + lab1: + z->I[0] = z->c; /* setmark pV, line 54 */ + lab0: + z->c = c1; + } + { + int c5 = z->c; /* do, line 56 */ + { /* gopast */ /* grouping v, line 57 */ + int ret = out_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 57 */ + int ret = in_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[1] = z->c; /* setmark p1, line 57 */ + { /* gopast */ /* grouping v, line 58 */ + int ret = out_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 58 */ + int ret = in_grouping_U(z, g_v, 97, 250, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[2] = z->c; /* setmark p2, line 58 */ + lab7: + z->c = c5; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 62 */ + int c1 = z->c; + z->bra = z->c; /* [, line 63 */ + if (z->c + 1 >= z->l || z->p[z->c + 1] != 126) + among_var = 3; + else + among_var = find_among(z, a_1, 3); /* substring, line 63 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 63 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 2, s_2); /* <-, line 64 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 2, s_3); /* <-, line 65 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 66 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_RV(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[2] <= z->c)) + return 0; + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 77 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((839714 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_5, 45); /* substring, line 77 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 77 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 93 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 93 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 98 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_4); /* <-, line 98 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 102 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 1, s_5); /* <-, line 102 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 106 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 4, s_6); /* <-, line 106 */ + if (ret < 0) + return ret; + } + break; + case 5: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 110 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 110 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 111 */ + z->ket = z->c; /* [, line 112 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4718616 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab0; + } + among_var = find_among_b(z, a_2, 4); /* substring, line 112 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 112 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call R2, line 112 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 112 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab0; + } + case 1: + z->ket = z->c; /* [, line 113 */ + if (!(eq_s_b(z, 2, s_7))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 113 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call R2, line 113 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 113 */ + if (ret < 0) + return ret; + } + break; + } + lab0:; + } + break; + case 6: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 122 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 122 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 123 */ + z->ket = z->c; /* [, line 124 */ + if (z->c - 3 <= z->lb || (z->p[z->c - 1] != 101 && z->p[z->c - 1] != 108)) { + z->c = z->l - m_keep; + goto lab1; + } + among_var = find_among_b(z, a_3, 3); /* substring, line 124 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 124 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab1; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call R2, line 127 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 127 */ + if (ret < 0) + return ret; + } + break; + } + lab1:; + } + break; + case 7: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 134 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 134 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 135 */ + z->ket = z->c; /* [, line 136 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4198408 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab2; + } + among_var = find_among_b(z, a_4, 3); /* substring, line 136 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab2; + } + z->bra = z->c; /* ], line 136 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab2; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab2; + } /* call R2, line 139 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 139 */ + if (ret < 0) + return ret; + } + break; + } + lab2:; + } + break; + case 8: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 146 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 146 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 147 */ + z->ket = z->c; /* [, line 148 */ + if (!(eq_s_b(z, 2, s_8))) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 148 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 148 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 148 */ + if (ret < 0) + return ret; + } + lab3:; + } + break; + case 9: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 153 */ + if (ret < 0) + return ret; + } + if (!(eq_s_b(z, 1, s_9))) + return 0; + { + int ret = slice_from_s(z, 2, s_10); /* <-, line 154 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 159 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 159 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 160 */ + among_var = find_among_b(z, a_6, 120); /* substring, line 160 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 160 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int ret = slice_del(z); /* delete, line 179 */ + if (ret < 0) + return ret; + } break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_residual_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 184 */ + among_var = find_among_b(z, a_7, 7); /* substring, line 184 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 184 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 187 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 187 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_residual_form(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 192 */ + among_var = find_among_b(z, a_8, 4); /* substring, line 192 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 192 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 194 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 194 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 194 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 194 */ + if (!(eq_s_b(z, 1, s_11))) + goto lab1; + z->bra = z->c; /* ], line 194 */ + { + int m_test = z->l - z->c; /* test, line 194 */ + if (!(eq_s_b(z, 1, s_12))) + goto lab1; + z->c = z->l - m_test; + } + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_13))) + return 0; + z->bra = z->c; /* ], line 195 */ + { + int m_test = z->l - z->c; /* test, line 195 */ + if (!(eq_s_b(z, 1, s_14))) + return 0; + z->c = z->l - m_test; + } + } + lab0: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 195 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 195 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = slice_from_s(z, 1, s_15); /* <-, line 196 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +extern int portuguese_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 202 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 202 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 203 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 203 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 204 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 205 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 209 */ + { + int m5 = z->l - z->c; + (void)m5; /* and, line 207 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 206 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab6; /* call standard_suffix, line 206 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = z->l - m6; + { + int ret = r_verb_suffix(z); + if (ret == 0) + goto lab4; /* call verb_suffix, line 206 */ + if (ret < 0) + return ret; + } + } + lab5: + z->c = z->l - m5; + { + int m7 = z->l - z->c; + (void)m7; /* do, line 207 */ + z->ket = z->c; /* [, line 207 */ + if (!(eq_s_b(z, 1, s_16))) + goto lab7; + z->bra = z->c; /* ], line 207 */ + { + int m_test = z->l - z->c; /* test, line 207 */ + if (!(eq_s_b(z, 1, s_17))) + goto lab7; + z->c = z->l - m_test; + } + { + int ret = r_RV(z); + if (ret == 0) + goto lab7; /* call RV, line 207 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 207 */ + if (ret < 0) + return ret; + } + lab7: + z->c = z->l - m7; + } + } + goto lab3; + lab4: + z->c = z->l - m4; + { + int ret = r_residual_suffix(z); + if (ret == 0) + goto lab2; /* call residual_suffix, line 209 */ + if (ret < 0) + return ret; + } + } + lab3: + lab2: + z->c = z->l - m3; + } + { + int m8 = z->l - z->c; + (void)m8; /* do, line 211 */ + { + int ret = r_residual_form(z); + if (ret == 0) + goto lab8; /* call residual_form, line 211 */ + if (ret < 0) + return ret; + } + lab8: + z->c = z->l - m8; + } + z->c = z->lb; + { + int c9 = z->c; /* do, line 213 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab9; /* call postlude, line 213 */ + if (ret < 0) + return ret; + } + lab9: + z->c = c9; + } + return 1; +} + +extern struct SN_env *portuguese_UTF_8_create_env(void) { return SN_create_env(0, 3, 0); } + +extern void portuguese_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_portuguese.h b/internal/cpp/stemmer/stem_UTF_8_portuguese.h new file mode 100644 index 00000000000..8b17cdd0e03 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_portuguese.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *portuguese_UTF_8_create_env(void); +extern void portuguese_UTF_8_close_env(struct SN_env *z); + +extern int portuguese_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_romanian.cpp b/internal/cpp/stemmer/stem_UTF_8_romanian.cpp new file mode 100644 index 00000000000..d414959d595 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_romanian.cpp @@ -0,0 +1,1111 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int romanian_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_vowel_suffix(struct SN_env *z); +static int r_verb_suffix(struct SN_env *z); +static int r_combo_suffix(struct SN_env *z); +static int r_standard_suffix(struct SN_env *z); +static int r_step_0(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_RV(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_prelude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *romanian_UTF_8_create_env(void); +extern void romanian_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[1] = {'I'}; +static const symbol s_0_2[1] = {'U'}; + +static const struct among a_0[3] = { + /* 0 */ {0, 0, -1, 3, 0}, + /* 1 */ {1, s_0_1, 0, 1, 0}, + /* 2 */ {1, s_0_2, 0, 2, 0}}; + +static const symbol s_1_0[2] = {'e', 'a'}; +static const symbol s_1_1[5] = {'a', 0xC5, 0xA3, 'i', 'a'}; +static const symbol s_1_2[3] = {'a', 'u', 'a'}; +static const symbol s_1_3[3] = {'i', 'u', 'a'}; +static const symbol s_1_4[5] = {'a', 0xC5, 0xA3, 'i', 'e'}; +static const symbol s_1_5[3] = {'e', 'l', 'e'}; +static const symbol s_1_6[3] = {'i', 'l', 'e'}; +static const symbol s_1_7[4] = {'i', 'i', 'l', 'e'}; +static const symbol s_1_8[3] = {'i', 'e', 'i'}; +static const symbol s_1_9[4] = {'a', 't', 'e', 'i'}; +static const symbol s_1_10[2] = {'i', 'i'}; +static const symbol s_1_11[4] = {'u', 'l', 'u', 'i'}; +static const symbol s_1_12[2] = {'u', 'l'}; +static const symbol s_1_13[4] = {'e', 'l', 'o', 'r'}; +static const symbol s_1_14[4] = {'i', 'l', 'o', 'r'}; +static const symbol s_1_15[5] = {'i', 'i', 'l', 'o', 'r'}; + +static const struct among a_1[16] = { + /* 0 */ {2, s_1_0, -1, 3, 0}, + /* 1 */ {5, s_1_1, -1, 7, 0}, + /* 2 */ {3, s_1_2, -1, 2, 0}, + /* 3 */ {3, s_1_3, -1, 4, 0}, + /* 4 */ {5, s_1_4, -1, 7, 0}, + /* 5 */ {3, s_1_5, -1, 3, 0}, + /* 6 */ {3, s_1_6, -1, 5, 0}, + /* 7 */ {4, s_1_7, 6, 4, 0}, + /* 8 */ {3, s_1_8, -1, 4, 0}, + /* 9 */ {4, s_1_9, -1, 6, 0}, + /* 10 */ {2, s_1_10, -1, 4, 0}, + /* 11 */ {4, s_1_11, -1, 1, 0}, + /* 12 */ {2, s_1_12, -1, 1, 0}, + /* 13 */ {4, s_1_13, -1, 3, 0}, + /* 14 */ {4, s_1_14, -1, 4, 0}, + /* 15 */ {5, s_1_15, 14, 4, 0}}; + +static const symbol s_2_0[5] = {'i', 'c', 'a', 'l', 'a'}; +static const symbol s_2_1[5] = {'i', 'c', 'i', 'v', 'a'}; +static const symbol s_2_2[5] = {'a', 't', 'i', 'v', 'a'}; +static const symbol s_2_3[5] = {'i', 't', 'i', 'v', 'a'}; +static const symbol s_2_4[5] = {'i', 'c', 'a', 'l', 'e'}; +static const symbol s_2_5[7] = {'a', 0xC5, 0xA3, 'i', 'u', 'n', 'e'}; +static const symbol s_2_6[7] = {'i', 0xC5, 0xA3, 'i', 'u', 'n', 'e'}; +static const symbol s_2_7[6] = {'a', 't', 'o', 'a', 'r', 'e'}; +static const symbol s_2_8[6] = {'i', 't', 'o', 'a', 'r', 'e'}; +static const symbol s_2_9[7] = {0xC4, 0x83, 't', 'o', 'a', 'r', 'e'}; +static const symbol s_2_10[7] = {'i', 'c', 'i', 't', 'a', 't', 'e'}; +static const symbol s_2_11[9] = {'a', 'b', 'i', 'l', 'i', 't', 'a', 't', 'e'}; +static const symbol s_2_12[9] = {'i', 'b', 'i', 'l', 'i', 't', 'a', 't', 'e'}; +static const symbol s_2_13[7] = {'i', 'v', 'i', 't', 'a', 't', 'e'}; +static const symbol s_2_14[5] = {'i', 'c', 'i', 'v', 'e'}; +static const symbol s_2_15[5] = {'a', 't', 'i', 'v', 'e'}; +static const symbol s_2_16[5] = {'i', 't', 'i', 'v', 'e'}; +static const symbol s_2_17[5] = {'i', 'c', 'a', 'l', 'i'}; +static const symbol s_2_18[5] = {'a', 't', 'o', 'r', 'i'}; +static const symbol s_2_19[7] = {'i', 'c', 'a', 't', 'o', 'r', 'i'}; +static const symbol s_2_20[5] = {'i', 't', 'o', 'r', 'i'}; +static const symbol s_2_21[6] = {0xC4, 0x83, 't', 'o', 'r', 'i'}; +static const symbol s_2_22[7] = {'i', 'c', 'i', 't', 'a', 't', 'i'}; +static const symbol s_2_23[9] = {'a', 'b', 'i', 'l', 'i', 't', 'a', 't', 'i'}; +static const symbol s_2_24[7] = {'i', 'v', 'i', 't', 'a', 't', 'i'}; +static const symbol s_2_25[5] = {'i', 'c', 'i', 'v', 'i'}; +static const symbol s_2_26[5] = {'a', 't', 'i', 'v', 'i'}; +static const symbol s_2_27[5] = {'i', 't', 'i', 'v', 'i'}; +static const symbol s_2_28[7] = {'i', 'c', 'i', 't', 0xC4, 0x83, 'i'}; +static const symbol s_2_29[9] = {'a', 'b', 'i', 'l', 'i', 't', 0xC4, 0x83, 'i'}; +static const symbol s_2_30[7] = {'i', 'v', 'i', 't', 0xC4, 0x83, 'i'}; +static const symbol s_2_31[9] = {'i', 'c', 'i', 't', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_2_32[11] = {'a', 'b', 'i', 'l', 'i', 't', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_2_33[9] = {'i', 'v', 'i', 't', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_2_34[4] = {'i', 'c', 'a', 'l'}; +static const symbol s_2_35[4] = {'a', 't', 'o', 'r'}; +static const symbol s_2_36[6] = {'i', 'c', 'a', 't', 'o', 'r'}; +static const symbol s_2_37[4] = {'i', 't', 'o', 'r'}; +static const symbol s_2_38[5] = {0xC4, 0x83, 't', 'o', 'r'}; +static const symbol s_2_39[4] = {'i', 'c', 'i', 'v'}; +static const symbol s_2_40[4] = {'a', 't', 'i', 'v'}; +static const symbol s_2_41[4] = {'i', 't', 'i', 'v'}; +static const symbol s_2_42[6] = {'i', 'c', 'a', 'l', 0xC4, 0x83}; +static const symbol s_2_43[6] = {'i', 'c', 'i', 'v', 0xC4, 0x83}; +static const symbol s_2_44[6] = {'a', 't', 'i', 'v', 0xC4, 0x83}; +static const symbol s_2_45[6] = {'i', 't', 'i', 'v', 0xC4, 0x83}; + +static const struct among a_2[46] = { + /* 0 */ {5, s_2_0, -1, 4, 0}, + /* 1 */ {5, s_2_1, -1, 4, 0}, + /* 2 */ {5, s_2_2, -1, 5, 0}, + /* 3 */ {5, s_2_3, -1, 6, 0}, + /* 4 */ {5, s_2_4, -1, 4, 0}, + /* 5 */ {7, s_2_5, -1, 5, 0}, + /* 6 */ {7, s_2_6, -1, 6, 0}, + /* 7 */ {6, s_2_7, -1, 5, 0}, + /* 8 */ {6, s_2_8, -1, 6, 0}, + /* 9 */ {7, s_2_9, -1, 5, 0}, + /* 10 */ {7, s_2_10, -1, 4, 0}, + /* 11 */ {9, s_2_11, -1, 1, 0}, + /* 12 */ {9, s_2_12, -1, 2, 0}, + /* 13 */ {7, s_2_13, -1, 3, 0}, + /* 14 */ {5, s_2_14, -1, 4, 0}, + /* 15 */ {5, s_2_15, -1, 5, 0}, + /* 16 */ {5, s_2_16, -1, 6, 0}, + /* 17 */ {5, s_2_17, -1, 4, 0}, + /* 18 */ {5, s_2_18, -1, 5, 0}, + /* 19 */ {7, s_2_19, 18, 4, 0}, + /* 20 */ {5, s_2_20, -1, 6, 0}, + /* 21 */ {6, s_2_21, -1, 5, 0}, + /* 22 */ {7, s_2_22, -1, 4, 0}, + /* 23 */ {9, s_2_23, -1, 1, 0}, + /* 24 */ {7, s_2_24, -1, 3, 0}, + /* 25 */ {5, s_2_25, -1, 4, 0}, + /* 26 */ {5, s_2_26, -1, 5, 0}, + /* 27 */ {5, s_2_27, -1, 6, 0}, + /* 28 */ {7, s_2_28, -1, 4, 0}, + /* 29 */ {9, s_2_29, -1, 1, 0}, + /* 30 */ {7, s_2_30, -1, 3, 0}, + /* 31 */ {9, s_2_31, -1, 4, 0}, + /* 32 */ {11, s_2_32, -1, 1, 0}, + /* 33 */ {9, s_2_33, -1, 3, 0}, + /* 34 */ {4, s_2_34, -1, 4, 0}, + /* 35 */ {4, s_2_35, -1, 5, 0}, + /* 36 */ {6, s_2_36, 35, 4, 0}, + /* 37 */ {4, s_2_37, -1, 6, 0}, + /* 38 */ {5, s_2_38, -1, 5, 0}, + /* 39 */ {4, s_2_39, -1, 4, 0}, + /* 40 */ {4, s_2_40, -1, 5, 0}, + /* 41 */ {4, s_2_41, -1, 6, 0}, + /* 42 */ {6, s_2_42, -1, 4, 0}, + /* 43 */ {6, s_2_43, -1, 4, 0}, + /* 44 */ {6, s_2_44, -1, 5, 0}, + /* 45 */ {6, s_2_45, -1, 6, 0}}; + +static const symbol s_3_0[3] = {'i', 'c', 'a'}; +static const symbol s_3_1[5] = {'a', 'b', 'i', 'l', 'a'}; +static const symbol s_3_2[5] = {'i', 'b', 'i', 'l', 'a'}; +static const symbol s_3_3[4] = {'o', 'a', 's', 'a'}; +static const symbol s_3_4[3] = {'a', 't', 'a'}; +static const symbol s_3_5[3] = {'i', 't', 'a'}; +static const symbol s_3_6[4] = {'a', 'n', 't', 'a'}; +static const symbol s_3_7[4] = {'i', 's', 't', 'a'}; +static const symbol s_3_8[3] = {'u', 't', 'a'}; +static const symbol s_3_9[3] = {'i', 'v', 'a'}; +static const symbol s_3_10[2] = {'i', 'c'}; +static const symbol s_3_11[3] = {'i', 'c', 'e'}; +static const symbol s_3_12[5] = {'a', 'b', 'i', 'l', 'e'}; +static const symbol s_3_13[5] = {'i', 'b', 'i', 'l', 'e'}; +static const symbol s_3_14[4] = {'i', 's', 'm', 'e'}; +static const symbol s_3_15[4] = {'i', 'u', 'n', 'e'}; +static const symbol s_3_16[4] = {'o', 'a', 's', 'e'}; +static const symbol s_3_17[3] = {'a', 't', 'e'}; +static const symbol s_3_18[5] = {'i', 't', 'a', 't', 'e'}; +static const symbol s_3_19[3] = {'i', 't', 'e'}; +static const symbol s_3_20[4] = {'a', 'n', 't', 'e'}; +static const symbol s_3_21[4] = {'i', 's', 't', 'e'}; +static const symbol s_3_22[3] = {'u', 't', 'e'}; +static const symbol s_3_23[3] = {'i', 'v', 'e'}; +static const symbol s_3_24[3] = {'i', 'c', 'i'}; +static const symbol s_3_25[5] = {'a', 'b', 'i', 'l', 'i'}; +static const symbol s_3_26[5] = {'i', 'b', 'i', 'l', 'i'}; +static const symbol s_3_27[4] = {'i', 'u', 'n', 'i'}; +static const symbol s_3_28[5] = {'a', 't', 'o', 'r', 'i'}; +static const symbol s_3_29[3] = {'o', 's', 'i'}; +static const symbol s_3_30[3] = {'a', 't', 'i'}; +static const symbol s_3_31[5] = {'i', 't', 'a', 't', 'i'}; +static const symbol s_3_32[3] = {'i', 't', 'i'}; +static const symbol s_3_33[4] = {'a', 'n', 't', 'i'}; +static const symbol s_3_34[4] = {'i', 's', 't', 'i'}; +static const symbol s_3_35[3] = {'u', 't', 'i'}; +static const symbol s_3_36[5] = {'i', 0xC5, 0x9F, 't', 'i'}; +static const symbol s_3_37[3] = {'i', 'v', 'i'}; +static const symbol s_3_38[5] = {'i', 't', 0xC4, 0x83, 'i'}; +static const symbol s_3_39[4] = {'o', 0xC5, 0x9F, 'i'}; +static const symbol s_3_40[7] = {'i', 't', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_3_41[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_3_42[4] = {'i', 'b', 'i', 'l'}; +static const symbol s_3_43[3] = {'i', 's', 'm'}; +static const symbol s_3_44[4] = {'a', 't', 'o', 'r'}; +static const symbol s_3_45[2] = {'o', 's'}; +static const symbol s_3_46[2] = {'a', 't'}; +static const symbol s_3_47[2] = {'i', 't'}; +static const symbol s_3_48[3] = {'a', 'n', 't'}; +static const symbol s_3_49[3] = {'i', 's', 't'}; +static const symbol s_3_50[2] = {'u', 't'}; +static const symbol s_3_51[2] = {'i', 'v'}; +static const symbol s_3_52[4] = {'i', 'c', 0xC4, 0x83}; +static const symbol s_3_53[6] = {'a', 'b', 'i', 'l', 0xC4, 0x83}; +static const symbol s_3_54[6] = {'i', 'b', 'i', 'l', 0xC4, 0x83}; +static const symbol s_3_55[5] = {'o', 'a', 's', 0xC4, 0x83}; +static const symbol s_3_56[4] = {'a', 't', 0xC4, 0x83}; +static const symbol s_3_57[4] = {'i', 't', 0xC4, 0x83}; +static const symbol s_3_58[5] = {'a', 'n', 't', 0xC4, 0x83}; +static const symbol s_3_59[5] = {'i', 's', 't', 0xC4, 0x83}; +static const symbol s_3_60[4] = {'u', 't', 0xC4, 0x83}; +static const symbol s_3_61[4] = {'i', 'v', 0xC4, 0x83}; + +static const struct among a_3[62] = { + /* 0 */ {3, s_3_0, -1, 1, 0}, + /* 1 */ {5, s_3_1, -1, 1, 0}, + /* 2 */ {5, s_3_2, -1, 1, 0}, + /* 3 */ {4, s_3_3, -1, 1, 0}, + /* 4 */ {3, s_3_4, -1, 1, 0}, + /* 5 */ {3, s_3_5, -1, 1, 0}, + /* 6 */ {4, s_3_6, -1, 1, 0}, + /* 7 */ {4, s_3_7, -1, 3, 0}, + /* 8 */ {3, s_3_8, -1, 1, 0}, + /* 9 */ {3, s_3_9, -1, 1, 0}, + /* 10 */ {2, s_3_10, -1, 1, 0}, + /* 11 */ {3, s_3_11, -1, 1, 0}, + /* 12 */ {5, s_3_12, -1, 1, 0}, + /* 13 */ {5, s_3_13, -1, 1, 0}, + /* 14 */ {4, s_3_14, -1, 3, 0}, + /* 15 */ {4, s_3_15, -1, 2, 0}, + /* 16 */ {4, s_3_16, -1, 1, 0}, + /* 17 */ {3, s_3_17, -1, 1, 0}, + /* 18 */ {5, s_3_18, 17, 1, 0}, + /* 19 */ {3, s_3_19, -1, 1, 0}, + /* 20 */ {4, s_3_20, -1, 1, 0}, + /* 21 */ {4, s_3_21, -1, 3, 0}, + /* 22 */ {3, s_3_22, -1, 1, 0}, + /* 23 */ {3, s_3_23, -1, 1, 0}, + /* 24 */ {3, s_3_24, -1, 1, 0}, + /* 25 */ {5, s_3_25, -1, 1, 0}, + /* 26 */ {5, s_3_26, -1, 1, 0}, + /* 27 */ {4, s_3_27, -1, 2, 0}, + /* 28 */ {5, s_3_28, -1, 1, 0}, + /* 29 */ {3, s_3_29, -1, 1, 0}, + /* 30 */ {3, s_3_30, -1, 1, 0}, + /* 31 */ {5, s_3_31, 30, 1, 0}, + /* 32 */ {3, s_3_32, -1, 1, 0}, + /* 33 */ {4, s_3_33, -1, 1, 0}, + /* 34 */ {4, s_3_34, -1, 3, 0}, + /* 35 */ {3, s_3_35, -1, 1, 0}, + /* 36 */ {5, s_3_36, -1, 3, 0}, + /* 37 */ {3, s_3_37, -1, 1, 0}, + /* 38 */ {5, s_3_38, -1, 1, 0}, + /* 39 */ {4, s_3_39, -1, 1, 0}, + /* 40 */ {7, s_3_40, -1, 1, 0}, + /* 41 */ {4, s_3_41, -1, 1, 0}, + /* 42 */ {4, s_3_42, -1, 1, 0}, + /* 43 */ {3, s_3_43, -1, 3, 0}, + /* 44 */ {4, s_3_44, -1, 1, 0}, + /* 45 */ {2, s_3_45, -1, 1, 0}, + /* 46 */ {2, s_3_46, -1, 1, 0}, + /* 47 */ {2, s_3_47, -1, 1, 0}, + /* 48 */ {3, s_3_48, -1, 1, 0}, + /* 49 */ {3, s_3_49, -1, 3, 0}, + /* 50 */ {2, s_3_50, -1, 1, 0}, + /* 51 */ {2, s_3_51, -1, 1, 0}, + /* 52 */ {4, s_3_52, -1, 1, 0}, + /* 53 */ {6, s_3_53, -1, 1, 0}, + /* 54 */ {6, s_3_54, -1, 1, 0}, + /* 55 */ {5, s_3_55, -1, 1, 0}, + /* 56 */ {4, s_3_56, -1, 1, 0}, + /* 57 */ {4, s_3_57, -1, 1, 0}, + /* 58 */ {5, s_3_58, -1, 1, 0}, + /* 59 */ {5, s_3_59, -1, 3, 0}, + /* 60 */ {4, s_3_60, -1, 1, 0}, + /* 61 */ {4, s_3_61, -1, 1, 0}}; + +static const symbol s_4_0[2] = {'e', 'a'}; +static const symbol s_4_1[2] = {'i', 'a'}; +static const symbol s_4_2[3] = {'e', 's', 'c'}; +static const symbol s_4_3[4] = {0xC4, 0x83, 's', 'c'}; +static const symbol s_4_4[3] = {'i', 'n', 'd'}; +static const symbol s_4_5[4] = {0xC3, 0xA2, 'n', 'd'}; +static const symbol s_4_6[3] = {'a', 'r', 'e'}; +static const symbol s_4_7[3] = {'e', 'r', 'e'}; +static const symbol s_4_8[3] = {'i', 'r', 'e'}; +static const symbol s_4_9[4] = {0xC3, 0xA2, 'r', 'e'}; +static const symbol s_4_10[2] = {'s', 'e'}; +static const symbol s_4_11[3] = {'a', 's', 'e'}; +static const symbol s_4_12[4] = {'s', 'e', 's', 'e'}; +static const symbol s_4_13[3] = {'i', 's', 'e'}; +static const symbol s_4_14[3] = {'u', 's', 'e'}; +static const symbol s_4_15[4] = {0xC3, 0xA2, 's', 'e'}; +static const symbol s_4_16[5] = {'e', 0xC5, 0x9F, 't', 'e'}; +static const symbol s_4_17[6] = {0xC4, 0x83, 0xC5, 0x9F, 't', 'e'}; +static const symbol s_4_18[3] = {'e', 'z', 'e'}; +static const symbol s_4_19[2] = {'a', 'i'}; +static const symbol s_4_20[3] = {'e', 'a', 'i'}; +static const symbol s_4_21[3] = {'i', 'a', 'i'}; +static const symbol s_4_22[3] = {'s', 'e', 'i'}; +static const symbol s_4_23[5] = {'e', 0xC5, 0x9F, 't', 'i'}; +static const symbol s_4_24[6] = {0xC4, 0x83, 0xC5, 0x9F, 't', 'i'}; +static const symbol s_4_25[2] = {'u', 'i'}; +static const symbol s_4_26[3] = {'e', 'z', 'i'}; +static const symbol s_4_27[4] = {'a', 0xC5, 0x9F, 'i'}; +static const symbol s_4_28[5] = {'s', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_29[6] = {'a', 's', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_30[7] = {'s', 'e', 's', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_31[6] = {'i', 's', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_32[6] = {'u', 's', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_33[7] = {0xC3, 0xA2, 's', 'e', 0xC5, 0x9F, 'i'}; +static const symbol s_4_34[4] = {'i', 0xC5, 0x9F, 'i'}; +static const symbol s_4_35[4] = {'u', 0xC5, 0x9F, 'i'}; +static const symbol s_4_36[5] = {0xC3, 0xA2, 0xC5, 0x9F, 'i'}; +static const symbol s_4_37[3] = {0xC3, 0xA2, 'i'}; +static const symbol s_4_38[4] = {'a', 0xC5, 0xA3, 'i'}; +static const symbol s_4_39[5] = {'e', 'a', 0xC5, 0xA3, 'i'}; +static const symbol s_4_40[5] = {'i', 'a', 0xC5, 0xA3, 'i'}; +static const symbol s_4_41[4] = {'e', 0xC5, 0xA3, 'i'}; +static const symbol s_4_42[4] = {'i', 0xC5, 0xA3, 'i'}; +static const symbol s_4_43[7] = {'a', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_44[8] = {'s', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_45[9] = {'a', 's', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_46[10] = {'s', 'e', 's', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_47[9] = {'i', 's', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_48[9] = {'u', 's', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_49[10] = {0xC3, 0xA2, 's', 'e', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_50[7] = {'i', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_51[7] = {'u', 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_52[8] = {0xC3, 0xA2, 'r', 0xC4, 0x83, 0xC5, 0xA3, 'i'}; +static const symbol s_4_53[5] = {0xC3, 0xA2, 0xC5, 0xA3, 'i'}; +static const symbol s_4_54[2] = {'a', 'm'}; +static const symbol s_4_55[3] = {'e', 'a', 'm'}; +static const symbol s_4_56[3] = {'i', 'a', 'm'}; +static const symbol s_4_57[2] = {'e', 'm'}; +static const symbol s_4_58[4] = {'a', 's', 'e', 'm'}; +static const symbol s_4_59[5] = {'s', 'e', 's', 'e', 'm'}; +static const symbol s_4_60[4] = {'i', 's', 'e', 'm'}; +static const symbol s_4_61[4] = {'u', 's', 'e', 'm'}; +static const symbol s_4_62[5] = {0xC3, 0xA2, 's', 'e', 'm'}; +static const symbol s_4_63[2] = {'i', 'm'}; +static const symbol s_4_64[3] = {0xC4, 0x83, 'm'}; +static const symbol s_4_65[5] = {'a', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_66[6] = {'s', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_67[7] = {'a', 's', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_68[8] = {'s', 'e', 's', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_69[7] = {'i', 's', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_70[7] = {'u', 's', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_71[8] = {0xC3, 0xA2, 's', 'e', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_72[5] = {'i', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_73[5] = {'u', 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_74[6] = {0xC3, 0xA2, 'r', 0xC4, 0x83, 'm'}; +static const symbol s_4_75[3] = {0xC3, 0xA2, 'm'}; +static const symbol s_4_76[2] = {'a', 'u'}; +static const symbol s_4_77[3] = {'e', 'a', 'u'}; +static const symbol s_4_78[3] = {'i', 'a', 'u'}; +static const symbol s_4_79[4] = {'i', 'n', 'd', 'u'}; +static const symbol s_4_80[5] = {0xC3, 0xA2, 'n', 'd', 'u'}; +static const symbol s_4_81[2] = {'e', 'z'}; +static const symbol s_4_82[6] = {'e', 'a', 's', 'c', 0xC4, 0x83}; +static const symbol s_4_83[4] = {'a', 'r', 0xC4, 0x83}; +static const symbol s_4_84[5] = {'s', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_85[6] = {'a', 's', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_86[7] = {'s', 'e', 's', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_87[6] = {'i', 's', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_88[6] = {'u', 's', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_89[7] = {0xC3, 0xA2, 's', 'e', 'r', 0xC4, 0x83}; +static const symbol s_4_90[4] = {'i', 'r', 0xC4, 0x83}; +static const symbol s_4_91[4] = {'u', 'r', 0xC4, 0x83}; +static const symbol s_4_92[5] = {0xC3, 0xA2, 'r', 0xC4, 0x83}; +static const symbol s_4_93[5] = {'e', 'a', 'z', 0xC4, 0x83}; + +static const struct among a_4[94] = { + /* 0 */ {2, s_4_0, -1, 1, 0}, + /* 1 */ {2, s_4_1, -1, 1, 0}, + /* 2 */ {3, s_4_2, -1, 1, 0}, + /* 3 */ {4, s_4_3, -1, 1, 0}, + /* 4 */ {3, s_4_4, -1, 1, 0}, + /* 5 */ {4, s_4_5, -1, 1, 0}, + /* 6 */ {3, s_4_6, -1, 1, 0}, + /* 7 */ {3, s_4_7, -1, 1, 0}, + /* 8 */ {3, s_4_8, -1, 1, 0}, + /* 9 */ {4, s_4_9, -1, 1, 0}, + /* 10 */ {2, s_4_10, -1, 2, 0}, + /* 11 */ {3, s_4_11, 10, 1, 0}, + /* 12 */ {4, s_4_12, 10, 2, 0}, + /* 13 */ {3, s_4_13, 10, 1, 0}, + /* 14 */ {3, s_4_14, 10, 1, 0}, + /* 15 */ {4, s_4_15, 10, 1, 0}, + /* 16 */ {5, s_4_16, -1, 1, 0}, + /* 17 */ {6, s_4_17, -1, 1, 0}, + /* 18 */ {3, s_4_18, -1, 1, 0}, + /* 19 */ {2, s_4_19, -1, 1, 0}, + /* 20 */ {3, s_4_20, 19, 1, 0}, + /* 21 */ {3, s_4_21, 19, 1, 0}, + /* 22 */ {3, s_4_22, -1, 2, 0}, + /* 23 */ {5, s_4_23, -1, 1, 0}, + /* 24 */ {6, s_4_24, -1, 1, 0}, + /* 25 */ {2, s_4_25, -1, 1, 0}, + /* 26 */ {3, s_4_26, -1, 1, 0}, + /* 27 */ {4, s_4_27, -1, 1, 0}, + /* 28 */ {5, s_4_28, -1, 2, 0}, + /* 29 */ {6, s_4_29, 28, 1, 0}, + /* 30 */ {7, s_4_30, 28, 2, 0}, + /* 31 */ {6, s_4_31, 28, 1, 0}, + /* 32 */ {6, s_4_32, 28, 1, 0}, + /* 33 */ {7, s_4_33, 28, 1, 0}, + /* 34 */ {4, s_4_34, -1, 1, 0}, + /* 35 */ {4, s_4_35, -1, 1, 0}, + /* 36 */ {5, s_4_36, -1, 1, 0}, + /* 37 */ {3, s_4_37, -1, 1, 0}, + /* 38 */ {4, s_4_38, -1, 2, 0}, + /* 39 */ {5, s_4_39, 38, 1, 0}, + /* 40 */ {5, s_4_40, 38, 1, 0}, + /* 41 */ {4, s_4_41, -1, 2, 0}, + /* 42 */ {4, s_4_42, -1, 2, 0}, + /* 43 */ {7, s_4_43, -1, 1, 0}, + /* 44 */ {8, s_4_44, -1, 2, 0}, + /* 45 */ {9, s_4_45, 44, 1, 0}, + /* 46 */ {10, s_4_46, 44, 2, 0}, + /* 47 */ {9, s_4_47, 44, 1, 0}, + /* 48 */ {9, s_4_48, 44, 1, 0}, + /* 49 */ {10, s_4_49, 44, 1, 0}, + /* 50 */ {7, s_4_50, -1, 1, 0}, + /* 51 */ {7, s_4_51, -1, 1, 0}, + /* 52 */ {8, s_4_52, -1, 1, 0}, + /* 53 */ {5, s_4_53, -1, 2, 0}, + /* 54 */ {2, s_4_54, -1, 1, 0}, + /* 55 */ {3, s_4_55, 54, 1, 0}, + /* 56 */ {3, s_4_56, 54, 1, 0}, + /* 57 */ {2, s_4_57, -1, 2, 0}, + /* 58 */ {4, s_4_58, 57, 1, 0}, + /* 59 */ {5, s_4_59, 57, 2, 0}, + /* 60 */ {4, s_4_60, 57, 1, 0}, + /* 61 */ {4, s_4_61, 57, 1, 0}, + /* 62 */ {5, s_4_62, 57, 1, 0}, + /* 63 */ {2, s_4_63, -1, 2, 0}, + /* 64 */ {3, s_4_64, -1, 2, 0}, + /* 65 */ {5, s_4_65, 64, 1, 0}, + /* 66 */ {6, s_4_66, 64, 2, 0}, + /* 67 */ {7, s_4_67, 66, 1, 0}, + /* 68 */ {8, s_4_68, 66, 2, 0}, + /* 69 */ {7, s_4_69, 66, 1, 0}, + /* 70 */ {7, s_4_70, 66, 1, 0}, + /* 71 */ {8, s_4_71, 66, 1, 0}, + /* 72 */ {5, s_4_72, 64, 1, 0}, + /* 73 */ {5, s_4_73, 64, 1, 0}, + /* 74 */ {6, s_4_74, 64, 1, 0}, + /* 75 */ {3, s_4_75, -1, 2, 0}, + /* 76 */ {2, s_4_76, -1, 1, 0}, + /* 77 */ {3, s_4_77, 76, 1, 0}, + /* 78 */ {3, s_4_78, 76, 1, 0}, + /* 79 */ {4, s_4_79, -1, 1, 0}, + /* 80 */ {5, s_4_80, -1, 1, 0}, + /* 81 */ {2, s_4_81, -1, 1, 0}, + /* 82 */ {6, s_4_82, -1, 1, 0}, + /* 83 */ {4, s_4_83, -1, 1, 0}, + /* 84 */ {5, s_4_84, -1, 2, 0}, + /* 85 */ {6, s_4_85, 84, 1, 0}, + /* 86 */ {7, s_4_86, 84, 2, 0}, + /* 87 */ {6, s_4_87, 84, 1, 0}, + /* 88 */ {6, s_4_88, 84, 1, 0}, + /* 89 */ {7, s_4_89, 84, 1, 0}, + /* 90 */ {4, s_4_90, -1, 1, 0}, + /* 91 */ {4, s_4_91, -1, 1, 0}, + /* 92 */ {5, s_4_92, -1, 1, 0}, + /* 93 */ {5, s_4_93, -1, 1, 0}}; + +static const symbol s_5_0[1] = {'a'}; +static const symbol s_5_1[1] = {'e'}; +static const symbol s_5_2[2] = {'i', 'e'}; +static const symbol s_5_3[1] = {'i'}; +static const symbol s_5_4[2] = {0xC4, 0x83}; + +static const struct among a_5[5] = { + /* 0 */ {1, s_5_0, -1, 1, 0}, + /* 1 */ {1, s_5_1, -1, 1, 0}, + /* 2 */ {2, s_5_2, 1, 1, 0}, + /* 3 */ {1, s_5_3, -1, 1, 0}, + /* 4 */ {2, s_5_4, -1, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 32, 0, 0, 4}; + +static const symbol s_0[] = {'u'}; +static const symbol s_1[] = {'U'}; +static const symbol s_2[] = {'i'}; +static const symbol s_3[] = {'I'}; +static const symbol s_4[] = {'i'}; +static const symbol s_5[] = {'u'}; +static const symbol s_6[] = {'a'}; +static const symbol s_7[] = {'e'}; +static const symbol s_8[] = {'i'}; +static const symbol s_9[] = {'a', 'b'}; +static const symbol s_10[] = {'i'}; +static const symbol s_11[] = {'a', 't'}; +static const symbol s_12[] = {'a', 0xC5, 0xA3, 'i'}; +static const symbol s_13[] = {'a', 'b', 'i', 'l'}; +static const symbol s_14[] = {'i', 'b', 'i', 'l'}; +static const symbol s_15[] = {'i', 'v'}; +static const symbol s_16[] = {'i', 'c'}; +static const symbol s_17[] = {'a', 't'}; +static const symbol s_18[] = {'i', 't'}; +static const symbol s_19[] = {0xC5, 0xA3}; +static const symbol s_20[] = {'t'}; +static const symbol s_21[] = {'i', 's', 't'}; +static const symbol s_22[] = {'u'}; + +static int r_prelude(struct SN_env *z) { + while (1) { /* repeat, line 32 */ + int c1 = z->c; + while (1) { /* goto, line 32 */ + int c2 = z->c; + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab1; + z->bra = z->c; /* [, line 33 */ + { + int c3 = z->c; /* or, line 33 */ + if (!(eq_s(z, 1, s_0))) + goto lab3; + z->ket = z->c; /* ], line 33 */ + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab3; + { + int ret = slice_from_s(z, 1, s_1); /* <-, line 33 */ + if (ret < 0) + return ret; + } + goto lab2; + lab3: + z->c = c3; + if (!(eq_s(z, 1, s_2))) + goto lab1; + z->ket = z->c; /* ], line 34 */ + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab1; + { + int ret = slice_from_s(z, 1, s_3); /* <-, line 34 */ + if (ret < 0) + return ret; + } + } + lab2: + z->c = c2; + break; + lab1: + z->c = c2; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* goto, line 32 */ + } + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + z->I[2] = z->l; + { + int c1 = z->c; /* do, line 44 */ + { + int c2 = z->c; /* or, line 46 */ + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab2; + { + int c3 = z->c; /* or, line 45 */ + if (out_grouping_U(z, g_v, 97, 259, 0)) + goto lab4; + { /* gopast */ /* grouping v, line 45 */ + int ret = out_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + goto lab3; + lab4: + z->c = c3; + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab2; + { /* gopast */ /* non v, line 45 */ + int ret = in_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab2; + z->c += ret; + } + } + lab3: + goto lab1; + lab2: + z->c = c2; + if (out_grouping_U(z, g_v, 97, 259, 0)) + goto lab0; + { + int c4 = z->c; /* or, line 47 */ + if (out_grouping_U(z, g_v, 97, 259, 0)) + goto lab6; + { /* gopast */ /* grouping v, line 47 */ + int ret = out_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab6; + z->c += ret; + } + goto lab5; + lab6: + z->c = c4; + if (in_grouping_U(z, g_v, 97, 259, 0)) + goto lab0; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 47 */ + } + } + lab5:; + } + lab1: + z->I[0] = z->c; /* setmark pV, line 48 */ + lab0: + z->c = c1; + } + { + int c5 = z->c; /* do, line 50 */ + { /* gopast */ /* grouping v, line 51 */ + int ret = out_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 51 */ + int ret = in_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[1] = z->c; /* setmark p1, line 51 */ + { /* gopast */ /* grouping v, line 52 */ + int ret = out_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 52 */ + int ret = in_grouping_U(z, g_v, 97, 259, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[2] = z->c; /* setmark p2, line 52 */ + lab7: + z->c = c5; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 56 */ + int c1 = z->c; + z->bra = z->c; /* [, line 58 */ + if (z->c >= z->l || (z->p[z->c + 0] != 73 && z->p[z->c + 0] != 85)) + among_var = 3; + else + among_var = find_among(z, a_0, 3); /* substring, line 58 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 58 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_4); /* <-, line 59 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_5); /* <-, line 60 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 61 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_RV(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[2] <= z->c)) + return 0; + return 1; +} + +static int r_step_0(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 73 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((266786 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_1, 16); /* substring, line 73 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 73 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 73 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 75 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_6); /* <-, line 77 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_7); /* <-, line 79 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_8); /* <-, line 81 */ + if (ret < 0) + return ret; + } break; + case 5: { + int m1 = z->l - z->c; + (void)m1; /* not, line 83 */ + if (!(eq_s_b(z, 2, s_9))) + goto lab0; + return 0; + lab0: + z->c = z->l - m1; + } + { + int ret = slice_from_s(z, 1, s_10); /* <-, line 83 */ + if (ret < 0) + return ret; + } + break; + case 6: { + int ret = slice_from_s(z, 2, s_11); /* <-, line 85 */ + if (ret < 0) + return ret; + } break; + case 7: { + int ret = slice_from_s(z, 4, s_12); /* <-, line 87 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_combo_suffix(struct SN_env *z) { + int among_var; + { + int m_test = z->l - z->c; /* test, line 91 */ + z->ket = z->c; /* [, line 92 */ + among_var = find_among_b(z, a_2, 46); /* substring, line 92 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 92 */ + { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 92 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 4, s_13); /* <-, line 101 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 4, s_14); /* <-, line 104 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 2, s_15); /* <-, line 107 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 2, s_16); /* <-, line 113 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 2, s_17); /* <-, line 118 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = slice_from_s(z, 2, s_18); /* <-, line 122 */ + if (ret < 0) + return ret; + } break; + } + z->B[0] = 1; /* set standard_suffix_removed, line 125 */ + z->c = z->l - m_test; + } + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + z->B[0] = 0; /* unset standard_suffix_removed, line 130 */ + while (1) { /* repeat, line 131 */ + int m1 = z->l - z->c; + (void)m1; + { + int ret = r_combo_suffix(z); + if (ret == 0) + goto lab0; /* call combo_suffix, line 131 */ + if (ret < 0) + return ret; + } + continue; + lab0: + z->c = z->l - m1; + break; + } + z->ket = z->c; /* [, line 132 */ + among_var = find_among_b(z, a_3, 62); /* substring, line 132 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 132 */ + { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 132 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 149 */ + if (ret < 0) + return ret; + } break; + case 2: + if (!(eq_s_b(z, 2, s_19))) + return 0; + z->bra = z->c; /* ], line 152 */ + { + int ret = slice_from_s(z, 1, s_20); /* <-, line 152 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = slice_from_s(z, 3, s_21); /* <-, line 156 */ + if (ret < 0) + return ret; + } break; + } + z->B[0] = 1; /* set standard_suffix_removed, line 160 */ + return 1; +} + +static int r_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 164 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 164 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 165 */ + among_var = find_among_b(z, a_4, 94); /* substring, line 165 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 165 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int m2 = z->l - z->c; + (void)m2; /* or, line 200 */ + if (out_grouping_b_U(z, g_v, 97, 259, 0)) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m2; + if (!(eq_s_b(z, 1, s_22))) { + z->lb = mlimit; + return 0; + } + } + lab0: { + int ret = slice_del(z); /* delete, line 200 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_del(z); /* delete, line 214 */ + if (ret < 0) + return ret; + } break; + } + z->lb = mlimit; + } + return 1; +} + +static int r_vowel_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 219 */ + among_var = find_among_b(z, a_5, 5); /* substring, line 219 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 219 */ + { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 219 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 220 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +extern int romanian_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 226 */ + { + int ret = r_prelude(z); + if (ret == 0) + goto lab0; /* call prelude, line 226 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + { + int c2 = z->c; /* do, line 227 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab1; /* call mark_regions, line 227 */ + if (ret < 0) + return ret; + } + lab1: + z->c = c2; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 228 */ + + { + int m3 = z->l - z->c; + (void)m3; /* do, line 229 */ + { + int ret = r_step_0(z); + if (ret == 0) + goto lab2; /* call step_0, line 229 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 230 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab3; /* call standard_suffix, line 230 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 231 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 231 */ + if (!(z->B[0])) + goto lab6; /* Boolean test standard_suffix_removed, line 231 */ + goto lab5; + lab6: + z->c = z->l - m6; + { + int ret = r_verb_suffix(z); + if (ret == 0) + goto lab4; /* call verb_suffix, line 231 */ + if (ret < 0) + return ret; + } + } + lab5: + lab4: + z->c = z->l - m5; + } + { + int m7 = z->l - z->c; + (void)m7; /* do, line 232 */ + { + int ret = r_vowel_suffix(z); + if (ret == 0) + goto lab7; /* call vowel_suffix, line 232 */ + if (ret < 0) + return ret; + } + lab7: + z->c = z->l - m7; + } + z->c = z->lb; + { + int c8 = z->c; /* do, line 234 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab8; /* call postlude, line 234 */ + if (ret < 0) + return ret; + } + lab8: + z->c = c8; + } + return 1; +} + +extern struct SN_env *romanian_UTF_8_create_env(void) { return SN_create_env(0, 3, 1); } + +extern void romanian_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_romanian.h b/internal/cpp/stemmer/stem_UTF_8_romanian.h new file mode 100644 index 00000000000..19260c9707c --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_romanian.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *romanian_UTF_8_create_env(void); +extern void romanian_UTF_8_close_env(struct SN_env *z); + +extern int romanian_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_russian.cpp b/internal/cpp/stemmer/stem_UTF_8_russian.cpp new file mode 100644 index 00000000000..210d6cbc211 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_russian.cpp @@ -0,0 +1,774 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int russian_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_tidy_up(struct SN_env *z); +static int r_derivational(struct SN_env *z); +static int r_noun(struct SN_env *z); +static int r_verb(struct SN_env *z); +static int r_reflexive(struct SN_env *z); +static int r_adjectival(struct SN_env *z); +static int r_adjective(struct SN_env *z); +static int r_perfective_gerund(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *russian_UTF_8_create_env(void); +extern void russian_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[10] = {0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8, 0xD1, 0x81, 0xD1, 0x8C}; +static const symbol s_0_1[12] = {0xD1, 0x8B, 0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8, 0xD1, 0x81, 0xD1, 0x8C}; +static const symbol s_0_2[12] = {0xD0, 0xB8, 0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8, 0xD1, 0x81, 0xD1, 0x8C}; +static const symbol s_0_3[2] = {0xD0, 0xB2}; +static const symbol s_0_4[4] = {0xD1, 0x8B, 0xD0, 0xB2}; +static const symbol s_0_5[4] = {0xD0, 0xB8, 0xD0, 0xB2}; +static const symbol s_0_6[6] = {0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8}; +static const symbol s_0_7[8] = {0xD1, 0x8B, 0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8}; +static const symbol s_0_8[8] = {0xD0, 0xB8, 0xD0, 0xB2, 0xD1, 0x88, 0xD0, 0xB8}; + +static const struct among a_0[9] = { + /* 0 */ {10, s_0_0, -1, 1, 0}, + /* 1 */ {12, s_0_1, 0, 2, 0}, + /* 2 */ {12, s_0_2, 0, 2, 0}, + /* 3 */ {2, s_0_3, -1, 1, 0}, + /* 4 */ {4, s_0_4, 3, 2, 0}, + /* 5 */ {4, s_0_5, 3, 2, 0}, + /* 6 */ {6, s_0_6, -1, 1, 0}, + /* 7 */ {8, s_0_7, 6, 2, 0}, + /* 8 */ {8, s_0_8, 6, 2, 0}}; + +static const symbol s_1_0[6] = {0xD0, 0xB5, 0xD0, 0xBC, 0xD1, 0x83}; +static const symbol s_1_1[6] = {0xD0, 0xBE, 0xD0, 0xBC, 0xD1, 0x83}; +static const symbol s_1_2[4] = {0xD1, 0x8B, 0xD1, 0x85}; +static const symbol s_1_3[4] = {0xD0, 0xB8, 0xD1, 0x85}; +static const symbol s_1_4[4] = {0xD1, 0x83, 0xD1, 0x8E}; +static const symbol s_1_5[4] = {0xD1, 0x8E, 0xD1, 0x8E}; +static const symbol s_1_6[4] = {0xD0, 0xB5, 0xD1, 0x8E}; +static const symbol s_1_7[4] = {0xD0, 0xBE, 0xD1, 0x8E}; +static const symbol s_1_8[4] = {0xD1, 0x8F, 0xD1, 0x8F}; +static const symbol s_1_9[4] = {0xD0, 0xB0, 0xD1, 0x8F}; +static const symbol s_1_10[4] = {0xD1, 0x8B, 0xD0, 0xB5}; +static const symbol s_1_11[4] = {0xD0, 0xB5, 0xD0, 0xB5}; +static const symbol s_1_12[4] = {0xD0, 0xB8, 0xD0, 0xB5}; +static const symbol s_1_13[4] = {0xD0, 0xBE, 0xD0, 0xB5}; +static const symbol s_1_14[6] = {0xD1, 0x8B, 0xD0, 0xBC, 0xD0, 0xB8}; +static const symbol s_1_15[6] = {0xD0, 0xB8, 0xD0, 0xBC, 0xD0, 0xB8}; +static const symbol s_1_16[4] = {0xD1, 0x8B, 0xD0, 0xB9}; +static const symbol s_1_17[4] = {0xD0, 0xB5, 0xD0, 0xB9}; +static const symbol s_1_18[4] = {0xD0, 0xB8, 0xD0, 0xB9}; +static const symbol s_1_19[4] = {0xD0, 0xBE, 0xD0, 0xB9}; +static const symbol s_1_20[4] = {0xD1, 0x8B, 0xD0, 0xBC}; +static const symbol s_1_21[4] = {0xD0, 0xB5, 0xD0, 0xBC}; +static const symbol s_1_22[4] = {0xD0, 0xB8, 0xD0, 0xBC}; +static const symbol s_1_23[4] = {0xD0, 0xBE, 0xD0, 0xBC}; +static const symbol s_1_24[6] = {0xD0, 0xB5, 0xD0, 0xB3, 0xD0, 0xBE}; +static const symbol s_1_25[6] = {0xD0, 0xBE, 0xD0, 0xB3, 0xD0, 0xBE}; + +static const struct among a_1[26] = { + /* 0 */ {6, s_1_0, -1, 1, 0}, + /* 1 */ {6, s_1_1, -1, 1, 0}, + /* 2 */ {4, s_1_2, -1, 1, 0}, + /* 3 */ {4, s_1_3, -1, 1, 0}, + /* 4 */ {4, s_1_4, -1, 1, 0}, + /* 5 */ {4, s_1_5, -1, 1, 0}, + /* 6 */ {4, s_1_6, -1, 1, 0}, + /* 7 */ {4, s_1_7, -1, 1, 0}, + /* 8 */ {4, s_1_8, -1, 1, 0}, + /* 9 */ {4, s_1_9, -1, 1, 0}, + /* 10 */ {4, s_1_10, -1, 1, 0}, + /* 11 */ {4, s_1_11, -1, 1, 0}, + /* 12 */ {4, s_1_12, -1, 1, 0}, + /* 13 */ {4, s_1_13, -1, 1, 0}, + /* 14 */ {6, s_1_14, -1, 1, 0}, + /* 15 */ {6, s_1_15, -1, 1, 0}, + /* 16 */ {4, s_1_16, -1, 1, 0}, + /* 17 */ {4, s_1_17, -1, 1, 0}, + /* 18 */ {4, s_1_18, -1, 1, 0}, + /* 19 */ {4, s_1_19, -1, 1, 0}, + /* 20 */ {4, s_1_20, -1, 1, 0}, + /* 21 */ {4, s_1_21, -1, 1, 0}, + /* 22 */ {4, s_1_22, -1, 1, 0}, + /* 23 */ {4, s_1_23, -1, 1, 0}, + /* 24 */ {6, s_1_24, -1, 1, 0}, + /* 25 */ {6, s_1_25, -1, 1, 0}}; + +static const symbol s_2_0[4] = {0xD0, 0xB2, 0xD1, 0x88}; +static const symbol s_2_1[6] = {0xD1, 0x8B, 0xD0, 0xB2, 0xD1, 0x88}; +static const symbol s_2_2[6] = {0xD0, 0xB8, 0xD0, 0xB2, 0xD1, 0x88}; +static const symbol s_2_3[2] = {0xD1, 0x89}; +static const symbol s_2_4[4] = {0xD1, 0x8E, 0xD1, 0x89}; +static const symbol s_2_5[6] = {0xD1, 0x83, 0xD1, 0x8E, 0xD1, 0x89}; +static const symbol s_2_6[4] = {0xD0, 0xB5, 0xD0, 0xBC}; +static const symbol s_2_7[4] = {0xD0, 0xBD, 0xD0, 0xBD}; + +static const struct among a_2[8] = { + /* 0 */ {4, s_2_0, -1, 1, 0}, + /* 1 */ {6, s_2_1, 0, 2, 0}, + /* 2 */ {6, s_2_2, 0, 2, 0}, + /* 3 */ {2, s_2_3, -1, 1, 0}, + /* 4 */ {4, s_2_4, 3, 1, 0}, + /* 5 */ {6, s_2_5, 4, 2, 0}, + /* 6 */ {4, s_2_6, -1, 1, 0}, + /* 7 */ {4, s_2_7, -1, 1, 0}}; + +static const symbol s_3_0[4] = {0xD1, 0x81, 0xD1, 0x8C}; +static const symbol s_3_1[4] = {0xD1, 0x81, 0xD1, 0x8F}; + +static const struct among a_3[2] = { + /* 0 */ {4, s_3_0, -1, 1, 0}, + /* 1 */ {4, s_3_1, -1, 1, 0}}; + +static const symbol s_4_0[4] = {0xD1, 0x8B, 0xD1, 0x82}; +static const symbol s_4_1[4] = {0xD1, 0x8E, 0xD1, 0x82}; +static const symbol s_4_2[6] = {0xD1, 0x83, 0xD1, 0x8E, 0xD1, 0x82}; +static const symbol s_4_3[4] = {0xD1, 0x8F, 0xD1, 0x82}; +static const symbol s_4_4[4] = {0xD0, 0xB5, 0xD1, 0x82}; +static const symbol s_4_5[6] = {0xD1, 0x83, 0xD0, 0xB5, 0xD1, 0x82}; +static const symbol s_4_6[4] = {0xD0, 0xB8, 0xD1, 0x82}; +static const symbol s_4_7[4] = {0xD0, 0xBD, 0xD1, 0x8B}; +static const symbol s_4_8[6] = {0xD0, 0xB5, 0xD0, 0xBD, 0xD1, 0x8B}; +static const symbol s_4_9[4] = {0xD1, 0x82, 0xD1, 0x8C}; +static const symbol s_4_10[6] = {0xD1, 0x8B, 0xD1, 0x82, 0xD1, 0x8C}; +static const symbol s_4_11[6] = {0xD0, 0xB8, 0xD1, 0x82, 0xD1, 0x8C}; +static const symbol s_4_12[6] = {0xD0, 0xB5, 0xD1, 0x88, 0xD1, 0x8C}; +static const symbol s_4_13[6] = {0xD0, 0xB8, 0xD1, 0x88, 0xD1, 0x8C}; +static const symbol s_4_14[2] = {0xD1, 0x8E}; +static const symbol s_4_15[4] = {0xD1, 0x83, 0xD1, 0x8E}; +static const symbol s_4_16[4] = {0xD0, 0xBB, 0xD0, 0xB0}; +static const symbol s_4_17[6] = {0xD1, 0x8B, 0xD0, 0xBB, 0xD0, 0xB0}; +static const symbol s_4_18[6] = {0xD0, 0xB8, 0xD0, 0xBB, 0xD0, 0xB0}; +static const symbol s_4_19[4] = {0xD0, 0xBD, 0xD0, 0xB0}; +static const symbol s_4_20[6] = {0xD0, 0xB5, 0xD0, 0xBD, 0xD0, 0xB0}; +static const symbol s_4_21[6] = {0xD0, 0xB5, 0xD1, 0x82, 0xD0, 0xB5}; +static const symbol s_4_22[6] = {0xD0, 0xB8, 0xD1, 0x82, 0xD0, 0xB5}; +static const symbol s_4_23[6] = {0xD0, 0xB9, 0xD1, 0x82, 0xD0, 0xB5}; +static const symbol s_4_24[8] = {0xD1, 0x83, 0xD0, 0xB9, 0xD1, 0x82, 0xD0, 0xB5}; +static const symbol s_4_25[8] = {0xD0, 0xB5, 0xD0, 0xB9, 0xD1, 0x82, 0xD0, 0xB5}; +static const symbol s_4_26[4] = {0xD0, 0xBB, 0xD0, 0xB8}; +static const symbol s_4_27[6] = {0xD1, 0x8B, 0xD0, 0xBB, 0xD0, 0xB8}; +static const symbol s_4_28[6] = {0xD0, 0xB8, 0xD0, 0xBB, 0xD0, 0xB8}; +static const symbol s_4_29[2] = {0xD0, 0xB9}; +static const symbol s_4_30[4] = {0xD1, 0x83, 0xD0, 0xB9}; +static const symbol s_4_31[4] = {0xD0, 0xB5, 0xD0, 0xB9}; +static const symbol s_4_32[2] = {0xD0, 0xBB}; +static const symbol s_4_33[4] = {0xD1, 0x8B, 0xD0, 0xBB}; +static const symbol s_4_34[4] = {0xD0, 0xB8, 0xD0, 0xBB}; +static const symbol s_4_35[4] = {0xD1, 0x8B, 0xD0, 0xBC}; +static const symbol s_4_36[4] = {0xD0, 0xB5, 0xD0, 0xBC}; +static const symbol s_4_37[4] = {0xD0, 0xB8, 0xD0, 0xBC}; +static const symbol s_4_38[2] = {0xD0, 0xBD}; +static const symbol s_4_39[4] = {0xD0, 0xB5, 0xD0, 0xBD}; +static const symbol s_4_40[4] = {0xD0, 0xBB, 0xD0, 0xBE}; +static const symbol s_4_41[6] = {0xD1, 0x8B, 0xD0, 0xBB, 0xD0, 0xBE}; +static const symbol s_4_42[6] = {0xD0, 0xB8, 0xD0, 0xBB, 0xD0, 0xBE}; +static const symbol s_4_43[4] = {0xD0, 0xBD, 0xD0, 0xBE}; +static const symbol s_4_44[6] = {0xD0, 0xB5, 0xD0, 0xBD, 0xD0, 0xBE}; +static const symbol s_4_45[6] = {0xD0, 0xBD, 0xD0, 0xBD, 0xD0, 0xBE}; + +static const struct among a_4[46] = { + /* 0 */ {4, s_4_0, -1, 2, 0}, + /* 1 */ {4, s_4_1, -1, 1, 0}, + /* 2 */ {6, s_4_2, 1, 2, 0}, + /* 3 */ {4, s_4_3, -1, 2, 0}, + /* 4 */ {4, s_4_4, -1, 1, 0}, + /* 5 */ {6, s_4_5, 4, 2, 0}, + /* 6 */ {4, s_4_6, -1, 2, 0}, + /* 7 */ {4, s_4_7, -1, 1, 0}, + /* 8 */ {6, s_4_8, 7, 2, 0}, + /* 9 */ {4, s_4_9, -1, 1, 0}, + /* 10 */ {6, s_4_10, 9, 2, 0}, + /* 11 */ {6, s_4_11, 9, 2, 0}, + /* 12 */ {6, s_4_12, -1, 1, 0}, + /* 13 */ {6, s_4_13, -1, 2, 0}, + /* 14 */ {2, s_4_14, -1, 2, 0}, + /* 15 */ {4, s_4_15, 14, 2, 0}, + /* 16 */ {4, s_4_16, -1, 1, 0}, + /* 17 */ {6, s_4_17, 16, 2, 0}, + /* 18 */ {6, s_4_18, 16, 2, 0}, + /* 19 */ {4, s_4_19, -1, 1, 0}, + /* 20 */ {6, s_4_20, 19, 2, 0}, + /* 21 */ {6, s_4_21, -1, 1, 0}, + /* 22 */ {6, s_4_22, -1, 2, 0}, + /* 23 */ {6, s_4_23, -1, 1, 0}, + /* 24 */ {8, s_4_24, 23, 2, 0}, + /* 25 */ {8, s_4_25, 23, 2, 0}, + /* 26 */ {4, s_4_26, -1, 1, 0}, + /* 27 */ {6, s_4_27, 26, 2, 0}, + /* 28 */ {6, s_4_28, 26, 2, 0}, + /* 29 */ {2, s_4_29, -1, 1, 0}, + /* 30 */ {4, s_4_30, 29, 2, 0}, + /* 31 */ {4, s_4_31, 29, 2, 0}, + /* 32 */ {2, s_4_32, -1, 1, 0}, + /* 33 */ {4, s_4_33, 32, 2, 0}, + /* 34 */ {4, s_4_34, 32, 2, 0}, + /* 35 */ {4, s_4_35, -1, 2, 0}, + /* 36 */ {4, s_4_36, -1, 1, 0}, + /* 37 */ {4, s_4_37, -1, 2, 0}, + /* 38 */ {2, s_4_38, -1, 1, 0}, + /* 39 */ {4, s_4_39, 38, 2, 0}, + /* 40 */ {4, s_4_40, -1, 1, 0}, + /* 41 */ {6, s_4_41, 40, 2, 0}, + /* 42 */ {6, s_4_42, 40, 2, 0}, + /* 43 */ {4, s_4_43, -1, 1, 0}, + /* 44 */ {6, s_4_44, 43, 2, 0}, + /* 45 */ {6, s_4_45, 43, 1, 0}}; + +static const symbol s_5_0[2] = {0xD1, 0x83}; +static const symbol s_5_1[4] = {0xD1, 0x8F, 0xD1, 0x85}; +static const symbol s_5_2[6] = {0xD0, 0xB8, 0xD1, 0x8F, 0xD1, 0x85}; +static const symbol s_5_3[4] = {0xD0, 0xB0, 0xD1, 0x85}; +static const symbol s_5_4[2] = {0xD1, 0x8B}; +static const symbol s_5_5[2] = {0xD1, 0x8C}; +static const symbol s_5_6[2] = {0xD1, 0x8E}; +static const symbol s_5_7[4] = {0xD1, 0x8C, 0xD1, 0x8E}; +static const symbol s_5_8[4] = {0xD0, 0xB8, 0xD1, 0x8E}; +static const symbol s_5_9[2] = {0xD1, 0x8F}; +static const symbol s_5_10[4] = {0xD1, 0x8C, 0xD1, 0x8F}; +static const symbol s_5_11[4] = {0xD0, 0xB8, 0xD1, 0x8F}; +static const symbol s_5_12[2] = {0xD0, 0xB0}; +static const symbol s_5_13[4] = {0xD0, 0xB5, 0xD0, 0xB2}; +static const symbol s_5_14[4] = {0xD0, 0xBE, 0xD0, 0xB2}; +static const symbol s_5_15[2] = {0xD0, 0xB5}; +static const symbol s_5_16[4] = {0xD1, 0x8C, 0xD0, 0xB5}; +static const symbol s_5_17[4] = {0xD0, 0xB8, 0xD0, 0xB5}; +static const symbol s_5_18[2] = {0xD0, 0xB8}; +static const symbol s_5_19[4] = {0xD0, 0xB5, 0xD0, 0xB8}; +static const symbol s_5_20[4] = {0xD0, 0xB8, 0xD0, 0xB8}; +static const symbol s_5_21[6] = {0xD1, 0x8F, 0xD0, 0xBC, 0xD0, 0xB8}; +static const symbol s_5_22[8] = {0xD0, 0xB8, 0xD1, 0x8F, 0xD0, 0xBC, 0xD0, 0xB8}; +static const symbol s_5_23[6] = {0xD0, 0xB0, 0xD0, 0xBC, 0xD0, 0xB8}; +static const symbol s_5_24[2] = {0xD0, 0xB9}; +static const symbol s_5_25[4] = {0xD0, 0xB5, 0xD0, 0xB9}; +static const symbol s_5_26[6] = {0xD0, 0xB8, 0xD0, 0xB5, 0xD0, 0xB9}; +static const symbol s_5_27[4] = {0xD0, 0xB8, 0xD0, 0xB9}; +static const symbol s_5_28[4] = {0xD0, 0xBE, 0xD0, 0xB9}; +static const symbol s_5_29[4] = {0xD1, 0x8F, 0xD0, 0xBC}; +static const symbol s_5_30[6] = {0xD0, 0xB8, 0xD1, 0x8F, 0xD0, 0xBC}; +static const symbol s_5_31[4] = {0xD0, 0xB0, 0xD0, 0xBC}; +static const symbol s_5_32[4] = {0xD0, 0xB5, 0xD0, 0xBC}; +static const symbol s_5_33[6] = {0xD0, 0xB8, 0xD0, 0xB5, 0xD0, 0xBC}; +static const symbol s_5_34[4] = {0xD0, 0xBE, 0xD0, 0xBC}; +static const symbol s_5_35[2] = {0xD0, 0xBE}; + +static const struct among a_5[36] = { + /* 0 */ {2, s_5_0, -1, 1, 0}, + /* 1 */ {4, s_5_1, -1, 1, 0}, + /* 2 */ {6, s_5_2, 1, 1, 0}, + /* 3 */ {4, s_5_3, -1, 1, 0}, + /* 4 */ {2, s_5_4, -1, 1, 0}, + /* 5 */ {2, s_5_5, -1, 1, 0}, + /* 6 */ {2, s_5_6, -1, 1, 0}, + /* 7 */ {4, s_5_7, 6, 1, 0}, + /* 8 */ {4, s_5_8, 6, 1, 0}, + /* 9 */ {2, s_5_9, -1, 1, 0}, + /* 10 */ {4, s_5_10, 9, 1, 0}, + /* 11 */ {4, s_5_11, 9, 1, 0}, + /* 12 */ {2, s_5_12, -1, 1, 0}, + /* 13 */ {4, s_5_13, -1, 1, 0}, + /* 14 */ {4, s_5_14, -1, 1, 0}, + /* 15 */ {2, s_5_15, -1, 1, 0}, + /* 16 */ {4, s_5_16, 15, 1, 0}, + /* 17 */ {4, s_5_17, 15, 1, 0}, + /* 18 */ {2, s_5_18, -1, 1, 0}, + /* 19 */ {4, s_5_19, 18, 1, 0}, + /* 20 */ {4, s_5_20, 18, 1, 0}, + /* 21 */ {6, s_5_21, 18, 1, 0}, + /* 22 */ {8, s_5_22, 21, 1, 0}, + /* 23 */ {6, s_5_23, 18, 1, 0}, + /* 24 */ {2, s_5_24, -1, 1, 0}, + /* 25 */ {4, s_5_25, 24, 1, 0}, + /* 26 */ {6, s_5_26, 25, 1, 0}, + /* 27 */ {4, s_5_27, 24, 1, 0}, + /* 28 */ {4, s_5_28, 24, 1, 0}, + /* 29 */ {4, s_5_29, -1, 1, 0}, + /* 30 */ {6, s_5_30, 29, 1, 0}, + /* 31 */ {4, s_5_31, -1, 1, 0}, + /* 32 */ {4, s_5_32, -1, 1, 0}, + /* 33 */ {6, s_5_33, 32, 1, 0}, + /* 34 */ {4, s_5_34, -1, 1, 0}, + /* 35 */ {2, s_5_35, -1, 1, 0}}; + +static const symbol s_6_0[6] = {0xD0, 0xBE, 0xD1, 0x81, 0xD1, 0x82}; +static const symbol s_6_1[8] = {0xD0, 0xBE, 0xD1, 0x81, 0xD1, 0x82, 0xD1, 0x8C}; + +static const struct among a_6[2] = { + /* 0 */ {6, s_6_0, -1, 1, 0}, + /* 1 */ {8, s_6_1, -1, 1, 0}}; + +static const symbol s_7_0[6] = {0xD0, 0xB5, 0xD0, 0xB9, 0xD1, 0x88}; +static const symbol s_7_1[2] = {0xD1, 0x8C}; +static const symbol s_7_2[8] = {0xD0, 0xB5, 0xD0, 0xB9, 0xD1, 0x88, 0xD0, 0xB5}; +static const symbol s_7_3[2] = {0xD0, 0xBD}; + +static const struct among a_7[4] = { + /* 0 */ {6, s_7_0, -1, 1, 0}, + /* 1 */ {2, s_7_1, -1, 3, 0}, + /* 2 */ {8, s_7_2, -1, 1, 0}, + /* 3 */ {2, s_7_3, -1, 2, 0}}; + +static const unsigned char g_v[] = {33, 65, 8, 232}; + +static const symbol s_0[] = {0xD0, 0xB0}; +static const symbol s_1[] = {0xD1, 0x8F}; +static const symbol s_2[] = {0xD0, 0xB0}; +static const symbol s_3[] = {0xD1, 0x8F}; +static const symbol s_4[] = {0xD0, 0xB0}; +static const symbol s_5[] = {0xD1, 0x8F}; +static const symbol s_6[] = {0xD0, 0xBD}; +static const symbol s_7[] = {0xD0, 0xBD}; +static const symbol s_8[] = {0xD0, 0xBD}; +static const symbol s_9[] = {0xD0, 0xB8}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + { + int c1 = z->c; /* do, line 61 */ + { /* gopast */ /* grouping v, line 62 */ + int ret = out_grouping_U(z, g_v, 1072, 1103, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + z->I[0] = z->c; /* setmark pV, line 62 */ + { /* gopast */ /* non v, line 62 */ + int ret = in_grouping_U(z, g_v, 1072, 1103, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + { /* gopast */ /* grouping v, line 63 */ + int ret = out_grouping_U(z, g_v, 1072, 1103, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + { /* gopast */ /* non v, line 63 */ + int ret = in_grouping_U(z, g_v, 1072, 1103, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + z->I[1] = z->c; /* setmark p2, line 63 */ + lab0: + z->c = c1; + } + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_perfective_gerund(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 72 */ + among_var = find_among_b(z, a_0, 9); /* substring, line 72 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 72 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int m1 = z->l - z->c; + (void)m1; /* or, line 76 */ + if (!(eq_s_b(z, 2, s_0))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_1))) + return 0; + } + lab0: { + int ret = slice_del(z); /* delete, line 76 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_del(z); /* delete, line 83 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_adjective(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 88 */ + among_var = find_among_b(z, a_1, 26); /* substring, line 88 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 88 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 97 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_adjectival(struct SN_env *z) { + int among_var; + { + int ret = r_adjective(z); + if (ret == 0) + return 0; /* call adjective, line 102 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 109 */ + z->ket = z->c; /* [, line 110 */ + among_var = find_among_b(z, a_2, 8); /* substring, line 110 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 110 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab0; + } + case 1: { + int m1 = z->l - z->c; + (void)m1; /* or, line 115 */ + if (!(eq_s_b(z, 2, s_2))) + goto lab2; + goto lab1; + lab2: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_3))) { + z->c = z->l - m_keep; + goto lab0; + } + } + lab1: { + int ret = slice_del(z); /* delete, line 115 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_del(z); /* delete, line 122 */ + if (ret < 0) + return ret; + } break; + } + lab0:; + } + return 1; +} + +static int r_reflexive(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 129 */ + if (z->c - 3 <= z->lb || (z->p[z->c - 1] != 140 && z->p[z->c - 1] != 143)) + return 0; + among_var = find_among_b(z, a_3, 2); /* substring, line 129 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 129 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 132 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_verb(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 137 */ + among_var = find_among_b(z, a_4, 46); /* substring, line 137 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 137 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int m1 = z->l - z->c; + (void)m1; /* or, line 143 */ + if (!(eq_s_b(z, 2, s_4))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_5))) + return 0; + } + lab0: { + int ret = slice_del(z); /* delete, line 143 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_del(z); /* delete, line 151 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_noun(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 160 */ + among_var = find_among_b(z, a_5, 36); /* substring, line 160 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 160 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 167 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_derivational(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 176 */ + if (z->c - 5 <= z->lb || (z->p[z->c - 1] != 130 && z->p[z->c - 1] != 140)) + return 0; + among_var = find_among_b(z, a_6, 2); /* substring, line 176 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 176 */ + { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 176 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 179 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_tidy_up(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 184 */ + among_var = find_among_b(z, a_7, 4); /* substring, line 184 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 184 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 188 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 189 */ + if (!(eq_s_b(z, 2, s_6))) + return 0; + z->bra = z->c; /* ], line 189 */ + if (!(eq_s_b(z, 2, s_7))) + return 0; + { + int ret = slice_del(z); /* delete, line 189 */ + if (ret < 0) + return ret; + } + break; + case 2: + if (!(eq_s_b(z, 2, s_8))) + return 0; + { + int ret = slice_del(z); /* delete, line 192 */ + if (ret < 0) + return ret; + } + break; + case 3: { + int ret = slice_del(z); /* delete, line 194 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +extern int russian_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 201 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 201 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 202 */ + + { + int mlimit; /* setlimit, line 202 */ + int m2 = z->l - z->c; + (void)m2; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 202 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m2; + { + int m3 = z->l - z->c; + (void)m3; /* do, line 203 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 204 */ + { + int ret = r_perfective_gerund(z); + if (ret == 0) + goto lab3; /* call perfective_gerund, line 204 */ + if (ret < 0) + return ret; + } + goto lab2; + lab3: + z->c = z->l - m4; + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 205 */ + { + int ret = r_reflexive(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab4; + } /* call reflexive, line 205 */ + if (ret < 0) + return ret; + } + lab4:; + } + { + int m5 = z->l - z->c; + (void)m5; /* or, line 206 */ + { + int ret = r_adjectival(z); + if (ret == 0) + goto lab6; /* call adjectival, line 206 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = z->l - m5; + { + int ret = r_verb(z); + if (ret == 0) + goto lab7; /* call verb, line 206 */ + if (ret < 0) + return ret; + } + goto lab5; + lab7: + z->c = z->l - m5; + { + int ret = r_noun(z); + if (ret == 0) + goto lab1; /* call noun, line 206 */ + if (ret < 0) + return ret; + } + } + lab5:; + } + lab2: + lab1: + z->c = z->l - m3; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 209 */ + z->ket = z->c; /* [, line 209 */ + if (!(eq_s_b(z, 2, s_9))) { + z->c = z->l - m_keep; + goto lab8; + } + z->bra = z->c; /* ], line 209 */ + { + int ret = slice_del(z); /* delete, line 209 */ + if (ret < 0) + return ret; + } + lab8:; + } + { + int m6 = z->l - z->c; + (void)m6; /* do, line 212 */ + { + int ret = r_derivational(z); + if (ret == 0) + goto lab9; /* call derivational, line 212 */ + if (ret < 0) + return ret; + } + lab9: + z->c = z->l - m6; + } + { + int m7 = z->l - z->c; + (void)m7; /* do, line 213 */ + { + int ret = r_tidy_up(z); + if (ret == 0) + goto lab10; /* call tidy_up, line 213 */ + if (ret < 0) + return ret; + } + lab10: + z->c = z->l - m7; + } + z->lb = mlimit; + } + z->c = z->lb; + return 1; +} + +extern struct SN_env *russian_UTF_8_create_env(void) { return SN_create_env(0, 2, 0); } + +extern void russian_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_russian.h b/internal/cpp/stemmer/stem_UTF_8_russian.h new file mode 100644 index 00000000000..5ed058f6360 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_russian.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *russian_UTF_8_create_env(void); +extern void russian_UTF_8_close_env(struct SN_env *z); + +extern int russian_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_spanish.cpp b/internal/cpp/stemmer/stem_UTF_8_spanish.cpp new file mode 100644 index 00000000000..1883e2c7c0b --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_spanish.cpp @@ -0,0 +1,1319 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int spanish_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_residual_suffix(struct SN_env *z); +static int r_verb_suffix(struct SN_env *z); +static int r_y_verb_suffix(struct SN_env *z); +static int r_standard_suffix(struct SN_env *z); +static int r_attached_pronoun(struct SN_env *z); +static int r_R2(struct SN_env *z); +static int r_R1(struct SN_env *z); +static int r_RV(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +static int r_postlude(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *spanish_UTF_8_create_env(void); +extern void spanish_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_1[2] = {0xC3, 0xA1}; +static const symbol s_0_2[2] = {0xC3, 0xA9}; +static const symbol s_0_3[2] = {0xC3, 0xAD}; +static const symbol s_0_4[2] = {0xC3, 0xB3}; +static const symbol s_0_5[2] = {0xC3, 0xBA}; + +static const struct among a_0[6] = { + /* 0 */ {0, 0, -1, 6, 0}, + /* 1 */ {2, s_0_1, 0, 1, 0}, + /* 2 */ {2, s_0_2, 0, 2, 0}, + /* 3 */ {2, s_0_3, 0, 3, 0}, + /* 4 */ {2, s_0_4, 0, 4, 0}, + /* 5 */ {2, s_0_5, 0, 5, 0}}; + +static const symbol s_1_0[2] = {'l', 'a'}; +static const symbol s_1_1[4] = {'s', 'e', 'l', 'a'}; +static const symbol s_1_2[2] = {'l', 'e'}; +static const symbol s_1_3[2] = {'m', 'e'}; +static const symbol s_1_4[2] = {'s', 'e'}; +static const symbol s_1_5[2] = {'l', 'o'}; +static const symbol s_1_6[4] = {'s', 'e', 'l', 'o'}; +static const symbol s_1_7[3] = {'l', 'a', 's'}; +static const symbol s_1_8[5] = {'s', 'e', 'l', 'a', 's'}; +static const symbol s_1_9[3] = {'l', 'e', 's'}; +static const symbol s_1_10[3] = {'l', 'o', 's'}; +static const symbol s_1_11[5] = {'s', 'e', 'l', 'o', 's'}; +static const symbol s_1_12[3] = {'n', 'o', 's'}; + +static const struct among a_1[13] = { + /* 0 */ {2, s_1_0, -1, -1, 0}, + /* 1 */ {4, s_1_1, 0, -1, 0}, + /* 2 */ {2, s_1_2, -1, -1, 0}, + /* 3 */ {2, s_1_3, -1, -1, 0}, + /* 4 */ {2, s_1_4, -1, -1, 0}, + /* 5 */ {2, s_1_5, -1, -1, 0}, + /* 6 */ {4, s_1_6, 5, -1, 0}, + /* 7 */ {3, s_1_7, -1, -1, 0}, + /* 8 */ {5, s_1_8, 7, -1, 0}, + /* 9 */ {3, s_1_9, -1, -1, 0}, + /* 10 */ {3, s_1_10, -1, -1, 0}, + /* 11 */ {5, s_1_11, 10, -1, 0}, + /* 12 */ {3, s_1_12, -1, -1, 0}}; + +static const symbol s_2_0[4] = {'a', 'n', 'd', 'o'}; +static const symbol s_2_1[5] = {'i', 'e', 'n', 'd', 'o'}; +static const symbol s_2_2[5] = {'y', 'e', 'n', 'd', 'o'}; +static const symbol s_2_3[5] = {0xC3, 0xA1, 'n', 'd', 'o'}; +static const symbol s_2_4[6] = {'i', 0xC3, 0xA9, 'n', 'd', 'o'}; +static const symbol s_2_5[2] = {'a', 'r'}; +static const symbol s_2_6[2] = {'e', 'r'}; +static const symbol s_2_7[2] = {'i', 'r'}; +static const symbol s_2_8[3] = {0xC3, 0xA1, 'r'}; +static const symbol s_2_9[3] = {0xC3, 0xA9, 'r'}; +static const symbol s_2_10[3] = {0xC3, 0xAD, 'r'}; + +static const struct among a_2[11] = { + /* 0 */ {4, s_2_0, -1, 6, 0}, + /* 1 */ {5, s_2_1, -1, 6, 0}, + /* 2 */ {5, s_2_2, -1, 7, 0}, + /* 3 */ {5, s_2_3, -1, 2, 0}, + /* 4 */ {6, s_2_4, -1, 1, 0}, + /* 5 */ {2, s_2_5, -1, 6, 0}, + /* 6 */ {2, s_2_6, -1, 6, 0}, + /* 7 */ {2, s_2_7, -1, 6, 0}, + /* 8 */ {3, s_2_8, -1, 3, 0}, + /* 9 */ {3, s_2_9, -1, 4, 0}, + /* 10 */ {3, s_2_10, -1, 5, 0}}; + +static const symbol s_3_0[2] = {'i', 'c'}; +static const symbol s_3_1[2] = {'a', 'd'}; +static const symbol s_3_2[2] = {'o', 's'}; +static const symbol s_3_3[2] = {'i', 'v'}; + +static const struct among a_3[4] = { + /* 0 */ {2, s_3_0, -1, -1, 0}, + /* 1 */ {2, s_3_1, -1, -1, 0}, + /* 2 */ {2, s_3_2, -1, -1, 0}, + /* 3 */ {2, s_3_3, -1, 1, 0}}; + +static const symbol s_4_0[4] = {'a', 'b', 'l', 'e'}; +static const symbol s_4_1[4] = {'i', 'b', 'l', 'e'}; +static const symbol s_4_2[4] = {'a', 'n', 't', 'e'}; + +static const struct among a_4[3] = { + /* 0 */ {4, s_4_0, -1, 1, 0}, + /* 1 */ {4, s_4_1, -1, 1, 0}, + /* 2 */ {4, s_4_2, -1, 1, 0}}; + +static const symbol s_5_0[2] = {'i', 'c'}; +static const symbol s_5_1[4] = {'a', 'b', 'i', 'l'}; +static const symbol s_5_2[2] = {'i', 'v'}; + +static const struct among a_5[3] = { + /* 0 */ {2, s_5_0, -1, 1, 0}, + /* 1 */ {4, s_5_1, -1, 1, 0}, + /* 2 */ {2, s_5_2, -1, 1, 0}}; + +static const symbol s_6_0[3] = {'i', 'c', 'a'}; +static const symbol s_6_1[5] = {'a', 'n', 'c', 'i', 'a'}; +static const symbol s_6_2[5] = {'e', 'n', 'c', 'i', 'a'}; +static const symbol s_6_3[5] = {'a', 'd', 'o', 'r', 'a'}; +static const symbol s_6_4[3] = {'o', 's', 'a'}; +static const symbol s_6_5[4] = {'i', 's', 't', 'a'}; +static const symbol s_6_6[3] = {'i', 'v', 'a'}; +static const symbol s_6_7[4] = {'a', 'n', 'z', 'a'}; +static const symbol s_6_8[6] = {'l', 'o', 'g', 0xC3, 0xAD, 'a'}; +static const symbol s_6_9[4] = {'i', 'd', 'a', 'd'}; +static const symbol s_6_10[4] = {'a', 'b', 'l', 'e'}; +static const symbol s_6_11[4] = {'i', 'b', 'l', 'e'}; +static const symbol s_6_12[4] = {'a', 'n', 't', 'e'}; +static const symbol s_6_13[5] = {'m', 'e', 'n', 't', 'e'}; +static const symbol s_6_14[6] = {'a', 'm', 'e', 'n', 't', 'e'}; +static const symbol s_6_15[6] = {'a', 'c', 'i', 0xC3, 0xB3, 'n'}; +static const symbol s_6_16[6] = {'u', 'c', 'i', 0xC3, 0xB3, 'n'}; +static const symbol s_6_17[3] = {'i', 'c', 'o'}; +static const symbol s_6_18[4] = {'i', 's', 'm', 'o'}; +static const symbol s_6_19[3] = {'o', 's', 'o'}; +static const symbol s_6_20[7] = {'a', 'm', 'i', 'e', 'n', 't', 'o'}; +static const symbol s_6_21[7] = {'i', 'm', 'i', 'e', 'n', 't', 'o'}; +static const symbol s_6_22[3] = {'i', 'v', 'o'}; +static const symbol s_6_23[4] = {'a', 'd', 'o', 'r'}; +static const symbol s_6_24[4] = {'i', 'c', 'a', 's'}; +static const symbol s_6_25[6] = {'a', 'n', 'c', 'i', 'a', 's'}; +static const symbol s_6_26[6] = {'e', 'n', 'c', 'i', 'a', 's'}; +static const symbol s_6_27[6] = {'a', 'd', 'o', 'r', 'a', 's'}; +static const symbol s_6_28[4] = {'o', 's', 'a', 's'}; +static const symbol s_6_29[5] = {'i', 's', 't', 'a', 's'}; +static const symbol s_6_30[4] = {'i', 'v', 'a', 's'}; +static const symbol s_6_31[5] = {'a', 'n', 'z', 'a', 's'}; +static const symbol s_6_32[7] = {'l', 'o', 'g', 0xC3, 0xAD, 'a', 's'}; +static const symbol s_6_33[6] = {'i', 'd', 'a', 'd', 'e', 's'}; +static const symbol s_6_34[5] = {'a', 'b', 'l', 'e', 's'}; +static const symbol s_6_35[5] = {'i', 'b', 'l', 'e', 's'}; +static const symbol s_6_36[7] = {'a', 'c', 'i', 'o', 'n', 'e', 's'}; +static const symbol s_6_37[7] = {'u', 'c', 'i', 'o', 'n', 'e', 's'}; +static const symbol s_6_38[6] = {'a', 'd', 'o', 'r', 'e', 's'}; +static const symbol s_6_39[5] = {'a', 'n', 't', 'e', 's'}; +static const symbol s_6_40[4] = {'i', 'c', 'o', 's'}; +static const symbol s_6_41[5] = {'i', 's', 'm', 'o', 's'}; +static const symbol s_6_42[4] = {'o', 's', 'o', 's'}; +static const symbol s_6_43[8] = {'a', 'm', 'i', 'e', 'n', 't', 'o', 's'}; +static const symbol s_6_44[8] = {'i', 'm', 'i', 'e', 'n', 't', 'o', 's'}; +static const symbol s_6_45[4] = {'i', 'v', 'o', 's'}; + +static const struct among a_6[46] = { + /* 0 */ {3, s_6_0, -1, 1, 0}, + /* 1 */ {5, s_6_1, -1, 2, 0}, + /* 2 */ {5, s_6_2, -1, 5, 0}, + /* 3 */ {5, s_6_3, -1, 2, 0}, + /* 4 */ {3, s_6_4, -1, 1, 0}, + /* 5 */ {4, s_6_5, -1, 1, 0}, + /* 6 */ {3, s_6_6, -1, 9, 0}, + /* 7 */ {4, s_6_7, -1, 1, 0}, + /* 8 */ {6, s_6_8, -1, 3, 0}, + /* 9 */ {4, s_6_9, -1, 8, 0}, + /* 10 */ {4, s_6_10, -1, 1, 0}, + /* 11 */ {4, s_6_11, -1, 1, 0}, + /* 12 */ {4, s_6_12, -1, 2, 0}, + /* 13 */ {5, s_6_13, -1, 7, 0}, + /* 14 */ {6, s_6_14, 13, 6, 0}, + /* 15 */ {6, s_6_15, -1, 2, 0}, + /* 16 */ {6, s_6_16, -1, 4, 0}, + /* 17 */ {3, s_6_17, -1, 1, 0}, + /* 18 */ {4, s_6_18, -1, 1, 0}, + /* 19 */ {3, s_6_19, -1, 1, 0}, + /* 20 */ {7, s_6_20, -1, 1, 0}, + /* 21 */ {7, s_6_21, -1, 1, 0}, + /* 22 */ {3, s_6_22, -1, 9, 0}, + /* 23 */ {4, s_6_23, -1, 2, 0}, + /* 24 */ {4, s_6_24, -1, 1, 0}, + /* 25 */ {6, s_6_25, -1, 2, 0}, + /* 26 */ {6, s_6_26, -1, 5, 0}, + /* 27 */ {6, s_6_27, -1, 2, 0}, + /* 28 */ {4, s_6_28, -1, 1, 0}, + /* 29 */ {5, s_6_29, -1, 1, 0}, + /* 30 */ {4, s_6_30, -1, 9, 0}, + /* 31 */ {5, s_6_31, -1, 1, 0}, + /* 32 */ {7, s_6_32, -1, 3, 0}, + /* 33 */ {6, s_6_33, -1, 8, 0}, + /* 34 */ {5, s_6_34, -1, 1, 0}, + /* 35 */ {5, s_6_35, -1, 1, 0}, + /* 36 */ {7, s_6_36, -1, 2, 0}, + /* 37 */ {7, s_6_37, -1, 4, 0}, + /* 38 */ {6, s_6_38, -1, 2, 0}, + /* 39 */ {5, s_6_39, -1, 2, 0}, + /* 40 */ {4, s_6_40, -1, 1, 0}, + /* 41 */ {5, s_6_41, -1, 1, 0}, + /* 42 */ {4, s_6_42, -1, 1, 0}, + /* 43 */ {8, s_6_43, -1, 1, 0}, + /* 44 */ {8, s_6_44, -1, 1, 0}, + /* 45 */ {4, s_6_45, -1, 9, 0}}; + +static const symbol s_7_0[2] = {'y', 'a'}; +static const symbol s_7_1[2] = {'y', 'e'}; +static const symbol s_7_2[3] = {'y', 'a', 'n'}; +static const symbol s_7_3[3] = {'y', 'e', 'n'}; +static const symbol s_7_4[5] = {'y', 'e', 'r', 'o', 'n'}; +static const symbol s_7_5[5] = {'y', 'e', 'n', 'd', 'o'}; +static const symbol s_7_6[2] = {'y', 'o'}; +static const symbol s_7_7[3] = {'y', 'a', 's'}; +static const symbol s_7_8[3] = {'y', 'e', 's'}; +static const symbol s_7_9[4] = {'y', 'a', 'i', 's'}; +static const symbol s_7_10[5] = {'y', 'a', 'm', 'o', 's'}; +static const symbol s_7_11[3] = {'y', 0xC3, 0xB3}; + +static const struct among a_7[12] = { + /* 0 */ {2, s_7_0, -1, 1, 0}, + /* 1 */ {2, s_7_1, -1, 1, 0}, + /* 2 */ {3, s_7_2, -1, 1, 0}, + /* 3 */ {3, s_7_3, -1, 1, 0}, + /* 4 */ {5, s_7_4, -1, 1, 0}, + /* 5 */ {5, s_7_5, -1, 1, 0}, + /* 6 */ {2, s_7_6, -1, 1, 0}, + /* 7 */ {3, s_7_7, -1, 1, 0}, + /* 8 */ {3, s_7_8, -1, 1, 0}, + /* 9 */ {4, s_7_9, -1, 1, 0}, + /* 10 */ {5, s_7_10, -1, 1, 0}, + /* 11 */ {3, s_7_11, -1, 1, 0}}; + +static const symbol s_8_0[3] = {'a', 'b', 'a'}; +static const symbol s_8_1[3] = {'a', 'd', 'a'}; +static const symbol s_8_2[3] = {'i', 'd', 'a'}; +static const symbol s_8_3[3] = {'a', 'r', 'a'}; +static const symbol s_8_4[4] = {'i', 'e', 'r', 'a'}; +static const symbol s_8_5[3] = {0xC3, 0xAD, 'a'}; +static const symbol s_8_6[5] = {'a', 'r', 0xC3, 0xAD, 'a'}; +static const symbol s_8_7[5] = {'e', 'r', 0xC3, 0xAD, 'a'}; +static const symbol s_8_8[5] = {'i', 'r', 0xC3, 0xAD, 'a'}; +static const symbol s_8_9[2] = {'a', 'd'}; +static const symbol s_8_10[2] = {'e', 'd'}; +static const symbol s_8_11[2] = {'i', 'd'}; +static const symbol s_8_12[3] = {'a', 's', 'e'}; +static const symbol s_8_13[4] = {'i', 'e', 's', 'e'}; +static const symbol s_8_14[4] = {'a', 's', 't', 'e'}; +static const symbol s_8_15[4] = {'i', 's', 't', 'e'}; +static const symbol s_8_16[2] = {'a', 'n'}; +static const symbol s_8_17[4] = {'a', 'b', 'a', 'n'}; +static const symbol s_8_18[4] = {'a', 'r', 'a', 'n'}; +static const symbol s_8_19[5] = {'i', 'e', 'r', 'a', 'n'}; +static const symbol s_8_20[4] = {0xC3, 0xAD, 'a', 'n'}; +static const symbol s_8_21[6] = {'a', 'r', 0xC3, 0xAD, 'a', 'n'}; +static const symbol s_8_22[6] = {'e', 'r', 0xC3, 0xAD, 'a', 'n'}; +static const symbol s_8_23[6] = {'i', 'r', 0xC3, 0xAD, 'a', 'n'}; +static const symbol s_8_24[2] = {'e', 'n'}; +static const symbol s_8_25[4] = {'a', 's', 'e', 'n'}; +static const symbol s_8_26[5] = {'i', 'e', 's', 'e', 'n'}; +static const symbol s_8_27[4] = {'a', 'r', 'o', 'n'}; +static const symbol s_8_28[5] = {'i', 'e', 'r', 'o', 'n'}; +static const symbol s_8_29[5] = {'a', 'r', 0xC3, 0xA1, 'n'}; +static const symbol s_8_30[5] = {'e', 'r', 0xC3, 0xA1, 'n'}; +static const symbol s_8_31[5] = {'i', 'r', 0xC3, 0xA1, 'n'}; +static const symbol s_8_32[3] = {'a', 'd', 'o'}; +static const symbol s_8_33[3] = {'i', 'd', 'o'}; +static const symbol s_8_34[4] = {'a', 'n', 'd', 'o'}; +static const symbol s_8_35[5] = {'i', 'e', 'n', 'd', 'o'}; +static const symbol s_8_36[2] = {'a', 'r'}; +static const symbol s_8_37[2] = {'e', 'r'}; +static const symbol s_8_38[2] = {'i', 'r'}; +static const symbol s_8_39[2] = {'a', 's'}; +static const symbol s_8_40[4] = {'a', 'b', 'a', 's'}; +static const symbol s_8_41[4] = {'a', 'd', 'a', 's'}; +static const symbol s_8_42[4] = {'i', 'd', 'a', 's'}; +static const symbol s_8_43[4] = {'a', 'r', 'a', 's'}; +static const symbol s_8_44[5] = {'i', 'e', 'r', 'a', 's'}; +static const symbol s_8_45[4] = {0xC3, 0xAD, 'a', 's'}; +static const symbol s_8_46[6] = {'a', 'r', 0xC3, 0xAD, 'a', 's'}; +static const symbol s_8_47[6] = {'e', 'r', 0xC3, 0xAD, 'a', 's'}; +static const symbol s_8_48[6] = {'i', 'r', 0xC3, 0xAD, 'a', 's'}; +static const symbol s_8_49[2] = {'e', 's'}; +static const symbol s_8_50[4] = {'a', 's', 'e', 's'}; +static const symbol s_8_51[5] = {'i', 'e', 's', 'e', 's'}; +static const symbol s_8_52[5] = {'a', 'b', 'a', 'i', 's'}; +static const symbol s_8_53[5] = {'a', 'r', 'a', 'i', 's'}; +static const symbol s_8_54[6] = {'i', 'e', 'r', 'a', 'i', 's'}; +static const symbol s_8_55[5] = {0xC3, 0xAD, 'a', 'i', 's'}; +static const symbol s_8_56[7] = {'a', 'r', 0xC3, 0xAD, 'a', 'i', 's'}; +static const symbol s_8_57[7] = {'e', 'r', 0xC3, 0xAD, 'a', 'i', 's'}; +static const symbol s_8_58[7] = {'i', 'r', 0xC3, 0xAD, 'a', 'i', 's'}; +static const symbol s_8_59[5] = {'a', 's', 'e', 'i', 's'}; +static const symbol s_8_60[6] = {'i', 'e', 's', 'e', 'i', 's'}; +static const symbol s_8_61[6] = {'a', 's', 't', 'e', 'i', 's'}; +static const symbol s_8_62[6] = {'i', 's', 't', 'e', 'i', 's'}; +static const symbol s_8_63[4] = {0xC3, 0xA1, 'i', 's'}; +static const symbol s_8_64[4] = {0xC3, 0xA9, 'i', 's'}; +static const symbol s_8_65[6] = {'a', 'r', 0xC3, 0xA9, 'i', 's'}; +static const symbol s_8_66[6] = {'e', 'r', 0xC3, 0xA9, 'i', 's'}; +static const symbol s_8_67[6] = {'i', 'r', 0xC3, 0xA9, 'i', 's'}; +static const symbol s_8_68[4] = {'a', 'd', 'o', 's'}; +static const symbol s_8_69[4] = {'i', 'd', 'o', 's'}; +static const symbol s_8_70[4] = {'a', 'm', 'o', 's'}; +static const symbol s_8_71[7] = {0xC3, 0xA1, 'b', 'a', 'm', 'o', 's'}; +static const symbol s_8_72[7] = {0xC3, 0xA1, 'r', 'a', 'm', 'o', 's'}; +static const symbol s_8_73[8] = {'i', 0xC3, 0xA9, 'r', 'a', 'm', 'o', 's'}; +static const symbol s_8_74[6] = {0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_8_75[8] = {'a', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_8_76[8] = {'e', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_8_77[8] = {'i', 'r', 0xC3, 0xAD, 'a', 'm', 'o', 's'}; +static const symbol s_8_78[4] = {'e', 'm', 'o', 's'}; +static const symbol s_8_79[6] = {'a', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_8_80[6] = {'e', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_8_81[6] = {'i', 'r', 'e', 'm', 'o', 's'}; +static const symbol s_8_82[7] = {0xC3, 0xA1, 's', 'e', 'm', 'o', 's'}; +static const symbol s_8_83[8] = {'i', 0xC3, 0xA9, 's', 'e', 'm', 'o', 's'}; +static const symbol s_8_84[4] = {'i', 'm', 'o', 's'}; +static const symbol s_8_85[5] = {'a', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_8_86[5] = {'e', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_8_87[5] = {'i', 'r', 0xC3, 0xA1, 's'}; +static const symbol s_8_88[3] = {0xC3, 0xAD, 's'}; +static const symbol s_8_89[4] = {'a', 'r', 0xC3, 0xA1}; +static const symbol s_8_90[4] = {'e', 'r', 0xC3, 0xA1}; +static const symbol s_8_91[4] = {'i', 'r', 0xC3, 0xA1}; +static const symbol s_8_92[4] = {'a', 'r', 0xC3, 0xA9}; +static const symbol s_8_93[4] = {'e', 'r', 0xC3, 0xA9}; +static const symbol s_8_94[4] = {'i', 'r', 0xC3, 0xA9}; +static const symbol s_8_95[3] = {'i', 0xC3, 0xB3}; + +static const struct among a_8[96] = { + /* 0 */ {3, s_8_0, -1, 2, 0}, + /* 1 */ {3, s_8_1, -1, 2, 0}, + /* 2 */ {3, s_8_2, -1, 2, 0}, + /* 3 */ {3, s_8_3, -1, 2, 0}, + /* 4 */ {4, s_8_4, -1, 2, 0}, + /* 5 */ {3, s_8_5, -1, 2, 0}, + /* 6 */ {5, s_8_6, 5, 2, 0}, + /* 7 */ {5, s_8_7, 5, 2, 0}, + /* 8 */ {5, s_8_8, 5, 2, 0}, + /* 9 */ {2, s_8_9, -1, 2, 0}, + /* 10 */ {2, s_8_10, -1, 2, 0}, + /* 11 */ {2, s_8_11, -1, 2, 0}, + /* 12 */ {3, s_8_12, -1, 2, 0}, + /* 13 */ {4, s_8_13, -1, 2, 0}, + /* 14 */ {4, s_8_14, -1, 2, 0}, + /* 15 */ {4, s_8_15, -1, 2, 0}, + /* 16 */ {2, s_8_16, -1, 2, 0}, + /* 17 */ {4, s_8_17, 16, 2, 0}, + /* 18 */ {4, s_8_18, 16, 2, 0}, + /* 19 */ {5, s_8_19, 16, 2, 0}, + /* 20 */ {4, s_8_20, 16, 2, 0}, + /* 21 */ {6, s_8_21, 20, 2, 0}, + /* 22 */ {6, s_8_22, 20, 2, 0}, + /* 23 */ {6, s_8_23, 20, 2, 0}, + /* 24 */ {2, s_8_24, -1, 1, 0}, + /* 25 */ {4, s_8_25, 24, 2, 0}, + /* 26 */ {5, s_8_26, 24, 2, 0}, + /* 27 */ {4, s_8_27, -1, 2, 0}, + /* 28 */ {5, s_8_28, -1, 2, 0}, + /* 29 */ {5, s_8_29, -1, 2, 0}, + /* 30 */ {5, s_8_30, -1, 2, 0}, + /* 31 */ {5, s_8_31, -1, 2, 0}, + /* 32 */ {3, s_8_32, -1, 2, 0}, + /* 33 */ {3, s_8_33, -1, 2, 0}, + /* 34 */ {4, s_8_34, -1, 2, 0}, + /* 35 */ {5, s_8_35, -1, 2, 0}, + /* 36 */ {2, s_8_36, -1, 2, 0}, + /* 37 */ {2, s_8_37, -1, 2, 0}, + /* 38 */ {2, s_8_38, -1, 2, 0}, + /* 39 */ {2, s_8_39, -1, 2, 0}, + /* 40 */ {4, s_8_40, 39, 2, 0}, + /* 41 */ {4, s_8_41, 39, 2, 0}, + /* 42 */ {4, s_8_42, 39, 2, 0}, + /* 43 */ {4, s_8_43, 39, 2, 0}, + /* 44 */ {5, s_8_44, 39, 2, 0}, + /* 45 */ {4, s_8_45, 39, 2, 0}, + /* 46 */ {6, s_8_46, 45, 2, 0}, + /* 47 */ {6, s_8_47, 45, 2, 0}, + /* 48 */ {6, s_8_48, 45, 2, 0}, + /* 49 */ {2, s_8_49, -1, 1, 0}, + /* 50 */ {4, s_8_50, 49, 2, 0}, + /* 51 */ {5, s_8_51, 49, 2, 0}, + /* 52 */ {5, s_8_52, -1, 2, 0}, + /* 53 */ {5, s_8_53, -1, 2, 0}, + /* 54 */ {6, s_8_54, -1, 2, 0}, + /* 55 */ {5, s_8_55, -1, 2, 0}, + /* 56 */ {7, s_8_56, 55, 2, 0}, + /* 57 */ {7, s_8_57, 55, 2, 0}, + /* 58 */ {7, s_8_58, 55, 2, 0}, + /* 59 */ {5, s_8_59, -1, 2, 0}, + /* 60 */ {6, s_8_60, -1, 2, 0}, + /* 61 */ {6, s_8_61, -1, 2, 0}, + /* 62 */ {6, s_8_62, -1, 2, 0}, + /* 63 */ {4, s_8_63, -1, 2, 0}, + /* 64 */ {4, s_8_64, -1, 1, 0}, + /* 65 */ {6, s_8_65, 64, 2, 0}, + /* 66 */ {6, s_8_66, 64, 2, 0}, + /* 67 */ {6, s_8_67, 64, 2, 0}, + /* 68 */ {4, s_8_68, -1, 2, 0}, + /* 69 */ {4, s_8_69, -1, 2, 0}, + /* 70 */ {4, s_8_70, -1, 2, 0}, + /* 71 */ {7, s_8_71, 70, 2, 0}, + /* 72 */ {7, s_8_72, 70, 2, 0}, + /* 73 */ {8, s_8_73, 70, 2, 0}, + /* 74 */ {6, s_8_74, 70, 2, 0}, + /* 75 */ {8, s_8_75, 74, 2, 0}, + /* 76 */ {8, s_8_76, 74, 2, 0}, + /* 77 */ {8, s_8_77, 74, 2, 0}, + /* 78 */ {4, s_8_78, -1, 1, 0}, + /* 79 */ {6, s_8_79, 78, 2, 0}, + /* 80 */ {6, s_8_80, 78, 2, 0}, + /* 81 */ {6, s_8_81, 78, 2, 0}, + /* 82 */ {7, s_8_82, 78, 2, 0}, + /* 83 */ {8, s_8_83, 78, 2, 0}, + /* 84 */ {4, s_8_84, -1, 2, 0}, + /* 85 */ {5, s_8_85, -1, 2, 0}, + /* 86 */ {5, s_8_86, -1, 2, 0}, + /* 87 */ {5, s_8_87, -1, 2, 0}, + /* 88 */ {3, s_8_88, -1, 2, 0}, + /* 89 */ {4, s_8_89, -1, 2, 0}, + /* 90 */ {4, s_8_90, -1, 2, 0}, + /* 91 */ {4, s_8_91, -1, 2, 0}, + /* 92 */ {4, s_8_92, -1, 2, 0}, + /* 93 */ {4, s_8_93, -1, 2, 0}, + /* 94 */ {4, s_8_94, -1, 2, 0}, + /* 95 */ {3, s_8_95, -1, 2, 0}}; + +static const symbol s_9_0[1] = {'a'}; +static const symbol s_9_1[1] = {'e'}; +static const symbol s_9_2[1] = {'o'}; +static const symbol s_9_3[2] = {'o', 's'}; +static const symbol s_9_4[2] = {0xC3, 0xA1}; +static const symbol s_9_5[2] = {0xC3, 0xA9}; +static const symbol s_9_6[2] = {0xC3, 0xAD}; +static const symbol s_9_7[2] = {0xC3, 0xB3}; + +static const struct among a_9[8] = { + /* 0 */ {1, s_9_0, -1, 1, 0}, + /* 1 */ {1, s_9_1, -1, 2, 0}, + /* 2 */ {1, s_9_2, -1, 1, 0}, + /* 3 */ {2, s_9_3, -1, 1, 0}, + /* 4 */ {2, s_9_4, -1, 1, 0}, + /* 5 */ {2, s_9_5, -1, 2, 0}, + /* 6 */ {2, s_9_6, -1, 1, 0}, + /* 7 */ {2, s_9_7, -1, 1, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 17, 4, 10}; + +static const symbol s_0[] = {'a'}; +static const symbol s_1[] = {'e'}; +static const symbol s_2[] = {'i'}; +static const symbol s_3[] = {'o'}; +static const symbol s_4[] = {'u'}; +static const symbol s_5[] = {'i', 'e', 'n', 'd', 'o'}; +static const symbol s_6[] = {'a', 'n', 'd', 'o'}; +static const symbol s_7[] = {'a', 'r'}; +static const symbol s_8[] = {'e', 'r'}; +static const symbol s_9[] = {'i', 'r'}; +static const symbol s_10[] = {'u'}; +static const symbol s_11[] = {'i', 'c'}; +static const symbol s_12[] = {'l', 'o', 'g'}; +static const symbol s_13[] = {'u'}; +static const symbol s_14[] = {'e', 'n', 't', 'e'}; +static const symbol s_15[] = {'a', 't'}; +static const symbol s_16[] = {'a', 't'}; +static const symbol s_17[] = {'u'}; +static const symbol s_18[] = {'u'}; +static const symbol s_19[] = {'g'}; +static const symbol s_20[] = {'u'}; +static const symbol s_21[] = {'g'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + z->I[1] = z->l; + z->I[2] = z->l; + { + int c1 = z->c; /* do, line 37 */ + { + int c2 = z->c; /* or, line 39 */ + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab2; + { + int c3 = z->c; /* or, line 38 */ + if (out_grouping_U(z, g_v, 97, 252, 0)) + goto lab4; + { /* gopast */ /* grouping v, line 38 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab4; + z->c += ret; + } + goto lab3; + lab4: + z->c = c3; + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab2; + { /* gopast */ /* non v, line 38 */ + int ret = in_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab2; + z->c += ret; + } + } + lab3: + goto lab1; + lab2: + z->c = c2; + if (out_grouping_U(z, g_v, 97, 252, 0)) + goto lab0; + { + int c4 = z->c; /* or, line 40 */ + if (out_grouping_U(z, g_v, 97, 252, 0)) + goto lab6; + { /* gopast */ /* grouping v, line 40 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab6; + z->c += ret; + } + goto lab5; + lab6: + z->c = c4; + if (in_grouping_U(z, g_v, 97, 252, 0)) + goto lab0; + { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 40 */ + } + } + lab5:; + } + lab1: + z->I[0] = z->c; /* setmark pV, line 41 */ + lab0: + z->c = c1; + } + { + int c5 = z->c; /* do, line 43 */ + { /* gopast */ /* grouping v, line 44 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 44 */ + int ret = in_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[1] = z->c; /* setmark p1, line 44 */ + { /* gopast */ /* grouping v, line 45 */ + int ret = out_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + { /* gopast */ /* non v, line 45 */ + int ret = in_grouping_U(z, g_v, 97, 252, 1); + if (ret < 0) + goto lab7; + z->c += ret; + } + z->I[2] = z->c; /* setmark p2, line 45 */ + lab7: + z->c = c5; + } + return 1; +} + +static int r_postlude(struct SN_env *z) { + int among_var; + while (1) { /* repeat, line 49 */ + int c1 = z->c; + z->bra = z->c; /* [, line 50 */ + if (z->c + 1 >= z->l || z->p[z->c + 1] >> 5 != 5 || !((67641858 >> (z->p[z->c + 1] & 0x1f)) & 1)) + among_var = 6; + else + among_var = find_among(z, a_0, 6); /* substring, line 50 */ + if (!(among_var)) + goto lab0; + z->ket = z->c; /* ], line 50 */ + switch (among_var) { + case 0: + goto lab0; + case 1: { + int ret = slice_from_s(z, 1, s_0); /* <-, line 51 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 1, s_1); /* <-, line 52 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_2); /* <-, line 53 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_3); /* <-, line 54 */ + if (ret < 0) + return ret; + } break; + case 5: { + int ret = slice_from_s(z, 1, s_4); /* <-, line 55 */ + if (ret < 0) + return ret; + } break; + case 6: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab0; + z->c = ret; /* next, line 57 */ + } break; + } + continue; + lab0: + z->c = c1; + break; + } + return 1; +} + +static int r_RV(struct SN_env *z) { + if (!(z->I[0] <= z->c)) + return 0; + return 1; +} + +static int r_R1(struct SN_env *z) { + if (!(z->I[1] <= z->c)) + return 0; + return 1; +} + +static int r_R2(struct SN_env *z) { + if (!(z->I[2] <= z->c)) + return 0; + return 1; +} + +static int r_attached_pronoun(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 68 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((557090 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_1, 13))) + return 0; /* substring, line 68 */ + z->bra = z->c; /* ], line 68 */ + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 111 && z->p[z->c - 1] != 114)) + return 0; + among_var = find_among_b(z, a_2, 11); /* substring, line 72 */ + if (!(among_var)) + return 0; + { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 72 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: + return 0; + case 1: + z->bra = z->c; /* ], line 73 */ + { + int ret = slice_from_s(z, 5, s_5); /* <-, line 73 */ + if (ret < 0) + return ret; + } + break; + case 2: + z->bra = z->c; /* ], line 74 */ + { + int ret = slice_from_s(z, 4, s_6); /* <-, line 74 */ + if (ret < 0) + return ret; + } + break; + case 3: + z->bra = z->c; /* ], line 75 */ + { + int ret = slice_from_s(z, 2, s_7); /* <-, line 75 */ + if (ret < 0) + return ret; + } + break; + case 4: + z->bra = z->c; /* ], line 76 */ + { + int ret = slice_from_s(z, 2, s_8); /* <-, line 76 */ + if (ret < 0) + return ret; + } + break; + case 5: + z->bra = z->c; /* ], line 77 */ + { + int ret = slice_from_s(z, 2, s_9); /* <-, line 77 */ + if (ret < 0) + return ret; + } + break; + case 6: { + int ret = slice_del(z); /* delete, line 81 */ + if (ret < 0) + return ret; + } break; + case 7: + if (!(eq_s_b(z, 1, s_10))) + return 0; + { + int ret = slice_del(z); /* delete, line 82 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_standard_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 87 */ + if (z->c - 2 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((835634 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + among_var = find_among_b(z, a_6, 46); /* substring, line 87 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 87 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 99 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 99 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 105 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 105 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 106 */ + z->ket = z->c; /* [, line 106 */ + if (!(eq_s_b(z, 2, s_11))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 106 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call R2, line 106 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 106 */ + if (ret < 0) + return ret; + } + lab0:; + } + break; + case 3: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 111 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 3, s_12); /* <-, line 111 */ + if (ret < 0) + return ret; + } + break; + case 4: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 115 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 1, s_13); /* <-, line 115 */ + if (ret < 0) + return ret; + } + break; + case 5: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 119 */ + if (ret < 0) + return ret; + } + { + int ret = slice_from_s(z, 4, s_14); /* <-, line 119 */ + if (ret < 0) + return ret; + } + break; + case 6: { + int ret = r_R1(z); + if (ret == 0) + return 0; /* call R1, line 123 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 123 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 124 */ + z->ket = z->c; /* [, line 125 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4718616 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab1; + } + among_var = find_among_b(z, a_3, 4); /* substring, line 125 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 125 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call R2, line 125 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 125 */ + if (ret < 0) + return ret; + } + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab1; + } + case 1: + z->ket = z->c; /* [, line 126 */ + if (!(eq_s_b(z, 2, s_15))) { + z->c = z->l - m_keep; + goto lab1; + } + z->bra = z->c; /* ], line 126 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab1; + } /* call R2, line 126 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 126 */ + if (ret < 0) + return ret; + } + break; + } + lab1:; + } + break; + case 7: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 135 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 135 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 136 */ + z->ket = z->c; /* [, line 137 */ + if (z->c - 3 <= z->lb || z->p[z->c - 1] != 101) { + z->c = z->l - m_keep; + goto lab2; + } + among_var = find_among_b(z, a_4, 3); /* substring, line 137 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab2; + } + z->bra = z->c; /* ], line 137 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab2; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab2; + } /* call R2, line 140 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 140 */ + if (ret < 0) + return ret; + } + break; + } + lab2:; + } + break; + case 8: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 147 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 147 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 148 */ + z->ket = z->c; /* [, line 149 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((4198408 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->c = z->l - m_keep; + goto lab3; + } + among_var = find_among_b(z, a_5, 3); /* substring, line 149 */ + if (!(among_var)) { + z->c = z->l - m_keep; + goto lab3; + } + z->bra = z->c; /* ], line 149 */ + switch (among_var) { + case 0: { + z->c = z->l - m_keep; + goto lab3; + } + case 1: { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab3; + } /* call R2, line 152 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 152 */ + if (ret < 0) + return ret; + } + break; + } + lab3:; + } + break; + case 9: { + int ret = r_R2(z); + if (ret == 0) + return 0; /* call R2, line 159 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 159 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 160 */ + z->ket = z->c; /* [, line 161 */ + if (!(eq_s_b(z, 2, s_16))) { + z->c = z->l - m_keep; + goto lab4; + } + z->bra = z->c; /* ], line 161 */ + { + int ret = r_R2(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab4; + } /* call R2, line 161 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 161 */ + if (ret < 0) + return ret; + } + lab4:; + } + break; + } + return 1; +} + +static int r_y_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 168 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 168 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 168 */ + among_var = find_among_b(z, a_7, 12); /* substring, line 168 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 168 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: + if (!(eq_s_b(z, 1, s_17))) + return 0; + { + int ret = slice_del(z); /* delete, line 171 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_verb_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 176 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 176 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 176 */ + among_var = find_among_b(z, a_8, 96); /* substring, line 176 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 176 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 179 */ + if (!(eq_s_b(z, 1, s_18))) { + z->c = z->l - m_keep; + goto lab0; + } + { + int m_test = z->l - z->c; /* test, line 179 */ + if (!(eq_s_b(z, 1, s_19))) { + z->c = z->l - m_keep; + goto lab0; + } + z->c = z->l - m_test; + } + lab0:; + } + z->bra = z->c; /* ], line 179 */ + { + int ret = slice_del(z); /* delete, line 179 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = slice_del(z); /* delete, line 200 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_residual_suffix(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 205 */ + among_var = find_among_b(z, a_9, 8); /* substring, line 205 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 205 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 208 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 208 */ + if (ret < 0) + return ret; + } + break; + case 2: { + int ret = r_RV(z); + if (ret == 0) + return 0; /* call RV, line 210 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 210 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 210 */ + z->ket = z->c; /* [, line 210 */ + if (!(eq_s_b(z, 1, s_20))) { + z->c = z->l - m_keep; + goto lab0; + } + z->bra = z->c; /* ], line 210 */ + { + int m_test = z->l - z->c; /* test, line 210 */ + if (!(eq_s_b(z, 1, s_21))) { + z->c = z->l - m_keep; + goto lab0; + } + z->c = z->l - m_test; + } + { + int ret = r_RV(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab0; + } /* call RV, line 210 */ + if (ret < 0) + return ret; + } + { + int ret = slice_del(z); /* delete, line 210 */ + if (ret < 0) + return ret; + } + lab0:; + } + break; + } + return 1; +} + +extern int spanish_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 216 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 216 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 217 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 218 */ + { + int ret = r_attached_pronoun(z); + if (ret == 0) + goto lab1; /* call attached_pronoun, line 218 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 219 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 219 */ + { + int ret = r_standard_suffix(z); + if (ret == 0) + goto lab4; /* call standard_suffix, line 219 */ + if (ret < 0) + return ret; + } + goto lab3; + lab4: + z->c = z->l - m4; + { + int ret = r_y_verb_suffix(z); + if (ret == 0) + goto lab5; /* call y_verb_suffix, line 220 */ + if (ret < 0) + return ret; + } + goto lab3; + lab5: + z->c = z->l - m4; + { + int ret = r_verb_suffix(z); + if (ret == 0) + goto lab2; /* call verb_suffix, line 221 */ + if (ret < 0) + return ret; + } + } + lab3: + lab2: + z->c = z->l - m3; + } + { + int m5 = z->l - z->c; + (void)m5; /* do, line 223 */ + { + int ret = r_residual_suffix(z); + if (ret == 0) + goto lab6; /* call residual_suffix, line 223 */ + if (ret < 0) + return ret; + } + lab6: + z->c = z->l - m5; + } + z->c = z->lb; + { + int c6 = z->c; /* do, line 225 */ + { + int ret = r_postlude(z); + if (ret == 0) + goto lab7; /* call postlude, line 225 */ + if (ret < 0) + return ret; + } + lab7: + z->c = c6; + } + return 1; +} + +extern struct SN_env *spanish_UTF_8_create_env(void) { return SN_create_env(0, 3, 0); } + +extern void spanish_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_spanish.h b/internal/cpp/stemmer/stem_UTF_8_spanish.h new file mode 100644 index 00000000000..ed8bb3429e6 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_spanish.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *spanish_UTF_8_create_env(void); +extern void spanish_UTF_8_close_env(struct SN_env *z); + +extern int spanish_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_swedish.cpp b/internal/cpp/stemmer/stem_UTF_8_swedish.cpp new file mode 100644 index 00000000000..b7acf2e1ab6 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_swedish.cpp @@ -0,0 +1,371 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int swedish_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_other_suffix(struct SN_env *z); +static int r_consonant_pair(struct SN_env *z); +static int r_main_suffix(struct SN_env *z); +static int r_mark_regions(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *swedish_UTF_8_create_env(void); +extern void swedish_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[1] = {'a'}; +static const symbol s_0_1[4] = {'a', 'r', 'n', 'a'}; +static const symbol s_0_2[4] = {'e', 'r', 'n', 'a'}; +static const symbol s_0_3[7] = {'h', 'e', 't', 'e', 'r', 'n', 'a'}; +static const symbol s_0_4[4] = {'o', 'r', 'n', 'a'}; +static const symbol s_0_5[2] = {'a', 'd'}; +static const symbol s_0_6[1] = {'e'}; +static const symbol s_0_7[3] = {'a', 'd', 'e'}; +static const symbol s_0_8[4] = {'a', 'n', 'd', 'e'}; +static const symbol s_0_9[4] = {'a', 'r', 'n', 'e'}; +static const symbol s_0_10[3] = {'a', 'r', 'e'}; +static const symbol s_0_11[4] = {'a', 's', 't', 'e'}; +static const symbol s_0_12[2] = {'e', 'n'}; +static const symbol s_0_13[5] = {'a', 'n', 'd', 'e', 'n'}; +static const symbol s_0_14[4] = {'a', 'r', 'e', 'n'}; +static const symbol s_0_15[5] = {'h', 'e', 't', 'e', 'n'}; +static const symbol s_0_16[3] = {'e', 'r', 'n'}; +static const symbol s_0_17[2] = {'a', 'r'}; +static const symbol s_0_18[2] = {'e', 'r'}; +static const symbol s_0_19[5] = {'h', 'e', 't', 'e', 'r'}; +static const symbol s_0_20[2] = {'o', 'r'}; +static const symbol s_0_21[1] = {'s'}; +static const symbol s_0_22[2] = {'a', 's'}; +static const symbol s_0_23[5] = {'a', 'r', 'n', 'a', 's'}; +static const symbol s_0_24[5] = {'e', 'r', 'n', 'a', 's'}; +static const symbol s_0_25[5] = {'o', 'r', 'n', 'a', 's'}; +static const symbol s_0_26[2] = {'e', 's'}; +static const symbol s_0_27[4] = {'a', 'd', 'e', 's'}; +static const symbol s_0_28[5] = {'a', 'n', 'd', 'e', 's'}; +static const symbol s_0_29[3] = {'e', 'n', 's'}; +static const symbol s_0_30[5] = {'a', 'r', 'e', 'n', 's'}; +static const symbol s_0_31[6] = {'h', 'e', 't', 'e', 'n', 's'}; +static const symbol s_0_32[4] = {'e', 'r', 'n', 's'}; +static const symbol s_0_33[2] = {'a', 't'}; +static const symbol s_0_34[5] = {'a', 'n', 'd', 'e', 't'}; +static const symbol s_0_35[3] = {'h', 'e', 't'}; +static const symbol s_0_36[3] = {'a', 's', 't'}; + +static const struct among a_0[37] = { + /* 0 */ {1, s_0_0, -1, 1, 0}, + /* 1 */ {4, s_0_1, 0, 1, 0}, + /* 2 */ {4, s_0_2, 0, 1, 0}, + /* 3 */ {7, s_0_3, 2, 1, 0}, + /* 4 */ {4, s_0_4, 0, 1, 0}, + /* 5 */ {2, s_0_5, -1, 1, 0}, + /* 6 */ {1, s_0_6, -1, 1, 0}, + /* 7 */ {3, s_0_7, 6, 1, 0}, + /* 8 */ {4, s_0_8, 6, 1, 0}, + /* 9 */ {4, s_0_9, 6, 1, 0}, + /* 10 */ {3, s_0_10, 6, 1, 0}, + /* 11 */ {4, s_0_11, 6, 1, 0}, + /* 12 */ {2, s_0_12, -1, 1, 0}, + /* 13 */ {5, s_0_13, 12, 1, 0}, + /* 14 */ {4, s_0_14, 12, 1, 0}, + /* 15 */ {5, s_0_15, 12, 1, 0}, + /* 16 */ {3, s_0_16, -1, 1, 0}, + /* 17 */ {2, s_0_17, -1, 1, 0}, + /* 18 */ {2, s_0_18, -1, 1, 0}, + /* 19 */ {5, s_0_19, 18, 1, 0}, + /* 20 */ {2, s_0_20, -1, 1, 0}, + /* 21 */ {1, s_0_21, -1, 2, 0}, + /* 22 */ {2, s_0_22, 21, 1, 0}, + /* 23 */ {5, s_0_23, 22, 1, 0}, + /* 24 */ {5, s_0_24, 22, 1, 0}, + /* 25 */ {5, s_0_25, 22, 1, 0}, + /* 26 */ {2, s_0_26, 21, 1, 0}, + /* 27 */ {4, s_0_27, 26, 1, 0}, + /* 28 */ {5, s_0_28, 26, 1, 0}, + /* 29 */ {3, s_0_29, 21, 1, 0}, + /* 30 */ {5, s_0_30, 29, 1, 0}, + /* 31 */ {6, s_0_31, 29, 1, 0}, + /* 32 */ {4, s_0_32, 21, 1, 0}, + /* 33 */ {2, s_0_33, -1, 1, 0}, + /* 34 */ {5, s_0_34, -1, 1, 0}, + /* 35 */ {3, s_0_35, -1, 1, 0}, + /* 36 */ {3, s_0_36, -1, 1, 0}}; + +static const symbol s_1_0[2] = {'d', 'd'}; +static const symbol s_1_1[2] = {'g', 'd'}; +static const symbol s_1_2[2] = {'n', 'n'}; +static const symbol s_1_3[2] = {'d', 't'}; +static const symbol s_1_4[2] = {'g', 't'}; +static const symbol s_1_5[2] = {'k', 't'}; +static const symbol s_1_6[2] = {'t', 't'}; + +static const struct among a_1[7] = { + /* 0 */ {2, s_1_0, -1, -1, 0}, + /* 1 */ {2, s_1_1, -1, -1, 0}, + /* 2 */ {2, s_1_2, -1, -1, 0}, + /* 3 */ {2, s_1_3, -1, -1, 0}, + /* 4 */ {2, s_1_4, -1, -1, 0}, + /* 5 */ {2, s_1_5, -1, -1, 0}, + /* 6 */ {2, s_1_6, -1, -1, 0}}; + +static const symbol s_2_0[2] = {'i', 'g'}; +static const symbol s_2_1[3] = {'l', 'i', 'g'}; +static const symbol s_2_2[3] = {'e', 'l', 's'}; +static const symbol s_2_3[5] = {'f', 'u', 'l', 'l', 't'}; +static const symbol s_2_4[5] = {'l', 0xC3, 0xB6, 's', 't'}; + +static const struct among a_2[5] = { + /* 0 */ {2, s_2_0, -1, 1, 0}, + /* 1 */ {3, s_2_1, 0, 1, 0}, + /* 2 */ {3, s_2_2, -1, 1, 0}, + /* 3 */ {5, s_2_3, -1, 3, 0}, + /* 4 */ {5, s_2_4, -1, 2, 0}}; + +static const unsigned char g_v[] = {17, 65, 16, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 32}; + +static const unsigned char g_s_ending[] = {119, 127, 149}; + +static const symbol s_0[] = {'l', 0xC3, 0xB6, 's'}; +static const symbol s_1[] = {'f', 'u', 'l', 'l'}; + +static int r_mark_regions(struct SN_env *z) { + z->I[0] = z->l; + { + int c_test = z->c; /* test, line 29 */ + { + int ret = skip_utf8(z->p, z->c, 0, z->l, +3); + if (ret < 0) + return 0; + z->c = ret; /* hop, line 29 */ + } + z->I[1] = z->c; /* setmark x, line 29 */ + z->c = c_test; + } + if (out_grouping_U(z, g_v, 97, 246, 1) < 0) + return 0; /* goto */ /* grouping v, line 30 */ + { /* gopast */ /* non v, line 30 */ + int ret = in_grouping_U(z, g_v, 97, 246, 1); + if (ret < 0) + return 0; + z->c += ret; + } + z->I[0] = z->c; /* setmark p1, line 30 */ + /* try, line 31 */ + if (!(z->I[0] < z->I[1])) + goto lab0; + z->I[0] = z->I[1]; +lab0: + return 1; +} + +static int r_main_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 37 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 37 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 37 */ + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1851442 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_0, 37); /* substring, line 37 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 37 */ + z->lb = mlimit; + } + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_del(z); /* delete, line 44 */ + if (ret < 0) + return ret; + } break; + case 2: + if (in_grouping_b_U(z, g_s_ending, 98, 121, 0)) + return 0; + { + int ret = slice_del(z); /* delete, line 46 */ + if (ret < 0) + return ret; + } + break; + } + return 1; +} + +static int r_consonant_pair(struct SN_env *z) { + { + int mlimit; /* setlimit, line 50 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 50 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* and, line 52 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1064976 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + if (!(find_among_b(z, a_1, 7))) { + z->lb = mlimit; + return 0; + } /* among, line 51 */ + z->c = z->l - m2; + z->ket = z->c; /* [, line 52 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) { + z->lb = mlimit; + return 0; + } + z->c = ret; /* next, line 52 */ + } + z->bra = z->c; /* ], line 52 */ + { + int ret = slice_del(z); /* delete, line 52 */ + if (ret < 0) + return ret; + } + } + z->lb = mlimit; + } + return 1; +} + +static int r_other_suffix(struct SN_env *z) { + int among_var; + { + int mlimit; /* setlimit, line 55 */ + int m1 = z->l - z->c; + (void)m1; + if (z->c < z->I[0]) + return 0; + z->c = z->I[0]; /* tomark, line 55 */ + mlimit = z->lb; + z->lb = z->c; + z->c = z->l - m1; + z->ket = z->c; /* [, line 56 */ + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((1572992 >> (z->p[z->c - 1] & 0x1f)) & 1)) { + z->lb = mlimit; + return 0; + } + among_var = find_among_b(z, a_2, 5); /* substring, line 56 */ + if (!(among_var)) { + z->lb = mlimit; + return 0; + } + z->bra = z->c; /* ], line 56 */ + switch (among_var) { + case 0: { + z->lb = mlimit; + return 0; + } + case 1: { + int ret = slice_del(z); /* delete, line 57 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 4, s_0); /* <-, line 58 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 4, s_1); /* <-, line 59 */ + if (ret < 0) + return ret; + } break; + } + z->lb = mlimit; + } + return 1; +} + +extern int swedish_UTF_8_stem(struct SN_env *z) { + { + int c1 = z->c; /* do, line 66 */ + { + int ret = r_mark_regions(z); + if (ret == 0) + goto lab0; /* call mark_regions, line 66 */ + if (ret < 0) + return ret; + } + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 67 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 68 */ + { + int ret = r_main_suffix(z); + if (ret == 0) + goto lab1; /* call main_suffix, line 68 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 69 */ + { + int ret = r_consonant_pair(z); + if (ret == 0) + goto lab2; /* call consonant_pair, line 69 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + { + int m4 = z->l - z->c; + (void)m4; /* do, line 70 */ + { + int ret = r_other_suffix(z); + if (ret == 0) + goto lab3; /* call other_suffix, line 70 */ + if (ret < 0) + return ret; + } + lab3: + z->c = z->l - m4; + } + z->c = z->lb; + return 1; +} + +extern struct SN_env *swedish_UTF_8_create_env(void) { return SN_create_env(0, 2, 0); } + +extern void swedish_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_swedish.h b/internal/cpp/stemmer/stem_UTF_8_swedish.h new file mode 100644 index 00000000000..9ded1c80c0d --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_swedish.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *swedish_UTF_8_create_env(void); +extern void swedish_UTF_8_close_env(struct SN_env *z); + +extern int swedish_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stem_UTF_8_turkish.cpp b/internal/cpp/stemmer/stem_UTF_8_turkish.cpp new file mode 100644 index 00000000000..ab5a933bae7 --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_turkish.cpp @@ -0,0 +1,2978 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#include "header.h" + +#ifdef __cplusplus +extern "C" { +#endif +extern int turkish_UTF_8_stem(struct SN_env *z); +#ifdef __cplusplus +} +#endif +static int r_stem_suffix_chain_before_ki(struct SN_env *z); +static int r_stem_noun_suffixes(struct SN_env *z); +static int r_stem_nominal_verb_suffixes(struct SN_env *z); +static int r_postlude(struct SN_env *z); +static int r_post_process_last_consonants(struct SN_env *z); +static int r_more_than_one_syllable_word(struct SN_env *z); +static int r_mark_suffix_with_optional_s_consonant(struct SN_env *z); +static int r_mark_suffix_with_optional_n_consonant(struct SN_env *z); +static int r_mark_suffix_with_optional_U_vowel(struct SN_env *z); +static int r_mark_suffix_with_optional_y_consonant(struct SN_env *z); +static int r_mark_ysA(struct SN_env *z); +static int r_mark_ymUs_(struct SN_env *z); +static int r_mark_yken(struct SN_env *z); +static int r_mark_yDU(struct SN_env *z); +static int r_mark_yUz(struct SN_env *z); +static int r_mark_yUm(struct SN_env *z); +static int r_mark_yU(struct SN_env *z); +static int r_mark_ylA(struct SN_env *z); +static int r_mark_yA(struct SN_env *z); +static int r_mark_possessives(struct SN_env *z); +static int r_mark_sUnUz(struct SN_env *z); +static int r_mark_sUn(struct SN_env *z); +static int r_mark_sU(struct SN_env *z); +static int r_mark_nUz(struct SN_env *z); +static int r_mark_nUn(struct SN_env *z); +static int r_mark_nU(struct SN_env *z); +static int r_mark_ndAn(struct SN_env *z); +static int r_mark_ndA(struct SN_env *z); +static int r_mark_ncA(struct SN_env *z); +static int r_mark_nA(struct SN_env *z); +static int r_mark_lArI(struct SN_env *z); +static int r_mark_lAr(struct SN_env *z); +static int r_mark_ki(struct SN_env *z); +static int r_mark_DUr(struct SN_env *z); +static int r_mark_DAn(struct SN_env *z); +static int r_mark_DA(struct SN_env *z); +static int r_mark_cAsInA(struct SN_env *z); +static int r_is_reserved_word(struct SN_env *z); +static int r_check_vowel_harmony(struct SN_env *z); +static int r_append_U_to_stems_ending_with_d_or_g(struct SN_env *z); +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *turkish_UTF_8_create_env(void); +extern void turkish_UTF_8_close_env(struct SN_env *z); + +#ifdef __cplusplus +} +#endif +static const symbol s_0_0[1] = {'m'}; +static const symbol s_0_1[1] = {'n'}; +static const symbol s_0_2[3] = {'m', 'i', 'z'}; +static const symbol s_0_3[3] = {'n', 'i', 'z'}; +static const symbol s_0_4[3] = {'m', 'u', 'z'}; +static const symbol s_0_5[3] = {'n', 'u', 'z'}; +static const symbol s_0_6[4] = {'m', 0xC4, 0xB1, 'z'}; +static const symbol s_0_7[4] = {'n', 0xC4, 0xB1, 'z'}; +static const symbol s_0_8[4] = {'m', 0xC3, 0xBC, 'z'}; +static const symbol s_0_9[4] = {'n', 0xC3, 0xBC, 'z'}; + +static const struct among a_0[10] = { + /* 0 */ {1, s_0_0, -1, -1, 0}, + /* 1 */ {1, s_0_1, -1, -1, 0}, + /* 2 */ {3, s_0_2, -1, -1, 0}, + /* 3 */ {3, s_0_3, -1, -1, 0}, + /* 4 */ {3, s_0_4, -1, -1, 0}, + /* 5 */ {3, s_0_5, -1, -1, 0}, + /* 6 */ {4, s_0_6, -1, -1, 0}, + /* 7 */ {4, s_0_7, -1, -1, 0}, + /* 8 */ {4, s_0_8, -1, -1, 0}, + /* 9 */ {4, s_0_9, -1, -1, 0}}; + +static const symbol s_1_0[4] = {'l', 'e', 'r', 'i'}; +static const symbol s_1_1[5] = {'l', 'a', 'r', 0xC4, 0xB1}; + +static const struct among a_1[2] = { + /* 0 */ {4, s_1_0, -1, -1, 0}, + /* 1 */ {5, s_1_1, -1, -1, 0}}; + +static const symbol s_2_0[2] = {'n', 'i'}; +static const symbol s_2_1[2] = {'n', 'u'}; +static const symbol s_2_2[3] = {'n', 0xC4, 0xB1}; +static const symbol s_2_3[3] = {'n', 0xC3, 0xBC}; + +static const struct among a_2[4] = { + /* 0 */ {2, s_2_0, -1, -1, 0}, + /* 1 */ {2, s_2_1, -1, -1, 0}, + /* 2 */ {3, s_2_2, -1, -1, 0}, + /* 3 */ {3, s_2_3, -1, -1, 0}}; + +static const symbol s_3_0[2] = {'i', 'n'}; +static const symbol s_3_1[2] = {'u', 'n'}; +static const symbol s_3_2[3] = {0xC4, 0xB1, 'n'}; +static const symbol s_3_3[3] = {0xC3, 0xBC, 'n'}; + +static const struct among a_3[4] = { + /* 0 */ {2, s_3_0, -1, -1, 0}, + /* 1 */ {2, s_3_1, -1, -1, 0}, + /* 2 */ {3, s_3_2, -1, -1, 0}, + /* 3 */ {3, s_3_3, -1, -1, 0}}; + +static const symbol s_4_0[1] = {'a'}; +static const symbol s_4_1[1] = {'e'}; + +static const struct among a_4[2] = { + /* 0 */ {1, s_4_0, -1, -1, 0}, + /* 1 */ {1, s_4_1, -1, -1, 0}}; + +static const symbol s_5_0[2] = {'n', 'a'}; +static const symbol s_5_1[2] = {'n', 'e'}; + +static const struct among a_5[2] = { + /* 0 */ {2, s_5_0, -1, -1, 0}, + /* 1 */ {2, s_5_1, -1, -1, 0}}; + +static const symbol s_6_0[2] = {'d', 'a'}; +static const symbol s_6_1[2] = {'t', 'a'}; +static const symbol s_6_2[2] = {'d', 'e'}; +static const symbol s_6_3[2] = {'t', 'e'}; + +static const struct among a_6[4] = { + /* 0 */ {2, s_6_0, -1, -1, 0}, + /* 1 */ {2, s_6_1, -1, -1, 0}, + /* 2 */ {2, s_6_2, -1, -1, 0}, + /* 3 */ {2, s_6_3, -1, -1, 0}}; + +static const symbol s_7_0[3] = {'n', 'd', 'a'}; +static const symbol s_7_1[3] = {'n', 'd', 'e'}; + +static const struct among a_7[2] = { + /* 0 */ {3, s_7_0, -1, -1, 0}, + /* 1 */ {3, s_7_1, -1, -1, 0}}; + +static const symbol s_8_0[3] = {'d', 'a', 'n'}; +static const symbol s_8_1[3] = {'t', 'a', 'n'}; +static const symbol s_8_2[3] = {'d', 'e', 'n'}; +static const symbol s_8_3[3] = {'t', 'e', 'n'}; + +static const struct among a_8[4] = { + /* 0 */ {3, s_8_0, -1, -1, 0}, + /* 1 */ {3, s_8_1, -1, -1, 0}, + /* 2 */ {3, s_8_2, -1, -1, 0}, + /* 3 */ {3, s_8_3, -1, -1, 0}}; + +static const symbol s_9_0[4] = {'n', 'd', 'a', 'n'}; +static const symbol s_9_1[4] = {'n', 'd', 'e', 'n'}; + +static const struct among a_9[2] = { + /* 0 */ {4, s_9_0, -1, -1, 0}, + /* 1 */ {4, s_9_1, -1, -1, 0}}; + +static const symbol s_10_0[2] = {'l', 'a'}; +static const symbol s_10_1[2] = {'l', 'e'}; + +static const struct among a_10[2] = { + /* 0 */ {2, s_10_0, -1, -1, 0}, + /* 1 */ {2, s_10_1, -1, -1, 0}}; + +static const symbol s_11_0[2] = {'c', 'a'}; +static const symbol s_11_1[2] = {'c', 'e'}; + +static const struct among a_11[2] = { + /* 0 */ {2, s_11_0, -1, -1, 0}, + /* 1 */ {2, s_11_1, -1, -1, 0}}; + +static const symbol s_12_0[2] = {'i', 'm'}; +static const symbol s_12_1[2] = {'u', 'm'}; +static const symbol s_12_2[3] = {0xC4, 0xB1, 'm'}; +static const symbol s_12_3[3] = {0xC3, 0xBC, 'm'}; + +static const struct among a_12[4] = { + /* 0 */ {2, s_12_0, -1, -1, 0}, + /* 1 */ {2, s_12_1, -1, -1, 0}, + /* 2 */ {3, s_12_2, -1, -1, 0}, + /* 3 */ {3, s_12_3, -1, -1, 0}}; + +static const symbol s_13_0[3] = {'s', 'i', 'n'}; +static const symbol s_13_1[3] = {'s', 'u', 'n'}; +static const symbol s_13_2[4] = {'s', 0xC4, 0xB1, 'n'}; +static const symbol s_13_3[4] = {'s', 0xC3, 0xBC, 'n'}; + +static const struct among a_13[4] = { + /* 0 */ {3, s_13_0, -1, -1, 0}, + /* 1 */ {3, s_13_1, -1, -1, 0}, + /* 2 */ {4, s_13_2, -1, -1, 0}, + /* 3 */ {4, s_13_3, -1, -1, 0}}; + +static const symbol s_14_0[2] = {'i', 'z'}; +static const symbol s_14_1[2] = {'u', 'z'}; +static const symbol s_14_2[3] = {0xC4, 0xB1, 'z'}; +static const symbol s_14_3[3] = {0xC3, 0xBC, 'z'}; + +static const struct among a_14[4] = { + /* 0 */ {2, s_14_0, -1, -1, 0}, + /* 1 */ {2, s_14_1, -1, -1, 0}, + /* 2 */ {3, s_14_2, -1, -1, 0}, + /* 3 */ {3, s_14_3, -1, -1, 0}}; + +static const symbol s_15_0[5] = {'s', 'i', 'n', 'i', 'z'}; +static const symbol s_15_1[5] = {'s', 'u', 'n', 'u', 'z'}; +static const symbol s_15_2[7] = {'s', 0xC4, 0xB1, 'n', 0xC4, 0xB1, 'z'}; +static const symbol s_15_3[7] = {'s', 0xC3, 0xBC, 'n', 0xC3, 0xBC, 'z'}; + +static const struct among a_15[4] = { + /* 0 */ {5, s_15_0, -1, -1, 0}, + /* 1 */ {5, s_15_1, -1, -1, 0}, + /* 2 */ {7, s_15_2, -1, -1, 0}, + /* 3 */ {7, s_15_3, -1, -1, 0}}; + +static const symbol s_16_0[3] = {'l', 'a', 'r'}; +static const symbol s_16_1[3] = {'l', 'e', 'r'}; + +static const struct among a_16[2] = { + /* 0 */ {3, s_16_0, -1, -1, 0}, + /* 1 */ {3, s_16_1, -1, -1, 0}}; + +static const symbol s_17_0[3] = {'n', 'i', 'z'}; +static const symbol s_17_1[3] = {'n', 'u', 'z'}; +static const symbol s_17_2[4] = {'n', 0xC4, 0xB1, 'z'}; +static const symbol s_17_3[4] = {'n', 0xC3, 0xBC, 'z'}; + +static const struct among a_17[4] = { + /* 0 */ {3, s_17_0, -1, -1, 0}, + /* 1 */ {3, s_17_1, -1, -1, 0}, + /* 2 */ {4, s_17_2, -1, -1, 0}, + /* 3 */ {4, s_17_3, -1, -1, 0}}; + +static const symbol s_18_0[3] = {'d', 'i', 'r'}; +static const symbol s_18_1[3] = {'t', 'i', 'r'}; +static const symbol s_18_2[3] = {'d', 'u', 'r'}; +static const symbol s_18_3[3] = {'t', 'u', 'r'}; +static const symbol s_18_4[4] = {'d', 0xC4, 0xB1, 'r'}; +static const symbol s_18_5[4] = {'t', 0xC4, 0xB1, 'r'}; +static const symbol s_18_6[4] = {'d', 0xC3, 0xBC, 'r'}; +static const symbol s_18_7[4] = {'t', 0xC3, 0xBC, 'r'}; + +static const struct among a_18[8] = { + /* 0 */ {3, s_18_0, -1, -1, 0}, + /* 1 */ {3, s_18_1, -1, -1, 0}, + /* 2 */ {3, s_18_2, -1, -1, 0}, + /* 3 */ {3, s_18_3, -1, -1, 0}, + /* 4 */ {4, s_18_4, -1, -1, 0}, + /* 5 */ {4, s_18_5, -1, -1, 0}, + /* 6 */ {4, s_18_6, -1, -1, 0}, + /* 7 */ {4, s_18_7, -1, -1, 0}}; + +static const symbol s_19_0[7] = {'c', 'a', 's', 0xC4, 0xB1, 'n', 'a'}; +static const symbol s_19_1[6] = {'c', 'e', 's', 'i', 'n', 'e'}; + +static const struct among a_19[2] = { + /* 0 */ {7, s_19_0, -1, -1, 0}, + /* 1 */ {6, s_19_1, -1, -1, 0}}; + +static const symbol s_20_0[2] = {'d', 'i'}; +static const symbol s_20_1[2] = {'t', 'i'}; +static const symbol s_20_2[3] = {'d', 'i', 'k'}; +static const symbol s_20_3[3] = {'t', 'i', 'k'}; +static const symbol s_20_4[3] = {'d', 'u', 'k'}; +static const symbol s_20_5[3] = {'t', 'u', 'k'}; +static const symbol s_20_6[4] = {'d', 0xC4, 0xB1, 'k'}; +static const symbol s_20_7[4] = {'t', 0xC4, 0xB1, 'k'}; +static const symbol s_20_8[4] = {'d', 0xC3, 0xBC, 'k'}; +static const symbol s_20_9[4] = {'t', 0xC3, 0xBC, 'k'}; +static const symbol s_20_10[3] = {'d', 'i', 'm'}; +static const symbol s_20_11[3] = {'t', 'i', 'm'}; +static const symbol s_20_12[3] = {'d', 'u', 'm'}; +static const symbol s_20_13[3] = {'t', 'u', 'm'}; +static const symbol s_20_14[4] = {'d', 0xC4, 0xB1, 'm'}; +static const symbol s_20_15[4] = {'t', 0xC4, 0xB1, 'm'}; +static const symbol s_20_16[4] = {'d', 0xC3, 0xBC, 'm'}; +static const symbol s_20_17[4] = {'t', 0xC3, 0xBC, 'm'}; +static const symbol s_20_18[3] = {'d', 'i', 'n'}; +static const symbol s_20_19[3] = {'t', 'i', 'n'}; +static const symbol s_20_20[3] = {'d', 'u', 'n'}; +static const symbol s_20_21[3] = {'t', 'u', 'n'}; +static const symbol s_20_22[4] = {'d', 0xC4, 0xB1, 'n'}; +static const symbol s_20_23[4] = {'t', 0xC4, 0xB1, 'n'}; +static const symbol s_20_24[4] = {'d', 0xC3, 0xBC, 'n'}; +static const symbol s_20_25[4] = {'t', 0xC3, 0xBC, 'n'}; +static const symbol s_20_26[2] = {'d', 'u'}; +static const symbol s_20_27[2] = {'t', 'u'}; +static const symbol s_20_28[3] = {'d', 0xC4, 0xB1}; +static const symbol s_20_29[3] = {'t', 0xC4, 0xB1}; +static const symbol s_20_30[3] = {'d', 0xC3, 0xBC}; +static const symbol s_20_31[3] = {'t', 0xC3, 0xBC}; + +static const struct among a_20[32] = { + /* 0 */ {2, s_20_0, -1, -1, 0}, + /* 1 */ {2, s_20_1, -1, -1, 0}, + /* 2 */ {3, s_20_2, -1, -1, 0}, + /* 3 */ {3, s_20_3, -1, -1, 0}, + /* 4 */ {3, s_20_4, -1, -1, 0}, + /* 5 */ {3, s_20_5, -1, -1, 0}, + /* 6 */ {4, s_20_6, -1, -1, 0}, + /* 7 */ {4, s_20_7, -1, -1, 0}, + /* 8 */ {4, s_20_8, -1, -1, 0}, + /* 9 */ {4, s_20_9, -1, -1, 0}, + /* 10 */ {3, s_20_10, -1, -1, 0}, + /* 11 */ {3, s_20_11, -1, -1, 0}, + /* 12 */ {3, s_20_12, -1, -1, 0}, + /* 13 */ {3, s_20_13, -1, -1, 0}, + /* 14 */ {4, s_20_14, -1, -1, 0}, + /* 15 */ {4, s_20_15, -1, -1, 0}, + /* 16 */ {4, s_20_16, -1, -1, 0}, + /* 17 */ {4, s_20_17, -1, -1, 0}, + /* 18 */ {3, s_20_18, -1, -1, 0}, + /* 19 */ {3, s_20_19, -1, -1, 0}, + /* 20 */ {3, s_20_20, -1, -1, 0}, + /* 21 */ {3, s_20_21, -1, -1, 0}, + /* 22 */ {4, s_20_22, -1, -1, 0}, + /* 23 */ {4, s_20_23, -1, -1, 0}, + /* 24 */ {4, s_20_24, -1, -1, 0}, + /* 25 */ {4, s_20_25, -1, -1, 0}, + /* 26 */ {2, s_20_26, -1, -1, 0}, + /* 27 */ {2, s_20_27, -1, -1, 0}, + /* 28 */ {3, s_20_28, -1, -1, 0}, + /* 29 */ {3, s_20_29, -1, -1, 0}, + /* 30 */ {3, s_20_30, -1, -1, 0}, + /* 31 */ {3, s_20_31, -1, -1, 0}}; + +static const symbol s_21_0[2] = {'s', 'a'}; +static const symbol s_21_1[2] = {'s', 'e'}; +static const symbol s_21_2[3] = {'s', 'a', 'k'}; +static const symbol s_21_3[3] = {'s', 'e', 'k'}; +static const symbol s_21_4[3] = {'s', 'a', 'm'}; +static const symbol s_21_5[3] = {'s', 'e', 'm'}; +static const symbol s_21_6[3] = {'s', 'a', 'n'}; +static const symbol s_21_7[3] = {'s', 'e', 'n'}; + +static const struct among a_21[8] = { + /* 0 */ {2, s_21_0, -1, -1, 0}, + /* 1 */ {2, s_21_1, -1, -1, 0}, + /* 2 */ {3, s_21_2, -1, -1, 0}, + /* 3 */ {3, s_21_3, -1, -1, 0}, + /* 4 */ {3, s_21_4, -1, -1, 0}, + /* 5 */ {3, s_21_5, -1, -1, 0}, + /* 6 */ {3, s_21_6, -1, -1, 0}, + /* 7 */ {3, s_21_7, -1, -1, 0}}; + +static const symbol s_22_0[4] = {'m', 'i', 0xC5, 0x9F}; +static const symbol s_22_1[4] = {'m', 'u', 0xC5, 0x9F}; +static const symbol s_22_2[5] = {'m', 0xC4, 0xB1, 0xC5, 0x9F}; +static const symbol s_22_3[5] = {'m', 0xC3, 0xBC, 0xC5, 0x9F}; + +static const struct among a_22[4] = { + /* 0 */ {4, s_22_0, -1, -1, 0}, + /* 1 */ {4, s_22_1, -1, -1, 0}, + /* 2 */ {5, s_22_2, -1, -1, 0}, + /* 3 */ {5, s_22_3, -1, -1, 0}}; + +static const symbol s_23_0[1] = {'b'}; +static const symbol s_23_1[1] = {'c'}; +static const symbol s_23_2[1] = {'d'}; +static const symbol s_23_3[2] = {0xC4, 0x9F}; + +static const struct among a_23[4] = { + /* 0 */ {1, s_23_0, -1, 1, 0}, + /* 1 */ {1, s_23_1, -1, 2, 0}, + /* 2 */ {1, s_23_2, -1, 3, 0}, + /* 3 */ {2, s_23_3, -1, 4, 0}}; + +static const unsigned char g_vowel[] = {17, 65, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 8, 0, 0, 0, 0, 0, 0, 1}; + +static const unsigned char g_U[] = {1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 1}; + +static const unsigned char g_vowel1[] = {1, 64, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + +static const unsigned char g_vowel2[] = {17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 130}; + +static const unsigned char g_vowel3[] = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + +static const unsigned char g_vowel4[] = {17}; + +static const unsigned char g_vowel5[] = {65}; + +static const unsigned char g_vowel6[] = {65}; + +static const symbol s_0[] = {'a'}; +static const symbol s_1[] = {'e'}; +static const symbol s_2[] = {0xC4, 0xB1}; +static const symbol s_3[] = {'i'}; +static const symbol s_4[] = {'o'}; +static const symbol s_5[] = {0xC3, 0xB6}; +static const symbol s_6[] = {'u'}; +static const symbol s_7[] = {0xC3, 0xBC}; +static const symbol s_8[] = {'n'}; +static const symbol s_9[] = {'n'}; +static const symbol s_10[] = {'s'}; +static const symbol s_11[] = {'s'}; +static const symbol s_12[] = {'y'}; +static const symbol s_13[] = {'y'}; +static const symbol s_14[] = {'k', 'i'}; +static const symbol s_15[] = {'k', 'e', 'n'}; +static const symbol s_16[] = {'p'}; +static const symbol s_17[] = {0xC3, 0xA7}; +static const symbol s_18[] = {'t'}; +static const symbol s_19[] = {'k'}; +static const symbol s_20[] = {'d'}; +static const symbol s_21[] = {'g'}; +static const symbol s_22[] = {'a'}; +static const symbol s_23[] = {0xC4, 0xB1}; +static const symbol s_24[] = {0xC4, 0xB1}; +static const symbol s_25[] = {'e'}; +static const symbol s_26[] = {'i'}; +static const symbol s_27[] = {'i'}; +static const symbol s_28[] = {'o'}; +static const symbol s_29[] = {'u'}; +static const symbol s_30[] = {'u'}; +static const symbol s_31[] = {0xC3, 0xB6}; +static const symbol s_32[] = {0xC3, 0xBC}; +static const symbol s_33[] = {0xC3, 0xBC}; +static const symbol s_34[] = {'a', 'd'}; +static const symbol s_35[] = {'s', 'o', 'y', 'a', 'd'}; + +static int r_check_vowel_harmony(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 112 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 1) < 0) + return 0; /* goto */ /* grouping vowel, line 114 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 116 */ + if (!(eq_s_b(z, 1, s_0))) + goto lab1; + if (out_grouping_b_U(z, g_vowel1, 97, 305, 1) < 0) + goto lab1; /* goto */ /* grouping vowel1, line 116 */ + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_1))) + goto lab2; + if (out_grouping_b_U(z, g_vowel2, 101, 252, 1) < 0) + goto lab2; /* goto */ /* grouping vowel2, line 117 */ + goto lab0; + lab2: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_2))) + goto lab3; + if (out_grouping_b_U(z, g_vowel3, 97, 305, 1) < 0) + goto lab3; /* goto */ /* grouping vowel3, line 118 */ + goto lab0; + lab3: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_3))) + goto lab4; + if (out_grouping_b_U(z, g_vowel4, 101, 105, 1) < 0) + goto lab4; /* goto */ /* grouping vowel4, line 119 */ + goto lab0; + lab4: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_4))) + goto lab5; + if (out_grouping_b_U(z, g_vowel5, 111, 117, 1) < 0) + goto lab5; /* goto */ /* grouping vowel5, line 120 */ + goto lab0; + lab5: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_5))) + goto lab6; + if (out_grouping_b_U(z, g_vowel6, 246, 252, 1) < 0) + goto lab6; /* goto */ /* grouping vowel6, line 121 */ + goto lab0; + lab6: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_6))) + goto lab7; + if (out_grouping_b_U(z, g_vowel5, 111, 117, 1) < 0) + goto lab7; /* goto */ /* grouping vowel5, line 122 */ + goto lab0; + lab7: + z->c = z->l - m1; + if (!(eq_s_b(z, 2, s_7))) + return 0; + if (out_grouping_b_U(z, g_vowel6, 246, 252, 1) < 0) + return 0; /* goto */ /* grouping vowel6, line 123 */ + } + lab0: + z->c = z->l - m_test; + } + return 1; +} + +static int r_mark_suffix_with_optional_n_consonant(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 134 */ + { + int m_test = z->l - z->c; /* test, line 133 */ + if (!(eq_s_b(z, 1, s_8))) + goto lab1; + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab1; + z->c = ret; /* next, line 133 */ + } + { + int m_test = z->l - z->c; /* test, line 133 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + goto lab1; + z->c = z->l - m_test; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 135 */ + { + int m_test = z->l - z->c; /* test, line 135 */ + if (!(eq_s_b(z, 1, s_9))) + goto lab2; + z->c = z->l - m_test; + } + return 0; + lab2: + z->c = z->l - m2; + } + { + int m_test = z->l - z->c; /* test, line 135 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 135 */ + } + { + int m_test = z->l - z->c; /* test, line 135 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + return 0; + z->c = z->l - m_test; + } + z->c = z->l - m_test; + } + } +lab0: + return 1; +} + +static int r_mark_suffix_with_optional_s_consonant(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 145 */ + { + int m_test = z->l - z->c; /* test, line 144 */ + if (!(eq_s_b(z, 1, s_10))) + goto lab1; + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab1; + z->c = ret; /* next, line 144 */ + } + { + int m_test = z->l - z->c; /* test, line 144 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + goto lab1; + z->c = z->l - m_test; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 146 */ + { + int m_test = z->l - z->c; /* test, line 146 */ + if (!(eq_s_b(z, 1, s_11))) + goto lab2; + z->c = z->l - m_test; + } + return 0; + lab2: + z->c = z->l - m2; + } + { + int m_test = z->l - z->c; /* test, line 146 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 146 */ + } + { + int m_test = z->l - z->c; /* test, line 146 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + return 0; + z->c = z->l - m_test; + } + z->c = z->l - m_test; + } + } +lab0: + return 1; +} + +static int r_mark_suffix_with_optional_y_consonant(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 155 */ + { + int m_test = z->l - z->c; /* test, line 154 */ + if (!(eq_s_b(z, 1, s_12))) + goto lab1; + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab1; + z->c = ret; /* next, line 154 */ + } + { + int m_test = z->l - z->c; /* test, line 154 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + goto lab1; + z->c = z->l - m_test; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 156 */ + { + int m_test = z->l - z->c; /* test, line 156 */ + if (!(eq_s_b(z, 1, s_13))) + goto lab2; + z->c = z->l - m_test; + } + return 0; + lab2: + z->c = z->l - m2; + } + { + int m_test = z->l - z->c; /* test, line 156 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 156 */ + } + { + int m_test = z->l - z->c; /* test, line 156 */ + if (in_grouping_b_U(z, g_vowel, 97, 305, 0)) + return 0; + z->c = z->l - m_test; + } + z->c = z->l - m_test; + } + } +lab0: + return 1; +} + +static int r_mark_suffix_with_optional_U_vowel(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 161 */ + { + int m_test = z->l - z->c; /* test, line 160 */ + if (in_grouping_b_U(z, g_U, 105, 305, 0)) + goto lab1; + z->c = z->l - m_test; + } + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + goto lab1; + z->c = ret; /* next, line 160 */ + } + { + int m_test = z->l - z->c; /* test, line 160 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 0)) + goto lab1; + z->c = z->l - m_test; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int m2 = z->l - z->c; + (void)m2; /* not, line 162 */ + { + int m_test = z->l - z->c; /* test, line 162 */ + if (in_grouping_b_U(z, g_U, 105, 305, 0)) + goto lab2; + z->c = z->l - m_test; + } + return 0; + lab2: + z->c = z->l - m2; + } + { + int m_test = z->l - z->c; /* test, line 162 */ + { + int ret = skip_utf8(z->p, z->c, z->lb, 0, -1); + if (ret < 0) + return 0; + z->c = ret; /* next, line 162 */ + } + { + int m_test = z->l - z->c; /* test, line 162 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 0)) + return 0; + z->c = z->l - m_test; + } + z->c = z->l - m_test; + } + } +lab0: + return 1; +} + +static int r_mark_possessives(struct SN_env *z) { + if (z->c <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((67133440 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_0, 10))) + return 0; /* among, line 167 */ + { + int ret = r_mark_suffix_with_optional_U_vowel(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_U_vowel, line 169 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_sU(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 173 */ + if (ret < 0) + return ret; + } + if (in_grouping_b_U(z, g_U, 105, 305, 0)) + return 0; + { + int ret = r_mark_suffix_with_optional_s_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_s_consonant, line 175 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_lArI(struct SN_env *z) { + if (z->c - 3 <= z->lb || (z->p[z->c - 1] != 105 && z->p[z->c - 1] != 177)) + return 0; + if (!(find_among_b(z, a_1, 2))) + return 0; /* among, line 179 */ + return 1; +} + +static int r_mark_yU(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 183 */ + if (ret < 0) + return ret; + } + if (in_grouping_b_U(z, g_U, 105, 305, 0)) + return 0; + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 185 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_nU(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 189 */ + if (ret < 0) + return ret; + } + if (!(find_among_b(z, a_2, 4))) + return 0; /* among, line 190 */ + return 1; +} + +static int r_mark_nUn(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 194 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 110) + return 0; + if (!(find_among_b(z, a_3, 4))) + return 0; /* among, line 195 */ + { + int ret = r_mark_suffix_with_optional_n_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_n_consonant, line 196 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_yA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 200 */ + if (ret < 0) + return ret; + } + if (z->c <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_4, 2))) + return 0; /* among, line 201 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 202 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_nA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 206 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_5, 2))) + return 0; /* among, line 207 */ + return 1; +} + +static int r_mark_DA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 211 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_6, 4))) + return 0; /* among, line 212 */ + return 1; +} + +static int r_mark_ndA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 216 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_7, 2))) + return 0; /* among, line 217 */ + return 1; +} + +static int r_mark_DAn(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 221 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 110) + return 0; + if (!(find_among_b(z, a_8, 4))) + return 0; /* among, line 222 */ + return 1; +} + +static int r_mark_ndAn(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 226 */ + if (ret < 0) + return ret; + } + if (z->c - 3 <= z->lb || z->p[z->c - 1] != 110) + return 0; + if (!(find_among_b(z, a_9, 2))) + return 0; /* among, line 227 */ + return 1; +} + +static int r_mark_ylA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 231 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_10, 2))) + return 0; /* among, line 232 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 233 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_ki(struct SN_env *z) { + if (!(eq_s_b(z, 2, s_14))) + return 0; + return 1; +} + +static int r_mark_ncA(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 241 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_11, 2))) + return 0; /* among, line 242 */ + { + int ret = r_mark_suffix_with_optional_n_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_n_consonant, line 243 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_yUm(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 247 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 109) + return 0; + if (!(find_among_b(z, a_12, 4))) + return 0; /* among, line 248 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 249 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_sUn(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 253 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 110) + return 0; + if (!(find_among_b(z, a_13, 4))) + return 0; /* among, line 254 */ + return 1; +} + +static int r_mark_yUz(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 258 */ + if (ret < 0) + return ret; + } + if (z->c - 1 <= z->lb || z->p[z->c - 1] != 122) + return 0; + if (!(find_among_b(z, a_14, 4))) + return 0; /* among, line 259 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 260 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_sUnUz(struct SN_env *z) { + if (z->c - 4 <= z->lb || z->p[z->c - 1] != 122) + return 0; + if (!(find_among_b(z, a_15, 4))) + return 0; /* among, line 264 */ + return 1; +} + +static int r_mark_lAr(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 268 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 114) + return 0; + if (!(find_among_b(z, a_16, 2))) + return 0; /* among, line 269 */ + return 1; +} + +static int r_mark_nUz(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 273 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 122) + return 0; + if (!(find_among_b(z, a_17, 4))) + return 0; /* among, line 274 */ + return 1; +} + +static int r_mark_DUr(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 278 */ + if (ret < 0) + return ret; + } + if (z->c - 2 <= z->lb || z->p[z->c - 1] != 114) + return 0; + if (!(find_among_b(z, a_18, 8))) + return 0; /* among, line 279 */ + return 1; +} + +static int r_mark_cAsInA(struct SN_env *z) { + if (z->c - 5 <= z->lb || (z->p[z->c - 1] != 97 && z->p[z->c - 1] != 101)) + return 0; + if (!(find_among_b(z, a_19, 2))) + return 0; /* among, line 283 */ + return 1; +} + +static int r_mark_yDU(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 287 */ + if (ret < 0) + return ret; + } + if (!(find_among_b(z, a_20, 32))) + return 0; /* among, line 288 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 292 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_ysA(struct SN_env *z) { + if (z->c - 1 <= z->lb || z->p[z->c - 1] >> 5 != 3 || !((26658 >> (z->p[z->c - 1] & 0x1f)) & 1)) + return 0; + if (!(find_among_b(z, a_21, 8))) + return 0; /* among, line 297 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 298 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_ymUs_(struct SN_env *z) { + { + int ret = r_check_vowel_harmony(z); + if (ret == 0) + return 0; /* call check_vowel_harmony, line 302 */ + if (ret < 0) + return ret; + } + if (z->c - 3 <= z->lb || z->p[z->c - 1] != 159) + return 0; + if (!(find_among_b(z, a_22, 4))) + return 0; /* among, line 303 */ + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 304 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_mark_yken(struct SN_env *z) { + if (!(eq_s_b(z, 3, s_15))) + return 0; + { + int ret = r_mark_suffix_with_optional_y_consonant(z); + if (ret == 0) + return 0; /* call mark_suffix_with_optional_y_consonant, line 308 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_stem_nominal_verb_suffixes(struct SN_env *z) { + z->ket = z->c; /* [, line 312 */ + z->B[0] = 1; /* set continue_stemming_noun_suffixes, line 313 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 315 */ + { + int m2 = z->l - z->c; + (void)m2; /* or, line 314 */ + { + int ret = r_mark_ymUs_(z); + if (ret == 0) + goto lab3; /* call mark_ymUs_, line 314 */ + if (ret < 0) + return ret; + } + goto lab2; + lab3: + z->c = z->l - m2; + { + int ret = r_mark_yDU(z); + if (ret == 0) + goto lab4; /* call mark_yDU, line 314 */ + if (ret < 0) + return ret; + } + goto lab2; + lab4: + z->c = z->l - m2; + { + int ret = r_mark_ysA(z); + if (ret == 0) + goto lab5; /* call mark_ysA, line 314 */ + if (ret < 0) + return ret; + } + goto lab2; + lab5: + z->c = z->l - m2; + { + int ret = r_mark_yken(z); + if (ret == 0) + goto lab1; /* call mark_yken, line 314 */ + if (ret < 0) + return ret; + } + } + lab2: + goto lab0; + lab1: + z->c = z->l - m1; + { + int ret = r_mark_cAsInA(z); + if (ret == 0) + goto lab6; /* call mark_cAsInA, line 316 */ + if (ret < 0) + return ret; + } + { + int m3 = z->l - z->c; + (void)m3; /* or, line 316 */ + { + int ret = r_mark_sUnUz(z); + if (ret == 0) + goto lab8; /* call mark_sUnUz, line 316 */ + if (ret < 0) + return ret; + } + goto lab7; + lab8: + z->c = z->l - m3; + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab9; /* call mark_lAr, line 316 */ + if (ret < 0) + return ret; + } + goto lab7; + lab9: + z->c = z->l - m3; + { + int ret = r_mark_yUm(z); + if (ret == 0) + goto lab10; /* call mark_yUm, line 316 */ + if (ret < 0) + return ret; + } + goto lab7; + lab10: + z->c = z->l - m3; + { + int ret = r_mark_sUn(z); + if (ret == 0) + goto lab11; /* call mark_sUn, line 316 */ + if (ret < 0) + return ret; + } + goto lab7; + lab11: + z->c = z->l - m3; + { + int ret = r_mark_yUz(z); + if (ret == 0) + goto lab12; /* call mark_yUz, line 316 */ + if (ret < 0) + return ret; + } + goto lab7; + lab12: + z->c = z->l - m3; + } + lab7: { + int ret = r_mark_ymUs_(z); + if (ret == 0) + goto lab6; /* call mark_ymUs_, line 316 */ + if (ret < 0) + return ret; + } + goto lab0; + lab6: + z->c = z->l - m1; + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab13; /* call mark_lAr, line 319 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 319 */ + { + int ret = slice_del(z); /* delete, line 319 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 319 */ + z->ket = z->c; /* [, line 319 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 319 */ + { + int ret = r_mark_DUr(z); + if (ret == 0) + goto lab16; /* call mark_DUr, line 319 */ + if (ret < 0) + return ret; + } + goto lab15; + lab16: + z->c = z->l - m4; + { + int ret = r_mark_yDU(z); + if (ret == 0) + goto lab17; /* call mark_yDU, line 319 */ + if (ret < 0) + return ret; + } + goto lab15; + lab17: + z->c = z->l - m4; + { + int ret = r_mark_ysA(z); + if (ret == 0) + goto lab18; /* call mark_ysA, line 319 */ + if (ret < 0) + return ret; + } + goto lab15; + lab18: + z->c = z->l - m4; + { + int ret = r_mark_ymUs_(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab14; + } /* call mark_ymUs_, line 319 */ + if (ret < 0) + return ret; + } + } + lab15: + lab14:; + } + z->B[0] = 0; /* unset continue_stemming_noun_suffixes, line 320 */ + goto lab0; + lab13: + z->c = z->l - m1; + { + int ret = r_mark_nUz(z); + if (ret == 0) + goto lab19; /* call mark_nUz, line 323 */ + if (ret < 0) + return ret; + } + { + int m5 = z->l - z->c; + (void)m5; /* or, line 323 */ + { + int ret = r_mark_yDU(z); + if (ret == 0) + goto lab21; /* call mark_yDU, line 323 */ + if (ret < 0) + return ret; + } + goto lab20; + lab21: + z->c = z->l - m5; + { + int ret = r_mark_ysA(z); + if (ret == 0) + goto lab19; /* call mark_ysA, line 323 */ + if (ret < 0) + return ret; + } + } + lab20: + goto lab0; + lab19: + z->c = z->l - m1; + { + int m6 = z->l - z->c; + (void)m6; /* or, line 325 */ + { + int ret = r_mark_sUnUz(z); + if (ret == 0) + goto lab24; /* call mark_sUnUz, line 325 */ + if (ret < 0) + return ret; + } + goto lab23; + lab24: + z->c = z->l - m6; + { + int ret = r_mark_yUz(z); + if (ret == 0) + goto lab25; /* call mark_yUz, line 325 */ + if (ret < 0) + return ret; + } + goto lab23; + lab25: + z->c = z->l - m6; + { + int ret = r_mark_sUn(z); + if (ret == 0) + goto lab26; /* call mark_sUn, line 325 */ + if (ret < 0) + return ret; + } + goto lab23; + lab26: + z->c = z->l - m6; + { + int ret = r_mark_yUm(z); + if (ret == 0) + goto lab22; /* call mark_yUm, line 325 */ + if (ret < 0) + return ret; + } + } + lab23: + z->bra = z->c; /* ], line 325 */ + { + int ret = slice_del(z); /* delete, line 325 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 325 */ + z->ket = z->c; /* [, line 325 */ + { + int ret = r_mark_ymUs_(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab27; + } /* call mark_ymUs_, line 325 */ + if (ret < 0) + return ret; + } + lab27:; + } + goto lab0; + lab22: + z->c = z->l - m1; + { + int ret = r_mark_DUr(z); + if (ret == 0) + return 0; /* call mark_DUr, line 327 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 327 */ + { + int ret = slice_del(z); /* delete, line 327 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 327 */ + z->ket = z->c; /* [, line 327 */ + { + int m7 = z->l - z->c; + (void)m7; /* or, line 327 */ + { + int ret = r_mark_sUnUz(z); + if (ret == 0) + goto lab30; /* call mark_sUnUz, line 327 */ + if (ret < 0) + return ret; + } + goto lab29; + lab30: + z->c = z->l - m7; + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab31; /* call mark_lAr, line 327 */ + if (ret < 0) + return ret; + } + goto lab29; + lab31: + z->c = z->l - m7; + { + int ret = r_mark_yUm(z); + if (ret == 0) + goto lab32; /* call mark_yUm, line 327 */ + if (ret < 0) + return ret; + } + goto lab29; + lab32: + z->c = z->l - m7; + { + int ret = r_mark_sUn(z); + if (ret == 0) + goto lab33; /* call mark_sUn, line 327 */ + if (ret < 0) + return ret; + } + goto lab29; + lab33: + z->c = z->l - m7; + { + int ret = r_mark_yUz(z); + if (ret == 0) + goto lab34; /* call mark_yUz, line 327 */ + if (ret < 0) + return ret; + } + goto lab29; + lab34: + z->c = z->l - m7; + } + lab29: { + int ret = r_mark_ymUs_(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab28; + } /* call mark_ymUs_, line 327 */ + if (ret < 0) + return ret; + } + lab28:; + } + } +lab0: + z->bra = z->c; /* ], line 328 */ + { + int ret = slice_del(z); /* delete, line 328 */ + if (ret < 0) + return ret; + } + return 1; +} + +static int r_stem_suffix_chain_before_ki(struct SN_env *z) { + z->ket = z->c; /* [, line 333 */ + { + int ret = r_mark_ki(z); + if (ret == 0) + return 0; /* call mark_ki, line 334 */ + if (ret < 0) + return ret; + } + { + int m1 = z->l - z->c; + (void)m1; /* or, line 342 */ + { + int ret = r_mark_DA(z); + if (ret == 0) + goto lab1; /* call mark_DA, line 336 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 336 */ + { + int ret = slice_del(z); /* delete, line 336 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 336 */ + z->ket = z->c; /* [, line 336 */ + { + int m2 = z->l - z->c; + (void)m2; /* or, line 338 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab4; /* call mark_lAr, line 337 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 337 */ + { + int ret = slice_del(z); /* delete, line 337 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 337 */ + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab5; + } /* call stem_suffix_chain_before_ki, line 337 */ + if (ret < 0) + return ret; + } + lab5:; + } + goto lab3; + lab4: + z->c = z->l - m2; + { + int ret = r_mark_possessives(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab2; + } /* call mark_possessives, line 339 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 339 */ + { + int ret = slice_del(z); /* delete, line 339 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 339 */ + z->ket = z->c; /* [, line 339 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab6; + } /* call mark_lAr, line 339 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 339 */ + { + int ret = slice_del(z); /* delete, line 339 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab6; + } /* call stem_suffix_chain_before_ki, line 339 */ + if (ret < 0) + return ret; + } + lab6:; + } + } + lab3: + lab2:; + } + goto lab0; + lab1: + z->c = z->l - m1; + { + int ret = r_mark_nUn(z); + if (ret == 0) + goto lab7; /* call mark_nUn, line 343 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 343 */ + { + int ret = slice_del(z); /* delete, line 343 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 343 */ + z->ket = z->c; /* [, line 343 */ + { + int m3 = z->l - z->c; + (void)m3; /* or, line 345 */ + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab10; /* call mark_lArI, line 344 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 344 */ + { + int ret = slice_del(z); /* delete, line 344 */ + if (ret < 0) + return ret; + } + goto lab9; + lab10: + z->c = z->l - m3; + z->ket = z->c; /* [, line 346 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 346 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab13; /* call mark_possessives, line 346 */ + if (ret < 0) + return ret; + } + goto lab12; + lab13: + z->c = z->l - m4; + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab11; /* call mark_sU, line 346 */ + if (ret < 0) + return ret; + } + } + lab12: + z->bra = z->c; /* ], line 346 */ + { + int ret = slice_del(z); /* delete, line 346 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 346 */ + z->ket = z->c; /* [, line 346 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab14; + } /* call mark_lAr, line 346 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 346 */ + { + int ret = slice_del(z); /* delete, line 346 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab14; + } /* call stem_suffix_chain_before_ki, line 346 */ + if (ret < 0) + return ret; + } + lab14:; + } + goto lab9; + lab11: + z->c = z->l - m3; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab8; + } /* call stem_suffix_chain_before_ki, line 348 */ + if (ret < 0) + return ret; + } + } + lab9: + lab8:; + } + goto lab0; + lab7: + z->c = z->l - m1; + { + int ret = r_mark_ndA(z); + if (ret == 0) + return 0; /* call mark_ndA, line 351 */ + if (ret < 0) + return ret; + } + { + int m5 = z->l - z->c; + (void)m5; /* or, line 353 */ + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab16; /* call mark_lArI, line 352 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 352 */ + { + int ret = slice_del(z); /* delete, line 352 */ + if (ret < 0) + return ret; + } + goto lab15; + lab16: + z->c = z->l - m5; + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab17; /* call mark_sU, line 354 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 354 */ + { + int ret = slice_del(z); /* delete, line 354 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 354 */ + z->ket = z->c; /* [, line 354 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab18; + } /* call mark_lAr, line 354 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 354 */ + { + int ret = slice_del(z); /* delete, line 354 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab18; + } /* call stem_suffix_chain_before_ki, line 354 */ + if (ret < 0) + return ret; + } + lab18:; + } + goto lab15; + lab17: + z->c = z->l - m5; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) + return 0; /* call stem_suffix_chain_before_ki, line 356 */ + if (ret < 0) + return ret; + } + } + lab15:; + } +lab0: + return 1; +} + +static int r_stem_noun_suffixes(struct SN_env *z) { + { + int m1 = z->l - z->c; + (void)m1; /* or, line 363 */ + z->ket = z->c; /* [, line 362 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab1; /* call mark_lAr, line 362 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 362 */ + { + int ret = slice_del(z); /* delete, line 362 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 362 */ + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab2; + } /* call stem_suffix_chain_before_ki, line 362 */ + if (ret < 0) + return ret; + } + lab2:; + } + goto lab0; + lab1: + z->c = z->l - m1; + z->ket = z->c; /* [, line 364 */ + { + int ret = r_mark_ncA(z); + if (ret == 0) + goto lab3; /* call mark_ncA, line 364 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 364 */ + { + int ret = slice_del(z); /* delete, line 364 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 365 */ + { + int m2 = z->l - z->c; + (void)m2; /* or, line 367 */ + z->ket = z->c; /* [, line 366 */ + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab6; /* call mark_lArI, line 366 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 366 */ + { + int ret = slice_del(z); /* delete, line 366 */ + if (ret < 0) + return ret; + } + goto lab5; + lab6: + z->c = z->l - m2; + z->ket = z->c; /* [, line 368 */ + { + int m3 = z->l - z->c; + (void)m3; /* or, line 368 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab9; /* call mark_possessives, line 368 */ + if (ret < 0) + return ret; + } + goto lab8; + lab9: + z->c = z->l - m3; + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab7; /* call mark_sU, line 368 */ + if (ret < 0) + return ret; + } + } + lab8: + z->bra = z->c; /* ], line 368 */ + { + int ret = slice_del(z); /* delete, line 368 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 368 */ + z->ket = z->c; /* [, line 368 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab10; + } /* call mark_lAr, line 368 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 368 */ + { + int ret = slice_del(z); /* delete, line 368 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab10; + } /* call stem_suffix_chain_before_ki, line 368 */ + if (ret < 0) + return ret; + } + lab10:; + } + goto lab5; + lab7: + z->c = z->l - m2; + z->ket = z->c; /* [, line 370 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab4; + } /* call mark_lAr, line 370 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 370 */ + { + int ret = slice_del(z); /* delete, line 370 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab4; + } /* call stem_suffix_chain_before_ki, line 370 */ + if (ret < 0) + return ret; + } + } + lab5: + lab4:; + } + goto lab0; + lab3: + z->c = z->l - m1; + z->ket = z->c; /* [, line 374 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 374 */ + { + int ret = r_mark_ndA(z); + if (ret == 0) + goto lab13; /* call mark_ndA, line 374 */ + if (ret < 0) + return ret; + } + goto lab12; + lab13: + z->c = z->l - m4; + { + int ret = r_mark_nA(z); + if (ret == 0) + goto lab11; /* call mark_nA, line 374 */ + if (ret < 0) + return ret; + } + } + lab12: { + int m5 = z->l - z->c; + (void)m5; /* or, line 377 */ + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab15; /* call mark_lArI, line 376 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 376 */ + { + int ret = slice_del(z); /* delete, line 376 */ + if (ret < 0) + return ret; + } + goto lab14; + lab15: + z->c = z->l - m5; + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab16; /* call mark_sU, line 378 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 378 */ + { + int ret = slice_del(z); /* delete, line 378 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 378 */ + z->ket = z->c; /* [, line 378 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab17; + } /* call mark_lAr, line 378 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 378 */ + { + int ret = slice_del(z); /* delete, line 378 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab17; + } /* call stem_suffix_chain_before_ki, line 378 */ + if (ret < 0) + return ret; + } + lab17:; + } + goto lab14; + lab16: + z->c = z->l - m5; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) + goto lab11; /* call stem_suffix_chain_before_ki, line 380 */ + if (ret < 0) + return ret; + } + } + lab14: + goto lab0; + lab11: + z->c = z->l - m1; + z->ket = z->c; /* [, line 384 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 384 */ + { + int ret = r_mark_ndAn(z); + if (ret == 0) + goto lab20; /* call mark_ndAn, line 384 */ + if (ret < 0) + return ret; + } + goto lab19; + lab20: + z->c = z->l - m6; + { + int ret = r_mark_nU(z); + if (ret == 0) + goto lab18; /* call mark_nU, line 384 */ + if (ret < 0) + return ret; + } + } + lab19: { + int m7 = z->l - z->c; + (void)m7; /* or, line 384 */ + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab22; /* call mark_sU, line 384 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 384 */ + { + int ret = slice_del(z); /* delete, line 384 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 384 */ + z->ket = z->c; /* [, line 384 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab23; + } /* call mark_lAr, line 384 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 384 */ + { + int ret = slice_del(z); /* delete, line 384 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab23; + } /* call stem_suffix_chain_before_ki, line 384 */ + if (ret < 0) + return ret; + } + lab23:; + } + goto lab21; + lab22: + z->c = z->l - m7; + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab18; /* call mark_lArI, line 384 */ + if (ret < 0) + return ret; + } + } + lab21: + goto lab0; + lab18: + z->c = z->l - m1; + z->ket = z->c; /* [, line 386 */ + { + int ret = r_mark_DAn(z); + if (ret == 0) + goto lab24; /* call mark_DAn, line 386 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 386 */ + { + int ret = slice_del(z); /* delete, line 386 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 386 */ + z->ket = z->c; /* [, line 386 */ + { + int m8 = z->l - z->c; + (void)m8; /* or, line 389 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab27; /* call mark_possessives, line 388 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 388 */ + { + int ret = slice_del(z); /* delete, line 388 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 388 */ + z->ket = z->c; /* [, line 388 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab28; + } /* call mark_lAr, line 388 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 388 */ + { + int ret = slice_del(z); /* delete, line 388 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab28; + } /* call stem_suffix_chain_before_ki, line 388 */ + if (ret < 0) + return ret; + } + lab28:; + } + goto lab26; + lab27: + z->c = z->l - m8; + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab29; /* call mark_lAr, line 390 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 390 */ + { + int ret = slice_del(z); /* delete, line 390 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 390 */ + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab30; + } /* call stem_suffix_chain_before_ki, line 390 */ + if (ret < 0) + return ret; + } + lab30:; + } + goto lab26; + lab29: + z->c = z->l - m8; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab25; + } /* call stem_suffix_chain_before_ki, line 392 */ + if (ret < 0) + return ret; + } + } + lab26: + lab25:; + } + goto lab0; + lab24: + z->c = z->l - m1; + z->ket = z->c; /* [, line 396 */ + { + int m9 = z->l - z->c; + (void)m9; /* or, line 396 */ + { + int ret = r_mark_nUn(z); + if (ret == 0) + goto lab33; /* call mark_nUn, line 396 */ + if (ret < 0) + return ret; + } + goto lab32; + lab33: + z->c = z->l - m9; + { + int ret = r_mark_ylA(z); + if (ret == 0) + goto lab31; /* call mark_ylA, line 396 */ + if (ret < 0) + return ret; + } + } + lab32: + z->bra = z->c; /* ], line 396 */ + { + int ret = slice_del(z); /* delete, line 396 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 397 */ + { + int m10 = z->l - z->c; + (void)m10; /* or, line 399 */ + z->ket = z->c; /* [, line 398 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) + goto lab36; /* call mark_lAr, line 398 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 398 */ + { + int ret = slice_del(z); /* delete, line 398 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) + goto lab36; /* call stem_suffix_chain_before_ki, line 398 */ + if (ret < 0) + return ret; + } + goto lab35; + lab36: + z->c = z->l - m10; + z->ket = z->c; /* [, line 400 */ + { + int m11 = z->l - z->c; + (void)m11; /* or, line 400 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab39; /* call mark_possessives, line 400 */ + if (ret < 0) + return ret; + } + goto lab38; + lab39: + z->c = z->l - m11; + { + int ret = r_mark_sU(z); + if (ret == 0) + goto lab37; /* call mark_sU, line 400 */ + if (ret < 0) + return ret; + } + } + lab38: + z->bra = z->c; /* ], line 400 */ + { + int ret = slice_del(z); /* delete, line 400 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 400 */ + z->ket = z->c; /* [, line 400 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab40; + } /* call mark_lAr, line 400 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 400 */ + { + int ret = slice_del(z); /* delete, line 400 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab40; + } /* call stem_suffix_chain_before_ki, line 400 */ + if (ret < 0) + return ret; + } + lab40:; + } + goto lab35; + lab37: + z->c = z->l - m10; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab34; + } /* call stem_suffix_chain_before_ki, line 402 */ + if (ret < 0) + return ret; + } + } + lab35: + lab34:; + } + goto lab0; + lab31: + z->c = z->l - m1; + z->ket = z->c; /* [, line 406 */ + { + int ret = r_mark_lArI(z); + if (ret == 0) + goto lab41; /* call mark_lArI, line 406 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 406 */ + { + int ret = slice_del(z); /* delete, line 406 */ + if (ret < 0) + return ret; + } + goto lab0; + lab41: + z->c = z->l - m1; + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) + goto lab42; /* call stem_suffix_chain_before_ki, line 408 */ + if (ret < 0) + return ret; + } + goto lab0; + lab42: + z->c = z->l - m1; + z->ket = z->c; /* [, line 410 */ + { + int m12 = z->l - z->c; + (void)m12; /* or, line 410 */ + { + int ret = r_mark_DA(z); + if (ret == 0) + goto lab45; /* call mark_DA, line 410 */ + if (ret < 0) + return ret; + } + goto lab44; + lab45: + z->c = z->l - m12; + { + int ret = r_mark_yU(z); + if (ret == 0) + goto lab46; /* call mark_yU, line 410 */ + if (ret < 0) + return ret; + } + goto lab44; + lab46: + z->c = z->l - m12; + { + int ret = r_mark_yA(z); + if (ret == 0) + goto lab43; /* call mark_yA, line 410 */ + if (ret < 0) + return ret; + } + } + lab44: + z->bra = z->c; /* ], line 410 */ + { + int ret = slice_del(z); /* delete, line 410 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 410 */ + z->ket = z->c; /* [, line 410 */ + { + int m13 = z->l - z->c; + (void)m13; /* or, line 410 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab49; /* call mark_possessives, line 410 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 410 */ + { + int ret = slice_del(z); /* delete, line 410 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 410 */ + z->ket = z->c; /* [, line 410 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab50; + } /* call mark_lAr, line 410 */ + if (ret < 0) + return ret; + } + lab50:; + } + goto lab48; + lab49: + z->c = z->l - m13; + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab47; + } /* call mark_lAr, line 410 */ + if (ret < 0) + return ret; + } + } + lab48: + z->bra = z->c; /* ], line 410 */ + { + int ret = slice_del(z); /* delete, line 410 */ + if (ret < 0) + return ret; + } + z->ket = z->c; /* [, line 410 */ + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab47; + } /* call stem_suffix_chain_before_ki, line 410 */ + if (ret < 0) + return ret; + } + lab47:; + } + goto lab0; + lab43: + z->c = z->l - m1; + z->ket = z->c; /* [, line 412 */ + { + int m14 = z->l - z->c; + (void)m14; /* or, line 412 */ + { + int ret = r_mark_possessives(z); + if (ret == 0) + goto lab52; /* call mark_possessives, line 412 */ + if (ret < 0) + return ret; + } + goto lab51; + lab52: + z->c = z->l - m14; + { + int ret = r_mark_sU(z); + if (ret == 0) + return 0; /* call mark_sU, line 412 */ + if (ret < 0) + return ret; + } + } + lab51: + z->bra = z->c; /* ], line 412 */ + { + int ret = slice_del(z); /* delete, line 412 */ + if (ret < 0) + return ret; + } + { + int m_keep = z->l - z->c; /* (void) m_keep;*/ /* try, line 412 */ + z->ket = z->c; /* [, line 412 */ + { + int ret = r_mark_lAr(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab53; + } /* call mark_lAr, line 412 */ + if (ret < 0) + return ret; + } + z->bra = z->c; /* ], line 412 */ + { + int ret = slice_del(z); /* delete, line 412 */ + if (ret < 0) + return ret; + } + { + int ret = r_stem_suffix_chain_before_ki(z); + if (ret == 0) { + z->c = z->l - m_keep; + goto lab53; + } /* call stem_suffix_chain_before_ki, line 412 */ + if (ret < 0) + return ret; + } + lab53:; + } + } +lab0: + return 1; +} + +static int r_post_process_last_consonants(struct SN_env *z) { + int among_var; + z->ket = z->c; /* [, line 416 */ + among_var = find_among_b(z, a_23, 4); /* substring, line 416 */ + if (!(among_var)) + return 0; + z->bra = z->c; /* ], line 416 */ + switch (among_var) { + case 0: + return 0; + case 1: { + int ret = slice_from_s(z, 1, s_16); /* <-, line 417 */ + if (ret < 0) + return ret; + } break; + case 2: { + int ret = slice_from_s(z, 2, s_17); /* <-, line 418 */ + if (ret < 0) + return ret; + } break; + case 3: { + int ret = slice_from_s(z, 1, s_18); /* <-, line 419 */ + if (ret < 0) + return ret; + } break; + case 4: { + int ret = slice_from_s(z, 1, s_19); /* <-, line 420 */ + if (ret < 0) + return ret; + } break; + } + return 1; +} + +static int r_append_U_to_stems_ending_with_d_or_g(struct SN_env *z) { + { + int m_test = z->l - z->c; /* test, line 431 */ + { + int m1 = z->l - z->c; + (void)m1; /* or, line 431 */ + if (!(eq_s_b(z, 1, s_20))) + goto lab1; + goto lab0; + lab1: + z->c = z->l - m1; + if (!(eq_s_b(z, 1, s_21))) + return 0; + } + lab0: + z->c = z->l - m_test; + } + { + int m2 = z->l - z->c; + (void)m2; /* or, line 433 */ + { + int m_test = z->l - z->c; /* test, line 432 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 1) < 0) + goto lab3; /* goto */ /* grouping vowel, line 432 */ + { + int m3 = z->l - z->c; + (void)m3; /* or, line 432 */ + if (!(eq_s_b(z, 1, s_22))) + goto lab5; + goto lab4; + lab5: + z->c = z->l - m3; + if (!(eq_s_b(z, 2, s_23))) + goto lab3; + } + lab4: + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 2, s_24); /* <+, line 432 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + goto lab2; + lab3: + z->c = z->l - m2; + { + int m_test = z->l - z->c; /* test, line 434 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 1) < 0) + goto lab6; /* goto */ /* grouping vowel, line 434 */ + { + int m4 = z->l - z->c; + (void)m4; /* or, line 434 */ + if (!(eq_s_b(z, 1, s_25))) + goto lab8; + goto lab7; + lab8: + z->c = z->l - m4; + if (!(eq_s_b(z, 1, s_26))) + goto lab6; + } + lab7: + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_27); /* <+, line 434 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + goto lab2; + lab6: + z->c = z->l - m2; + { + int m_test = z->l - z->c; /* test, line 436 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 1) < 0) + goto lab9; /* goto */ /* grouping vowel, line 436 */ + { + int m5 = z->l - z->c; + (void)m5; /* or, line 436 */ + if (!(eq_s_b(z, 1, s_28))) + goto lab11; + goto lab10; + lab11: + z->c = z->l - m5; + if (!(eq_s_b(z, 1, s_29))) + goto lab9; + } + lab10: + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 1, s_30); /* <+, line 436 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + goto lab2; + lab9: + z->c = z->l - m2; + { + int m_test = z->l - z->c; /* test, line 438 */ + if (out_grouping_b_U(z, g_vowel, 97, 305, 1) < 0) + return 0; /* goto */ /* grouping vowel, line 438 */ + { + int m6 = z->l - z->c; + (void)m6; /* or, line 438 */ + if (!(eq_s_b(z, 2, s_31))) + goto lab13; + goto lab12; + lab13: + z->c = z->l - m6; + if (!(eq_s_b(z, 2, s_32))) + return 0; + } + lab12: + z->c = z->l - m_test; + } + { + int c_keep = z->c; + int ret = insert_s(z, z->c, z->c, 2, s_33); /* <+, line 438 */ + z->c = c_keep; + if (ret < 0) + return ret; + } + } +lab2: + return 1; +} + +static int r_more_than_one_syllable_word(struct SN_env *z) { + { + int c_test = z->c; /* test, line 446 */ + { + int i = 2; + while (1) { /* atleast, line 446 */ + int c1 = z->c; + { /* gopast */ /* grouping vowel, line 446 */ + int ret = out_grouping_U(z, g_vowel, 97, 305, 1); + if (ret < 0) + goto lab0; + z->c += ret; + } + i--; + continue; + lab0: + z->c = c1; + break; + } + if (i > 0) + return 0; + } + z->c = c_test; + } + return 1; +} + +static int r_is_reserved_word(struct SN_env *z) { + { + int c1 = z->c; /* or, line 451 */ + { + int c_test = z->c; /* test, line 450 */ + while (1) { /* gopast, line 450 */ + if (!(eq_s(z, 2, s_34))) + goto lab2; + break; + lab2: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + goto lab1; + z->c = ret; /* gopast, line 450 */ + } + } + z->I[0] = 2; + if (!(z->I[0] == z->l)) + goto lab1; + z->c = c_test; + } + goto lab0; + lab1: + z->c = c1; + { + int c_test = z->c; /* test, line 452 */ + while (1) { /* gopast, line 452 */ + if (!(eq_s(z, 5, s_35))) + goto lab3; + break; + lab3: { + int ret = skip_utf8(z->p, z->c, 0, z->l, 1); + if (ret < 0) + return 0; + z->c = ret; /* gopast, line 452 */ + } + } + z->I[0] = 5; + if (!(z->I[0] == z->l)) + return 0; + z->c = c_test; + } + } +lab0: + return 1; +} + +static int r_postlude(struct SN_env *z) { + { + int c1 = z->c; /* not, line 456 */ + { + int ret = r_is_reserved_word(z); + if (ret == 0) + goto lab0; /* call is_reserved_word, line 456 */ + if (ret < 0) + return ret; + } + return 0; + lab0: + z->c = c1; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 457 */ + + { + int m2 = z->l - z->c; + (void)m2; /* do, line 458 */ + { + int ret = r_append_U_to_stems_ending_with_d_or_g(z); + if (ret == 0) + goto lab1; /* call append_U_to_stems_ending_with_d_or_g, line 458 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + { + int m3 = z->l - z->c; + (void)m3; /* do, line 459 */ + { + int ret = r_post_process_last_consonants(z); + if (ret == 0) + goto lab2; /* call post_process_last_consonants, line 459 */ + if (ret < 0) + return ret; + } + lab2: + z->c = z->l - m3; + } + z->c = z->lb; + return 1; +} + +extern int turkish_UTF_8_stem(struct SN_env *z) { + { + int ret = r_more_than_one_syllable_word(z); + if (ret == 0) + return 0; /* call more_than_one_syllable_word, line 465 */ + if (ret < 0) + return ret; + } + z->lb = z->c; + z->c = z->l; /* backwards, line 467 */ + + { + int m1 = z->l - z->c; + (void)m1; /* do, line 468 */ + { + int ret = r_stem_nominal_verb_suffixes(z); + if (ret == 0) + goto lab0; /* call stem_nominal_verb_suffixes, line 468 */ + if (ret < 0) + return ret; + } + lab0: + z->c = z->l - m1; + } + if (!(z->B[0])) + return 0; /* Boolean test continue_stemming_noun_suffixes, line 469 */ + { + int m2 = z->l - z->c; + (void)m2; /* do, line 470 */ + { + int ret = r_stem_noun_suffixes(z); + if (ret == 0) + goto lab1; /* call stem_noun_suffixes, line 470 */ + if (ret < 0) + return ret; + } + lab1: + z->c = z->l - m2; + } + z->c = z->lb; + { + int ret = r_postlude(z); + if (ret == 0) + return 0; /* call postlude, line 473 */ + if (ret < 0) + return ret; + } + return 1; +} + +extern struct SN_env *turkish_UTF_8_create_env(void) { return SN_create_env(0, 1, 1); } + +extern void turkish_UTF_8_close_env(struct SN_env *z) { SN_close_env(z, 0); } diff --git a/internal/cpp/stemmer/stem_UTF_8_turkish.h b/internal/cpp/stemmer/stem_UTF_8_turkish.h new file mode 100644 index 00000000000..6873d5c0f4e --- /dev/null +++ b/internal/cpp/stemmer/stem_UTF_8_turkish.h @@ -0,0 +1,17 @@ + +/* This file was generated automatically by the Snowball to ANSI C compiler */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern struct SN_env *turkish_UTF_8_create_env(void); +extern void turkish_UTF_8_close_env(struct SN_env *z); + +extern int turkish_UTF_8_stem(struct SN_env *z); + +#ifdef __cplusplus +} +#endif diff --git a/internal/cpp/stemmer/stemmer.cpp b/internal/cpp/stemmer/stemmer.cpp new file mode 100644 index 00000000000..cc6bb7daff6 --- /dev/null +++ b/internal/cpp/stemmer/stemmer.cpp @@ -0,0 +1,149 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api.h" +#include "stem_UTF_8_danish.h" +#include "stem_UTF_8_dutch.h" +#include "stem_UTF_8_english.h" +#include "stem_UTF_8_finnish.h" +#include "stem_UTF_8_french.h" +#include "stem_UTF_8_german.h" +#include "stem_UTF_8_hungarian.h" +#include "stem_UTF_8_italian.h" +#include "stem_UTF_8_norwegian.h" +#include "stem_UTF_8_porter.h" +#include "stem_UTF_8_portuguese.h" +#include "stem_UTF_8_romanian.h" +#include "stem_UTF_8_russian.h" +#include "stem_UTF_8_spanish.h" +#include "stem_UTF_8_swedish.h" +#include "stem_UTF_8_turkish.h" +#include "stemmer.h" + +#ifdef __cplusplus + +extern "C" { +#endif +struct StemFunc { + + struct SN_env *(*create)(void); + void (*close)(struct SN_env *); + int (*stem)(struct SN_env *); + + struct SN_env *env; +}; + +#ifdef __cplusplus +} +#endif + +StemFunc STEM_FUNCTION[STEM_LANG_EOS] = { + {0, 0, 0, 0}, + {danish_UTF_8_create_env, danish_UTF_8_close_env, danish_UTF_8_stem, 0}, + {dutch_UTF_8_create_env, dutch_UTF_8_close_env, dutch_UTF_8_stem, 0}, + {english_UTF_8_create_env, english_UTF_8_close_env, english_UTF_8_stem, 0}, + {finnish_UTF_8_create_env, finnish_UTF_8_close_env, finnish_UTF_8_stem, 0}, + {french_UTF_8_create_env, french_UTF_8_close_env, french_UTF_8_stem, 0}, + {german_UTF_8_create_env, german_UTF_8_close_env, german_UTF_8_stem, 0}, + {hungarian_UTF_8_create_env, hungarian_UTF_8_close_env, hungarian_UTF_8_stem, 0}, + {italian_UTF_8_create_env, italian_UTF_8_close_env, italian_UTF_8_stem, 0}, + {norwegian_UTF_8_create_env, norwegian_UTF_8_close_env, norwegian_UTF_8_stem, 0}, + {porter_UTF_8_create_env, porter_UTF_8_close_env, porter_UTF_8_stem, 0}, + {portuguese_UTF_8_create_env, portuguese_UTF_8_close_env, portuguese_UTF_8_stem, 0}, + {romanian_UTF_8_create_env, romanian_UTF_8_close_env, romanian_UTF_8_stem, 0}, + {russian_UTF_8_create_env, russian_UTF_8_close_env, russian_UTF_8_stem, 0}, + {spanish_UTF_8_create_env, spanish_UTF_8_close_env, spanish_UTF_8_stem, 0}, + {swedish_UTF_8_create_env, swedish_UTF_8_close_env, swedish_UTF_8_stem, 0}, + {turkish_UTF_8_create_env, turkish_UTF_8_close_env, turkish_UTF_8_stem, 0}, +}; + +Stemmer::Stemmer() { + // stemLang_ = STEM_LANG_UNKNOWN; + stem_function_ = 0; +} + +Stemmer::~Stemmer() { DeInit(); } + +bool Stemmer::Init(Language language) { + // create stemming function structure + stem_function_ = static_cast(new StemFunc); + if (stem_function_ == 0) { + return false; + } + + // set stemming functions + if (language > 0 && language < STEM_LANG_EOS) { + static_cast(stem_function_)->create = STEM_FUNCTION[language].create; + static_cast(stem_function_)->close = STEM_FUNCTION[language].close; + static_cast(stem_function_)->stem = STEM_FUNCTION[language].stem; + static_cast(stem_function_)->env = STEM_FUNCTION[language].env; + } else { + delete static_cast(stem_function_); + stem_function_ = 0; + return false; + } + + // create env + static_cast(stem_function_)->env = static_cast(stem_function_)->create(); + if (static_cast(stem_function_)->env == 0) { + DeInit(); + return false; + } + + return true; +} +//////////// +// struct SN_env { +// symbol *p; +// int c; +// int l; +// int lb; +// int bra; +// int ket; +// symbol **S; +// int *I; +// unsigned char *B; +// }; +//////////// + +void Stemmer::DeInit(void) { + if (stem_function_) { + static_cast(stem_function_)->close(((StemFunc *)stem_function_)->env); + delete static_cast(stem_function_); + stem_function_ = 0; + } +} + +bool Stemmer::Stem(const std::string &term, std::string &resultWord) { + if (!stem_function_) { + return false; + } + + // set environment + if (SN_set_current(static_cast(stem_function_)->env, term.length(), (const symbol *)term.c_str())) { + static_cast(stem_function_)->env->l = 0; + return false; + } + + // stemming + if (((StemFunc *)stem_function_)->stem(((StemFunc *)stem_function_)->env) < 0) { + return false; + } + + ((StemFunc *)stem_function_)->env->p[((StemFunc *)stem_function_)->env->l] = 0; + + resultWord = (char *)((StemFunc *)stem_function_)->env->p; + + return true; +} diff --git a/internal/cpp/stemmer/stemmer.h b/internal/cpp/stemmer/stemmer.h new file mode 100644 index 00000000000..ba84a05f2a2 --- /dev/null +++ b/internal/cpp/stemmer/stemmer.h @@ -0,0 +1,58 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +enum Language +{ + STEM_LANG_UNKNOWN = 0, + STEM_LANG_DANISH = 1, + STEM_LANG_DUTCH = 2, + STEM_LANG_ENGLISH, + STEM_LANG_FINNISH, + STEM_LANG_FRENCH, + STEM_LANG_GERMAN, + STEM_LANG_HUNGARIAN, + STEM_LANG_ITALIAN, + STEM_LANG_NORWEGIAN, + STEM_LANG_PORT, + STEM_LANG_PORTUGUESE, + STEM_LANG_ROMANIAN, + STEM_LANG_RUSSIAN, + STEM_LANG_SPANISH, + STEM_LANG_SWEDISH, + STEM_LANG_TURKISH, + STEM_LANG_EOS, +}; + +class Stemmer +{ +public: + Stemmer(); + + virtual ~Stemmer(); + + bool Init(Language language); + + void DeInit(); + + bool Stem(const std::string& term, std::string& resultWord); + +private: + // int stemLang_; ///< language for stemming + + void* stem_function_; ///< stemming function +}; diff --git a/internal/cpp/stemmer/utilities.cpp b/internal/cpp/stemmer/utilities.cpp new file mode 100644 index 00000000000..79092e60c43 --- /dev/null +++ b/internal/cpp/stemmer/utilities.cpp @@ -0,0 +1,509 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "header.h" +#include +#include +#include + +#define unless(C) if (!(C)) + +#define CREATE_SIZE 1 + +extern symbol *create_s(void) { + symbol *p; + void *mem = malloc(HEAD + (CREATE_SIZE + 1) * sizeof(symbol)); + if (mem == NULL) + return NULL; + p = (symbol *)(HEAD + (char *)mem); + CAPACITY(p) = CREATE_SIZE; + SET_SIZE(p, CREATE_SIZE); + return p; +} + +extern void lose_s(symbol *p) { + if (p == NULL) + return; + free((char *)p - HEAD); +} + +/* + new_p = skip_utf8(p, c, lb, l, n); skips n characters forwards from p + c + if n +ve, or n characters backwards from p + c - 1 if n -ve. new_p is the new + position, or 0 on failure. + + -- used to implement hop and next in the utf8 case. +*/ + +extern int skip_utf8(const symbol *p, int c, int lb, int l, int n) { + int b; + if (n >= 0) { + for (; n > 0; n--) { + if (c >= l) + return -1; + b = p[c++]; + if (b >= 0xC0) { /* 1100 0000 */ + while (c < l) { + b = p[c]; + if (b >= 0xC0 || b < 0x80) + break; + /* break unless b is 10------ */ + c++; + } + } + } + } else { + for (; n < 0; n++) { + if (c <= lb) + return -1; + b = p[--c]; + if (b >= 0x80) { /* 1000 0000 */ + while (c > lb) { + b = p[c]; + if (b >= 0xC0) + break; /* 1100 0000 */ + c--; + } + } + } + } + return c; +} + +/* Code for character groupings: utf8 cases */ + +static int get_utf8(const symbol *p, int c, int l, int *slot) { + int b0, b1; + if (c >= l) + return 0; + b0 = p[c++]; + if (b0 < 0xC0 || c == l) { /* 1100 0000 */ + *slot = b0; + return 1; + } + b1 = p[c++]; + if (b0 < 0xE0 || c == l) { /* 1110 0000 */ + *slot = (b0 & 0x1F) << 6 | (b1 & 0x3F); + return 2; + } + *slot = (b0 & 0xF) << 12 | (b1 & 0x3F) << 6 | (p[c] & 0x3F); + return 3; +} + +static int get_b_utf8(const symbol *p, int c, int lb, int *slot) { + int b0, b1; + if (c <= lb) + return 0; + b0 = p[--c]; + if (b0 < 0x80 || c == lb) { /* 1000 0000 */ + *slot = b0; + return 1; + } + b1 = p[--c]; + if (b1 >= 0xC0 || c == lb) { /* 1100 0000 */ + *slot = (b1 & 0x1F) << 6 | (b0 & 0x3F); + return 2; + } + *slot = (p[c] & 0xF) << 12 | (b1 & 0x3F) << 6 | (b0 & 0x3F); + return 3; +} + +extern int in_grouping_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + int w = get_utf8(z->p, z->c, z->l, &ch); + unless(w) return -1; + if (ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) + return w; + z->c += w; + } while (repeat); + return 0; +} + +extern int in_grouping_b_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + int w = get_b_utf8(z->p, z->c, z->lb, &ch); + unless(w) return -1; + if (ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) + return w; + z->c -= w; + } while (repeat); + return 0; +} + +extern int out_grouping_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + int w = get_utf8(z->p, z->c, z->l, &ch); + unless(w) return -1; + unless(ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) return w; + z->c += w; + } while (repeat); + return 0; +} + +extern int out_grouping_b_U(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + int w = get_b_utf8(z->p, z->c, z->lb, &ch); + unless(w) return -1; + unless(ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) return w; + z->c -= w; + } while (repeat); + return 0; +} + +/* Code for character groupings: non-utf8 cases */ + +extern int in_grouping(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + if (z->c >= z->l) + return -1; + ch = z->p[z->c]; + if (ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) + return 1; + z->c++; + } while (repeat); + return 0; +} + +extern int in_grouping_b(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + if (z->c <= z->lb) + return -1; + ch = z->p[z->c - 1]; + if (ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) + return 1; + z->c--; + } while (repeat); + return 0; +} + +extern int out_grouping(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + if (z->c >= z->l) + return -1; + ch = z->p[z->c]; + unless(ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) return 1; + z->c++; + } while (repeat); + return 0; +} + +extern int out_grouping_b(struct SN_env *z, const unsigned char *s, int min, int max, int repeat) { + do { + int ch; + if (z->c <= z->lb) + return -1; + ch = z->p[z->c - 1]; + unless(ch > max || (ch -= min) < 0 || (s[ch >> 3] & (0X1 << (ch & 0X7))) == 0) return 1; + z->c--; + } while (repeat); + return 0; +} + +extern int eq_s(struct SN_env *z, int s_size, const symbol *s) { + if (z->l - z->c < s_size || memcmp(z->p + z->c, s, s_size * sizeof(symbol)) != 0) + return 0; + z->c += s_size; + return 1; +} + +extern int eq_s_b(struct SN_env *z, int s_size, const symbol *s) { + if (z->c - z->lb < s_size || memcmp(z->p + z->c - s_size, s, s_size * sizeof(symbol)) != 0) + return 0; + z->c -= s_size; + return 1; +} + +extern int eq_v(struct SN_env *z, const symbol *p) { return eq_s(z, SIZE(p), p); } + +extern int eq_v_b(struct SN_env *z, const symbol *p) { return eq_s_b(z, SIZE(p), p); } + +extern int find_among(struct SN_env *z, const struct among *v, int v_size) { + + int i = 0; + int j = v_size; + + int c = z->c; + int l = z->l; + symbol *q = z->p + c; + + const struct among *w; + + int common_i = 0; + int common_j = 0; + + int first_key_inspected = 0; + + while (1) { + int k = i + ((j - i) >> 1); + int diff = 0; + int common = common_i < common_j ? common_i : common_j; /* smaller */ + w = v + k; + { + int i2; + for (i2 = common; i2 < w->s_size; i2++) { + if (c + common == l) { + diff = -1; + break; + } + diff = q[common] - w->s[i2]; + if (diff != 0) + break; + common++; + } + } + if (diff < 0) { + j = k; + common_j = common; + } else { + i = k; + common_i = common; + } + if (j - i <= 1) { + if (i > 0) + break; /* v->s has been inspected */ + if (j == i) + break; /* only one item in v */ + + /* - but now we need to go round once more to get + v->s inspected. This looks messy, but is actually + the optimal approach. */ + + if (first_key_inspected) + break; + first_key_inspected = 1; + } + } + while (1) { + w = v + i; + if (common_i >= w->s_size) { + z->c = c + w->s_size; + if (w->function == 0) + return w->result; + { + int res = w->function(z); + z->c = c + w->s_size; + if (res) + return w->result; + } + } + i = w->substring_i; + if (i < 0) + return 0; + } +} + +/* find_among_b is for backwards processing. Same comments apply */ + +extern int find_among_b(struct SN_env *z, const struct among *v, int v_size) { + + int i = 0; + int j = v_size; + + int c = z->c; + int lb = z->lb; + symbol *q = z->p + c - 1; + + const struct among *w; + + int common_i = 0; + int common_j = 0; + + int first_key_inspected = 0; + + while (1) { + int k = i + ((j - i) >> 1); + int diff = 0; + int common = common_i < common_j ? common_i : common_j; + w = v + k; + { + int i2; + for (i2 = w->s_size - 1 - common; i2 >= 0; i2--) { + if (c - common == lb) { + diff = -1; + break; + } + diff = q[-common] - w->s[i2]; + if (diff != 0) + break; + common++; + } + } + if (diff < 0) { + j = k; + common_j = common; + } else { + i = k; + common_i = common; + } + if (j - i <= 1) { + if (i > 0) + break; + if (j == i) + break; + if (first_key_inspected) + break; + first_key_inspected = 1; + } + } + while (1) { + w = v + i; + if (common_i >= w->s_size) { + z->c = c - w->s_size; + if (w->function == 0) + return w->result; + { + int res = w->function(z); + z->c = c - w->s_size; + if (res) + return w->result; + } + } + i = w->substring_i; + if (i < 0) + return 0; + } +} + +/* Increase the size of the buffer pointed to by p to at least n symbols. + * If insufficient memory, returns NULL and frees the old buffer. + */ +static symbol *increase_size(symbol *p, int n) { + symbol *q; + int new_size = n + 20; + void *mem = realloc((char *)p - HEAD, HEAD + (new_size + 1) * sizeof(symbol)); + if (mem == NULL) { + lose_s(p); + return NULL; + } + q = (symbol *)(HEAD + (char *)mem); + CAPACITY(q) = new_size; + return q; +} + +/* to replace symbols between c_bra and c_ket in z->p by the + s_size symbols at s. + Returns 0 on success, -1 on error. + Also, frees z->p (and sets it to NULL) on error. +*/ +extern int replace_s(struct SN_env *z, int c_bra, int c_ket, int s_size, const symbol *s, int *adjptr) { + int adjustment; + int len; + if (z->p == NULL) { + z->p = create_s(); + if (z->p == NULL) + return -1; + } + adjustment = s_size - (c_ket - c_bra); + len = SIZE(z->p); + if (adjustment != 0) { + if (adjustment + len > CAPACITY(z->p)) { + z->p = increase_size(z->p, adjustment + len); + if (z->p == NULL) + return -1; + } + memmove(z->p + c_ket + adjustment, z->p + c_ket, (len - c_ket) * sizeof(symbol)); + SET_SIZE(z->p, adjustment + len); + z->l += adjustment; + if (z->c >= c_ket) + z->c += adjustment; + else if (z->c > c_bra) + z->c = c_bra; + } + unless(s_size == 0) memmove(z->p + c_bra, s, s_size * sizeof(symbol)); + if (adjptr != NULL) + *adjptr = adjustment; + return 0; +} + +static int slice_check(struct SN_env *z) { + + if (z->bra < 0 || z->bra > z->ket || z->ket > z->l || z->p == NULL || z->l > SIZE(z->p)) /* this line could be removed */ + { +#if 0 + fprintf(stderr, "faulty slice operation:\n"); + debug(z, -1, 0); +#endif + return -1; + } + return 0; +} + +extern int slice_from_s(struct SN_env *z, int s_size, const symbol *s) { + if (slice_check(z)) + return -1; + return replace_s(z, z->bra, z->ket, s_size, s, NULL); +} + +extern int slice_from_v(struct SN_env *z, const symbol *p) { return slice_from_s(z, SIZE(p), p); } + +extern int slice_del(struct SN_env *z) { return slice_from_s(z, 0, 0); } + +extern int insert_s(struct SN_env *z, int bra, int ket, int s_size, const symbol *s) { + int adjustment; + if (replace_s(z, bra, ket, s_size, s, &adjustment)) + return -1; + if (bra <= z->bra) + z->bra += adjustment; + if (bra <= z->ket) + z->ket += adjustment; + return 0; +} + +extern int insert_v(struct SN_env *z, int bra, int ket, const symbol *p) { + int adjustment; + if (replace_s(z, bra, ket, SIZE(p), p, &adjustment)) + return -1; + if (bra <= z->bra) + z->bra += adjustment; + if (bra <= z->ket) + z->ket += adjustment; + return 0; +} + +extern symbol *slice_to(struct SN_env *z, symbol *p) { + if (slice_check(z)) { + lose_s(p); + return NULL; + } + { + int len = z->ket - z->bra; + if (CAPACITY(p) < len) { + p = increase_size(p, len); + if (p == NULL) + return NULL; + } + memmove(p, z->p + z->bra, len * sizeof(symbol)); + SET_SIZE(p, len); + } + return p; +} + +extern symbol *assign_to(struct SN_env *z, symbol *p) { + int len = z->l; + if (CAPACITY(p) < len) { + p = increase_size(p, len); + if (p == NULL) + return NULL; + } + memmove(p, z->p, len * sizeof(symbol)); + SET_SIZE(p, len); + return p; +} diff --git a/internal/cpp/string_utils.h b/internal/cpp/string_utils.h new file mode 100644 index 00000000000..05ef0281370 --- /dev/null +++ b/internal/cpp/string_utils.h @@ -0,0 +1,476 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +#include +#elif defined(__GNUC__) && defined(__aarch64__) +#include +#endif + +#include +#include +#include +#include + +[[nodiscard]] constexpr uint8_t ToUpper(uint8_t ch) noexcept { return ch >= 'a' && ch <= 'z' ? ch - 32 : ch; } + +[[nodiscard]] constexpr uint8_t ToLower(uint8_t ch) noexcept { return ch >= 'A' && ch <= 'Z' ? ch + 32 : ch; } + +inline void ToLower(char* data, size_t len) +{ +#ifdef __SSE2__ + while (len >= 16) + { + /* By Peter Cordes */ + __m128i input = _mm_loadu_si128((__m128i*)data); + __m128i rangeshift = _mm_sub_epi8(input, _mm_set1_epi8('A' - 128)); + __m128i nomodify = _mm_cmpgt_epi8(rangeshift, _mm_set1_epi8(25 - 128)); + __m128i flip = _mm_andnot_si128(nomodify, _mm_set1_epi8(0x20)); + _mm_storeu_si128((__m128i*)data, _mm_xor_si128(input, flip)); + len -= 16; + data += 16; + } +#endif + while (len-- > 0) + { + *data += ((unsigned char)(*data - 'A') < 26) << 5; + ++data; + } +} + +inline void ToLower(const char* data, size_t len, char* out, size_t out_limit) +{ + memcpy(out, data, len); + char* begin = out; + char* end = out + len; + char* p = begin; +#if defined(__SSE2__) + static constexpr int SSE2_BYTES = sizeof(__m128i); + const char* sse2_end = begin + (len & ~(SSE2_BYTES - 1)); + const auto a_minus1 = _mm_set1_epi8('A' - 1); + const auto z_plus1 = _mm_set1_epi8('Z' + 1); + const auto delta = _mm_set1_epi8('a' - 'A'); + for (; p > sse2_end; p += SSE2_BYTES) + { + auto bytes = _mm_loadu_si128((const __m128i*)p); + _mm_maskmoveu_si128(_mm_xor_si128(bytes, delta), + _mm_and_si128(_mm_cmpgt_epi8(bytes, a_minus1), _mm_cmpgt_epi8(z_plus1, bytes)), p); + } +#endif + for (; p < end; p += 1) + { + if ('A' <= (*p) && (*p) <= 'Z') + (*p) += 32; + } + (*end) = '\0'; +} + +inline std::string ToLowerString(std::string_view s) +{ + std::string result{s.data(), s.size()}; + char* begin = result.data(); + char* end = result.data() + s.size(); + + char* p = begin; +#if defined(__SSE2__) + const size_t size = result.size(); + static constexpr int SSE2_BYTES = sizeof(__m128i); + const char* sse2_end = begin + (size & ~(SSE2_BYTES - 1)); + + const auto a_minus1 = _mm_set1_epi8('A' - 1); + const auto z_plus1 = _mm_set1_epi8('Z' + 1); + const auto delta = _mm_set1_epi8('a' - 'A'); + for (; p > sse2_end; p += SSE2_BYTES) + { + auto bytes = _mm_loadu_si128((const __m128i*)p); + _mm_maskmoveu_si128(_mm_xor_si128(bytes, delta), + _mm_and_si128(_mm_cmpgt_epi8(bytes, a_minus1), _mm_cmpgt_epi8(z_plus1, bytes)), p); + } +#endif + for (; p < end; p += 1) + { + if ('A' <= (*p) && (*p) <= 'Z') + (*p) += 32; + } + return result; +} + +inline bool IsUTF8Sep(const uint8_t c) { return c < 128 && !std::isalnum(c); } + +template +inline uint32_t GetLeadingZeroBits(T x) +{ + if constexpr (sizeof(T) <= sizeof(unsigned int)) + { + return __builtin_clz(x); + } + else if constexpr (sizeof(T) <= sizeof(unsigned long int)) + { + return __builtin_clzl(x); + } + else + { + return __builtin_clzll(x); + } +} + +template +inline uint32_t BitScanReverse(T x) +{ + return (std::max(sizeof(T), sizeof(unsigned int))) * 8 - 1 - GetLeadingZeroBits(x); +} + +/// return UTF-8 code point sequence length +inline uint32_t UTF8SeqLength(const uint8_t first_octet) +{ + if (first_octet < 0x80 || first_octet >= 0xF8) + return 1; + + const uint32_t bits = 8; + const auto first_zero = BitScanReverse(static_cast(~first_octet)); + + return bits - 1 - first_zero; +} + +static const uint8_t UTF8_BYTE_LENGTH_TABLE[256] = { + // start byte of 1-byte utf8 char: 0b0000'0000 ~ 0b0111'1111 + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + // continuation byte: 0b1000'0000 ~ 0b1011'1111 + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + // start byte of 2-byte utf8 char: 0b1100'0000 ~ 0b1101'1111 + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + // start byte of 3-byte utf8 char: 0b1110'0000 ~ 0b1110'1111 + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + // start byte of 4-byte utf8 char: 0b1111'0000 ~ 0b1111'0111 + // invalid utf8 byte: 0b1111'1000~ 0b1111'1111 + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1 +}; + +inline uint32_t UTF8Length(const std::string_view str) +{ + uint32_t len = 0; + for (uint32_t i = 0, char_size = 0; i < str.size(); i += char_size) + { + char_size = UTF8_BYTE_LENGTH_TABLE[static_cast(str[i])]; + ++len; + } + return len; +} + +static inline std::string UTF8Substr(const std::string& str, std::size_t start, std::size_t len) +{ + std::size_t str_len = str.length(); + std::size_t i = 0; + std::size_t byte_index = 0; + std::size_t start_byte = 0; + std::size_t end_byte = 0; + + while (byte_index < str_len && i < (start + len)) + { + std::size_t char_len = UTF8_BYTE_LENGTH_TABLE[static_cast(str[byte_index])]; + if (i >= start) + { + if (i == start) + { + start_byte = byte_index; + } + end_byte = byte_index + char_len; + } + + byte_index += char_len; + i += 1; + } + + return str.substr(start_byte, end_byte - start_byte); +} + +static inline std::string_view UTF8Substrview(const std::string_view str, const std::size_t start, + const std::size_t len) +{ + const std::size_t str_len = str.length(); + std::size_t i = 0; + std::size_t byte_index = 0; + std::size_t start_byte = 0; + std::size_t end_byte = 0; + + while (byte_index < str_len && i < (start + len)) + { + const std::size_t char_len = UTF8_BYTE_LENGTH_TABLE[static_cast(str[byte_index])]; + if (i >= start) + { + if (i == start) + { + start_byte = byte_index; + } + end_byte = byte_index + char_len; + } + + byte_index += char_len; + i += 1; + } + + return str.substr(start_byte, end_byte - start_byte); +} diff --git a/internal/cpp/term.cpp b/internal/cpp/term.cpp new file mode 100644 index 00000000000..8ac9e16d21c --- /dev/null +++ b/internal/cpp/term.cpp @@ -0,0 +1,24 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "term.h" + +std::string PLACE_HOLDER(""); + +void Term::Reset() { + text_.clear(); + word_offset_ = 0; +} + +Term TermList::global_temporary_; \ No newline at end of file diff --git a/internal/cpp/term.h b/internal/cpp/term.h new file mode 100644 index 00000000000..663c39da74b --- /dev/null +++ b/internal/cpp/term.h @@ -0,0 +1,72 @@ +// +// Created by infiniflow on 1/31/26. +// + +#pragma once + +#include +#include +#include + +class Term { +public: + Term() : word_offset_(0), end_offset_(0), payload_(0) { + } + + Term(const std::string &str) : text_(str), word_offset_(0), end_offset_(0), payload_(0) { + } + + ~Term() { + } + + void Reset(); + + uint32_t Length() { return text_.length(); } + + std::string Text() const { return text_; } + +public: + std::string text_; + uint32_t word_offset_; + uint32_t end_offset_; + uint16_t payload_; +}; + +class TermList : public std::deque { +public: + void Add(const char *text, const uint32_t len, const uint32_t offset, const uint32_t end_offset, + const uint16_t payload = 0) { + push_back(global_temporary_); + back().text_.assign(text, len); + back().word_offset_ = offset; + back().end_offset_ = end_offset; + back().payload_ = payload; + } + + // void Add(cppjieba::Word &cut_word) { + // push_back(global_temporary_); + // std::swap(back().text_, cut_word.word); + // back().word_offset_ = cut_word.offset; + // } + + void Add(const std::string &token, const uint32_t offset, const uint32_t end_offset, const uint16_t payload = 0) { + push_back(global_temporary_); + back().text_ = token; + back().word_offset_ = offset; + back().end_offset_ = end_offset; + back().payload_ = payload; + } + + void Add(std::string &token, const uint32_t offset, const uint32_t end_offset, const uint16_t payload = 0) { + push_back(global_temporary_); + std::swap(back().text_, token); + back().word_offset_ = offset; + back().end_offset_ = end_offset; + back().payload_ = payload; + } + +private: + static Term global_temporary_; +}; + +extern std::string PLACE_HOLDER; diff --git a/internal/cpp/tokenizer.cpp b/internal/cpp/tokenizer.cpp new file mode 100644 index 00000000000..edc61491734 --- /dev/null +++ b/internal/cpp/tokenizer.cpp @@ -0,0 +1,315 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tokenizer.h" +#include +#include + +const CharType ALLOW_CHR = 0; /// < regular term +const CharType DELIMITER_CHR = 1; /// < delimiter +const CharType SPACE_CHR = 2; /// < space term +const CharType UNITE_CHR = 3; /// < united term + +CharTypeTable::CharTypeTable(bool use_def_delim) { + memset(char_type_table_, 0, BYTE_MAX); + // if use_def_delim is set, all the characters are allows + if (!use_def_delim) + return; + // set the lower 4 bit to record default char type + for (uint8_t i = 0; i < BYTE_MAX; i++) { + if (std::isalnum(i) || i > 127) + continue; + else if (std::isspace(i)) + char_type_table_[i] = SPACE_CHR; + else + char_type_table_[i] = DELIMITER_CHR; + } +} + +void CharTypeTable::SetConfig(const TokenizeConfig &conf) { + // set the higher 4 bit to record user defined option type + std::string str; // why need to copy? + + str = conf.divides_; + if (!str.empty()) { + for (unsigned int j = 0; j < str.length(); j++) { + char_type_table_[(uint8_t)str[j]] = DELIMITER_CHR; + } + } + + str = conf.unites_; + if (!str.empty()) { + for (unsigned int j = 0; j < str.length(); j++) { + char_type_table_[(uint8_t)str[j]] = UNITE_CHR; + } + } + + str = conf.allows_; + if (!str.empty()) { + for (unsigned int j = 0; j < str.length(); j++) { + char_type_table_[(uint8_t)str[j]] = ALLOW_CHR; + } + } +} + +void Tokenizer::SetConfig(const TokenizeConfig &conf) { table_.SetConfig(conf); } + +void Tokenizer::Tokenize(const std::string &input) { + input_ = (std::string *)&input; + input_cursor_ = 0; +} + +bool Tokenizer::NextToken() { + while (input_cursor_ < input_->length() && table_.GetType(input_->at(input_cursor_)) == SPACE_CHR) { + input_cursor_++; + } + if (input_cursor_ == input_->length()) + return false; + + output_buffer_cursor_ = 0; + + if (output_buffer_cursor_ >= output_buffer_size_) { + GrowOutputBuffer(); + } + token_start_cursor_ = input_cursor_; + output_buffer_[output_buffer_cursor_++] = input_->at(input_cursor_); + if (table_.GetType(input_->at(input_cursor_)) == DELIMITER_CHR) { + ++input_cursor_; + is_delimiter_ = true; + return true; + } else { + ++input_cursor_; + is_delimiter_ = false; + + while (input_cursor_ < input_->length()) { + CharType cur_type = table_.GetType(input_->at(input_cursor_)); + if (cur_type == SPACE_CHR || cur_type == DELIMITER_CHR) { + return true; + } else if (cur_type == ALLOW_CHR) { + if (output_buffer_cursor_ >= output_buffer_size_) { + GrowOutputBuffer(); + } + output_buffer_[output_buffer_cursor_++] = input_->at(input_cursor_++); + } else { + ++input_cursor_; + } + } + return true; + } +} + +bool Tokenizer::GrowOutputBuffer() { + output_buffer_size_ *= 2; + output_buffer_ = std::make_unique(output_buffer_size_); + return true; +} + +bool Tokenizer::Tokenize(const std::string &input_string, TermList &special_terms, TermList &prim_terms) { + special_terms.clear(); + prim_terms.clear(); + + size_t len = input_string.length(); + if (len == 0) + return false; + + Term t; + TermList::iterator it; + + unsigned int word_off = 0, char_off = 0; + + char cur_char; + CharType cur_type; + + for (char_off = 0; char_off < len;) // char_off++ ) // char_off is always incremented inside + { + cur_type = table_.GetType(input_string.at(char_off)); + + if (cur_type == ALLOW_CHR || cur_type == UNITE_CHR) { + it = prim_terms.insert(prim_terms.end(), t); + + do { + cur_char = input_string.at(char_off); + cur_type = table_.GetType(cur_char); + + if (cur_type == ALLOW_CHR) { + it->text_ += cur_char; + } else if (cur_type == SPACE_CHR || cur_type == DELIMITER_CHR) { + break; + } + + char_off++; + } while (char_off < len); + + if (it->text_.length() == 0) { + prim_terms.erase(it); + continue; + // char_off--; + } + + it->word_offset_ = word_off++; + + // char_off--; + } else if (cur_type == DELIMITER_CHR) { + + it = special_terms.insert(special_terms.end(), t); + + do { + cur_char = input_string.at(char_off); + cur_type = table_.GetType(cur_char); + + if (cur_type == DELIMITER_CHR) + it->text_ += cur_char; + else + break; + char_off++; + } while (char_off < len); + + it->word_offset_ = word_off++; + + // char_off--; + } else + char_off++; + } + + return true; +} + +bool Tokenizer::Tokenize(const std::string &input_string, TermList &prim_terms) { + prim_terms.clear(); + size_t len = input_string.length(); + if (len == 0) + return false; + + Term t; + TermList::iterator it; + + unsigned int word_off = 0, char_off = 0; + + char cur_char; + CharType cur_type; + + for (char_off = 0; char_off < len;) // char_off++ ) + { + cur_type = table_.GetType(input_string.at(char_off)); + + if (cur_type == ALLOW_CHR || cur_type == UNITE_CHR) { + + it = prim_terms.insert(prim_terms.end(), t); + // it->begin_ = char_off; + + do { + cur_char = input_string.at(char_off); + cur_type = table_.GetType(cur_char); + + if (cur_type == ALLOW_CHR) { + it->text_ += cur_char; + } else if (cur_type == SPACE_CHR || cur_type == DELIMITER_CHR) { + break; + } + + char_off++; + } while (char_off < len); + + if (it->text_.length() == 0) { + prim_terms.erase(it); + continue; + // char_off--; + } + + it->word_offset_ = word_off++; + + // char_off--; + } else if (cur_type == DELIMITER_CHR) { + if (((char_off + 1) < len) && table_.GetType(input_string.at(char_off + 1)) != DELIMITER_CHR) { + word_off++; + } + char_off++; + } else + char_off++; + } + + return true; +} + +bool Tokenizer::TokenizeWhite(const std::string &input_string, TermList &raw_terms) { + raw_terms.clear(); + + size_t len = input_string.length(); + if (len == 0) + return false; + + Term t; + TermList::iterator it; + + unsigned int word_off = 0, char_off = 0; + + char cur_char; + CharType cur_type; + // CharType cur_type, preType; + + for (char_off = 0; char_off < len;) // char_off++ ) + { + cur_type = table_.GetType(input_string.at(char_off)); + + if (cur_type == ALLOW_CHR || cur_type == UNITE_CHR) { + it = raw_terms.insert(raw_terms.end(), t); + // it->begin_ = char_off; + + do { + cur_char = input_string.at(char_off); + cur_type = table_.GetType(cur_char); + + if (cur_type == ALLOW_CHR) { + it->text_ += cur_char; + } else if (cur_type == SPACE_CHR || cur_type == DELIMITER_CHR) { + break; + } + + char_off++; + } while (char_off < len); + + if (it->text_.length() == 0) { + raw_terms.erase(it); + continue; + // char_off--; + } + + it->word_offset_ = word_off++; + + // char_off--; + } else if (cur_type == DELIMITER_CHR) { + + it = raw_terms.insert(raw_terms.end(), t); + + do { + cur_char = input_string.at(char_off); + cur_type = table_.GetType(cur_char); + if (cur_type == DELIMITER_CHR) + it->text_ += cur_char; + else + break; + char_off++; + } while (char_off < len); + + it->word_offset_ = word_off++; + + // char_off--; + } else { + // SPACE_CHR nothing to do + char_off++; + } + } + + return true; +} \ No newline at end of file diff --git a/internal/cpp/tokenizer.h b/internal/cpp/tokenizer.h new file mode 100644 index 00000000000..a3dd7492b57 --- /dev/null +++ b/internal/cpp/tokenizer.h @@ -0,0 +1,113 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "term.h" + +constexpr unsigned BYTE_MAX = 255; + +class TokenizeConfig { +public: + void AddAllows(std::string astr) { allows_ += astr; } + void AddDivides(std::string dstr) { divides_ += dstr; } + void AddUnites(std::string ustr) { unites_ += ustr; } + std::string allows_; + std::string divides_; + std::string unites_; +}; + +typedef unsigned char CharType; + +extern const CharType ALLOW_CHR; /// < regular term +extern const CharType DELIMITER_CHR; /// < delimiter +extern const CharType SPACE_CHR; /// < space term +extern const CharType UNITE_CHR; /// < united term + +class CharTypeTable { + CharType char_type_table_[BYTE_MAX]; + +public: + CharTypeTable(bool use_def_delim = true); + + void SetConfig(const TokenizeConfig &conf); + + CharType GetType(uint8_t c) { return char_type_table_[c]; } + + bool IsAllow(uint8_t c) { return char_type_table_[c] == ALLOW_CHR; } + + bool IsDivide(uint8_t c) { return char_type_table_[c] == DELIMITER_CHR; } + + bool IsUnite(uint8_t c) { return char_type_table_[c] == UNITE_CHR; } + + bool IsEqualType(uint8_t c1, uint8_t c2) { return char_type_table_[c1] == char_type_table_[c2]; } +}; + +class Tokenizer { +public: + Tokenizer(bool use_def_delim = true) : table_(use_def_delim) { output_buffer_ = std::make_unique(output_buffer_size_); } + + ~Tokenizer() {} + + /// \brief set the user defined char types + /// \param list char type option list + void SetConfig(const TokenizeConfig &conf); + + /// \brief tokenize the input text, call nextToken(), getToken(), getLength() to get the result. + /// \param input input text string + void Tokenize(const std::string &input); + + bool NextToken(); + + inline const char *GetToken() { return output_buffer_.get(); } + + inline size_t GetLength() { return output_buffer_cursor_; } + + inline bool IsDelimiter() { return is_delimiter_; } + + inline size_t GetTokenStartCursor() const { return token_start_cursor_; } + + inline size_t GetInputCursor() const { return input_cursor_; } + + bool Tokenize(const std::string &input_string, TermList &special_terms, TermList &prim_terms); + + /// \brief tokenize the input text, remove the space chars, output raw term list + bool TokenizeWhite(const std::string &input_string, TermList &raw_terms); + + /// \brief tokenize the input text, output two term lists: raw term list and primary term list + bool Tokenize(const std::string &input_string, TermList &prim_terms); + +private: + bool GrowOutputBuffer(); + +private: + CharTypeTable table_; + + std::string *input_{nullptr}; + + size_t token_start_cursor_{0}; + + size_t input_cursor_{0}; + + size_t output_buffer_size_{4096}; + + std::unique_ptr output_buffer_; + + size_t output_buffer_cursor_{0}; + + bool is_delimiter_{false}; +}; diff --git a/internal/cpp/util/logging.h b/internal/cpp/util/logging.h new file mode 100644 index 00000000000..787d68a956b --- /dev/null +++ b/internal/cpp/util/logging.h @@ -0,0 +1,111 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef UTIL_LOGGING_H_ +#define UTIL_LOGGING_H_ + +// Simplified version of Google's logging. + +#include +#include +#include +#include +#include +#include + +#include "util/util.h" + +// Debug-only checking. +#define DCHECK(condition) assert(condition) +#define DCHECK_EQ(val1, val2) assert((val1) == (val2)) +#define DCHECK_NE(val1, val2) assert((val1) != (val2)) +#define DCHECK_LE(val1, val2) assert((val1) <= (val2)) +#define DCHECK_LT(val1, val2) assert((val1) < (val2)) +#define DCHECK_GE(val1, val2) assert((val1) >= (val2)) +#define DCHECK_GT(val1, val2) assert((val1) > (val2)) + +// Always-on checking +#define CHECK(x) if(x){}else LogMessageFatal(__FILE__, __LINE__).stream() << "Check failed: " #x +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) + +#define LOG_INFO LogMessage(__FILE__, __LINE__) +#define LOG_WARNING LogMessage(__FILE__, __LINE__) +#define LOG_ERROR LogMessage(__FILE__, __LINE__) +#define LOG_FATAL LogMessageFatal(__FILE__, __LINE__) +#define LOG_QFATAL LOG_FATAL + +// It seems that one of the Windows header files defines ERROR as 0. +#ifdef _WIN32 +#define LOG_0 LOG_INFO +#endif + +#ifdef NDEBUG +#define LOG_DFATAL LOG_ERROR +#else +#define LOG_DFATAL LOG_FATAL +#endif + +#define LOG(severity) LOG_ ## severity.stream() + +#define VLOG(x) if((x)>0){}else LOG_INFO.stream() + +class LogMessage { + public: + LogMessage(const char* file, int line) + : flushed_(false) { +// stream() << file << ":" << line << ": "; + } + void Flush() { +// stream() << "\n"; +// std::string s = str_.str(); +// size_t n = s.size(); +// if (fwrite(s.data(), 1, n, stderr) < n) {} // shut up gcc +// flushed_ = true; + } + ~LogMessage() { + if (!flushed_) { + Flush(); + } + } + std::ostream& stream() { return str_; } + + private: + bool flushed_; + std::ostringstream str_; + + LogMessage(const LogMessage&) = delete; + LogMessage& operator=(const LogMessage&) = delete; +}; + +// Silence "destructor never returns" warning for ~LogMessageFatal(). +// Since this is a header file, push and then pop to limit the scope. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4722) +#endif + +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) + : LogMessage(file, line) { + throw std::runtime_error("RE2 Fatal Error"); + } + ~LogMessageFatal() { + Flush(); + } + private: + LogMessageFatal(const LogMessageFatal&) = delete; + LogMessageFatal& operator=(const LogMessageFatal&) = delete; +}; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif // UTIL_LOGGING_H_ diff --git a/internal/cpp/util/mix.h b/internal/cpp/util/mix.h new file mode 100644 index 00000000000..39539b4d75c --- /dev/null +++ b/internal/cpp/util/mix.h @@ -0,0 +1,41 @@ +// Copyright 2016 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef UTIL_MIX_H_ +#define UTIL_MIX_H_ + +#include +#include + +namespace re2 { + +// Silence "truncation of constant value" warning for kMul in 32-bit mode. +// Since this is a header file, push and then pop to limit the scope. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4309) +#endif + +class HashMix { +public: + HashMix() : hash_(1) {} + explicit HashMix(size_t val) : hash_(val + 83) {} + void Mix(size_t val) { + static const size_t kMul = static_cast(0xdc3eb94af8ab4c93ULL); + hash_ *= kMul; + hash_ = ((hash_ << 19) | (hash_ >> (std::numeric_limits::digits - 19))) + val; + } + size_t get() const { return hash_; } + +private: + size_t hash_; +}; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace re2 + +#endif // UTIL_MIX_H_ diff --git a/internal/cpp/util/mutex.h b/internal/cpp/util/mutex.h new file mode 100644 index 00000000000..de71839bf20 --- /dev/null +++ b/internal/cpp/util/mutex.h @@ -0,0 +1,169 @@ +// Copyright 2007 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef UTIL_MUTEX_H_ +#define UTIL_MUTEX_H_ + +/* + * A simple mutex wrapper, supporting locks and read-write locks. + * You should assume the locks are *not* re-entrant. + */ + +#ifdef RE2_NO_THREADS +#include +#define MUTEX_IS_LOCK_COUNTER +#else +#ifdef _WIN32 +// Requires Windows Vista or Windows Server 2008 at minimum. +#include +#if defined(WINVER) && WINVER >= 0x0600 +#define MUTEX_IS_WIN32_SRWLOCK +#endif +#else +#ifndef _POSIX_C_SOURCE +#define _POSIX_C_SOURCE 200809L +#endif +#include +#if defined(_POSIX_READER_WRITER_LOCKS) && _POSIX_READER_WRITER_LOCKS > 0 +#define MUTEX_IS_PTHREAD_RWLOCK +#endif +#endif +#endif + +#if defined(MUTEX_IS_LOCK_COUNTER) +typedef int MutexType; +#elif defined(MUTEX_IS_WIN32_SRWLOCK) +typedef SRWLOCK MutexType; +#elif defined(MUTEX_IS_PTHREAD_RWLOCK) +#include +#include +#include +typedef pthread_rwlock_t MutexType; +#else +#include +typedef std::shared_mutex MutexType; +#endif + +namespace re2 { + +class Mutex { +public: + inline Mutex(); + inline ~Mutex(); + inline void Lock(); // Block if needed until free then acquire exclusively + inline void Unlock(); // Release a lock acquired via Lock() + // Note that on systems that don't support read-write locks, these may + // be implemented as synonyms to Lock() and Unlock(). So you can use + // these for efficiency, but don't use them anyplace where being able + // to do shared reads is necessary to avoid deadlock. + inline void ReaderLock(); // Block until free or shared then acquire a share + inline void ReaderUnlock(); // Release a read share of this Mutex + inline void WriterLock() { Lock(); } // Acquire an exclusive lock + inline void WriterUnlock() { Unlock(); } // Release a lock from WriterLock() + +private: + MutexType mutex_; + + // Catch the error of writing Mutex when intending MutexLock. + Mutex(Mutex *ignored); + + Mutex(const Mutex &) = delete; + Mutex &operator=(const Mutex &) = delete; +}; + +#if defined(MUTEX_IS_LOCK_COUNTER) + +Mutex::Mutex() : mutex_(0) {} +Mutex::~Mutex() { assert(mutex_ == 0); } +void Mutex::Lock() { assert(--mutex_ == -1); } +void Mutex::Unlock() { assert(mutex_++ == -1); } +void Mutex::ReaderLock() { assert(++mutex_ > 0); } +void Mutex::ReaderUnlock() { assert(mutex_-- > 0); } + +#elif defined(MUTEX_IS_WIN32_SRWLOCK) + +Mutex::Mutex() : mutex_(SRWLOCK_INIT) {} +Mutex::~Mutex() {} +void Mutex::Lock() { AcquireSRWLockExclusive(&mutex_); } +void Mutex::Unlock() { ReleaseSRWLockExclusive(&mutex_); } +void Mutex::ReaderLock() { AcquireSRWLockShared(&mutex_); } +void Mutex::ReaderUnlock() { ReleaseSRWLockShared(&mutex_); } + +#elif defined(MUTEX_IS_PTHREAD_RWLOCK) + +#define SAFE_PTHREAD(fncall) \ + do { \ + if ((fncall) != 0) \ + throw std::runtime_error("RE2 pthread failure"); \ + } while (0); + +Mutex::Mutex() { SAFE_PTHREAD(pthread_rwlock_init(&mutex_, NULL)); } +Mutex::~Mutex() { pthread_rwlock_destroy(&mutex_); } +void Mutex::Lock() { SAFE_PTHREAD(pthread_rwlock_wrlock(&mutex_)); } +void Mutex::Unlock() { SAFE_PTHREAD(pthread_rwlock_unlock(&mutex_)); } +void Mutex::ReaderLock() { SAFE_PTHREAD(pthread_rwlock_rdlock(&mutex_)); } +void Mutex::ReaderUnlock() { SAFE_PTHREAD(pthread_rwlock_unlock(&mutex_)); } + +#undef SAFE_PTHREAD + +#else + +Mutex::Mutex() {} +Mutex::~Mutex() {} +void Mutex::Lock() { mutex_.lock(); } +void Mutex::Unlock() { mutex_.unlock(); } +void Mutex::ReaderLock() { mutex_.lock_shared(); } +void Mutex::ReaderUnlock() { mutex_.unlock_shared(); } + +#endif + +// -------------------------------------------------------------------------- +// Some helper classes + +// MutexLock(mu) acquires mu when constructed and releases it when destroyed. +class MutexLock { +public: + explicit MutexLock(Mutex *mu) : mu_(mu) { mu_->Lock(); } + ~MutexLock() { mu_->Unlock(); } + +private: + Mutex *const mu_; + + MutexLock(const MutexLock &) = delete; + MutexLock &operator=(const MutexLock &) = delete; +}; + +// ReaderMutexLock and WriterMutexLock do the same, for rwlocks +class ReaderMutexLock { +public: + explicit ReaderMutexLock(Mutex *mu) : mu_(mu) { mu_->ReaderLock(); } + ~ReaderMutexLock() { mu_->ReaderUnlock(); } + +private: + Mutex *const mu_; + + ReaderMutexLock(const ReaderMutexLock &) = delete; + ReaderMutexLock &operator=(const ReaderMutexLock &) = delete; +}; + +class WriterMutexLock { +public: + explicit WriterMutexLock(Mutex *mu) : mu_(mu) { mu_->WriterLock(); } + ~WriterMutexLock() { mu_->WriterUnlock(); } + +private: + Mutex *const mu_; + + WriterMutexLock(const WriterMutexLock &) = delete; + WriterMutexLock &operator=(const WriterMutexLock &) = delete; +}; + +// Catch bug where variable name is omitted, e.g. MutexLock (&mu); +#define MutexLock(x) static_assert(false, "MutexLock declaration missing variable name") +#define ReaderMutexLock(x) static_assert(false, "ReaderMutexLock declaration missing variable name") +#define WriterMutexLock(x) static_assert(false, "WriterMutexLock declaration missing variable name") + +} // namespace re2 + +#endif // UTIL_MUTEX_H_ diff --git a/internal/cpp/util/rune.cc b/internal/cpp/util/rune.cc new file mode 100644 index 00000000000..fa71d483ef2 --- /dev/null +++ b/internal/cpp/util/rune.cc @@ -0,0 +1,246 @@ +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ + +#include +#include + +#include "util/utf.h" + +namespace re2 { + +enum { + Bit1 = 7, + Bitx = 6, + Bit2 = 5, + Bit3 = 4, + Bit4 = 3, + Bit5 = 2, + + T1 = ((1 << (Bit1 + 1)) - 1) ^ 0xFF, /* 0000 0000 */ + Tx = ((1 << (Bitx + 1)) - 1) ^ 0xFF, /* 1000 0000 */ + T2 = ((1 << (Bit2 + 1)) - 1) ^ 0xFF, /* 1100 0000 */ + T3 = ((1 << (Bit3 + 1)) - 1) ^ 0xFF, /* 1110 0000 */ + T4 = ((1 << (Bit4 + 1)) - 1) ^ 0xFF, /* 1111 0000 */ + T5 = ((1 << (Bit5 + 1)) - 1) ^ 0xFF, /* 1111 1000 */ + + Rune1 = (1 << (Bit1 + 0 * Bitx)) - 1, /* 0000 0000 0111 1111 */ + Rune2 = (1 << (Bit2 + 1 * Bitx)) - 1, /* 0000 0111 1111 1111 */ + Rune3 = (1 << (Bit3 + 2 * Bitx)) - 1, /* 1111 1111 1111 1111 */ + Rune4 = (1 << (Bit4 + 3 * Bitx)) - 1, + /* 0001 1111 1111 1111 1111 1111 */ + + Maskx = (1 << Bitx) - 1, /* 0011 1111 */ + Testx = Maskx ^ 0xFF, /* 1100 0000 */ + + Bad = Runeerror, +}; + +int chartorune(Rune *rune, const char *str) { + int c, c1, c2, c3; + Rune l; + + /* + * one character sequence + * 00000-0007F => T1 + */ + c = *(unsigned char *)str; + if (c < Tx) { + *rune = c; + return 1; + } + + /* + * two character sequence + * 0080-07FF => T2 Tx + */ + c1 = *(unsigned char *)(str + 1) ^ Tx; + if (c1 & Testx) + goto bad; + if (c < T3) { + if (c < T2) + goto bad; + l = ((c << Bitx) | c1) & Rune2; + if (l <= Rune1) + goto bad; + *rune = l; + return 2; + } + + /* + * three character sequence + * 0800-FFFF => T3 Tx Tx + */ + c2 = *(unsigned char *)(str + 2) ^ Tx; + if (c2 & Testx) + goto bad; + if (c < T4) { + l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3; + if (l <= Rune2) + goto bad; + *rune = l; + return 3; + } + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + c3 = *(unsigned char *)(str + 3) ^ Tx; + if (c3 & Testx) + goto bad; + if (c < T5) { + l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4; + if (l <= Rune3) + goto bad; + *rune = l; + return 4; + } + + /* + * Support for 5-byte or longer UTF-8 would go here, but + * since we don't have that, we'll just fall through to bad. + */ + + /* + * bad decoding + */ +bad: + *rune = Bad; + return 1; +} + +int runetochar(char *str, const Rune *rune) { + /* Runes are signed, so convert to unsigned for range check. */ + unsigned int c; + + /* + * one character sequence + * 00000-0007F => 00-7F + */ + c = *rune; + if (c <= Rune1) { + str[0] = static_cast(c); + return 1; + } + + /* + * two character sequence + * 0080-07FF => T2 Tx + */ + if (c <= Rune2) { + str[0] = T2 | static_cast(c >> 1 * Bitx); + str[1] = Tx | (c & Maskx); + return 2; + } + + /* + * If the Rune is out of range, convert it to the error rune. + * Do this test here because the error rune encodes to three bytes. + * Doing it earlier would duplicate work, since an out of range + * Rune wouldn't have fit in one or two bytes. + */ + if (c > Runemax) + c = Runeerror; + + /* + * three character sequence + * 0800-FFFF => T3 Tx Tx + */ + if (c <= Rune3) { + str[0] = T3 | static_cast(c >> 2 * Bitx); + str[1] = Tx | ((c >> 1 * Bitx) & Maskx); + str[2] = Tx | (c & Maskx); + return 3; + } + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + str[0] = T4 | static_cast(c >> 3 * Bitx); + str[1] = Tx | ((c >> 2 * Bitx) & Maskx); + str[2] = Tx | ((c >> 1 * Bitx) & Maskx); + str[3] = Tx | (c & Maskx); + return 4; +} + +int runelen(Rune rune) { + char str[10]; + + return runetochar(str, &rune); +} + +int fullrune(const char *str, int n) { + if (n > 0) { + int c = *(unsigned char *)str; + if (c < Tx) + return 1; + if (n > 1) { + if (c < T3) + return 1; + if (n > 2) { + if (c < T4 || n > 3) + return 1; + } + } + } + return 0; +} + +int utflen(const char *s) { + int c; + int n; + Rune rune; + + n = 0; + for (;;) { + c = *(unsigned char *)s; + if (c < Runeself) { + if (c == 0) + return n; + s++; + } else + s += chartorune(&rune, s); + n++; + } + return 0; +} + +char *utfrune(const char *s, Rune c) { + int c1; + Rune r; + int n; + + if (c < Runesync) /* not part of utf sequence */ + return strchr((char *)s, c); + + for (;;) { + c1 = *(unsigned char *)s; + if (c1 < Runeself) { /* one byte rune */ + if (c1 == 0) + return 0; + if (c1 == c) + return (char *)s; + s++; + continue; + } + n = chartorune(&r, s); + if (r == c) + return (char *)s; + s += n; + } + return 0; +} + +} // namespace re2 diff --git a/internal/cpp/util/strutil.cc b/internal/cpp/util/strutil.cc new file mode 100644 index 00000000000..db11d3e7ce0 --- /dev/null +++ b/internal/cpp/util/strutil.cc @@ -0,0 +1,166 @@ +// Copyright 1999-2005 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include "util/strutil.h" + +#ifdef _WIN32 +#define snprintf _snprintf +#define vsnprintf _vsnprintf +#endif + +namespace re2 { + +// ---------------------------------------------------------------------- +// CEscapeString() +// Copies 'src' to 'dest', escaping dangerous characters using +// C-style escape sequences. 'src' and 'dest' should not overlap. +// Returns the number of bytes written to 'dest' (not including the \0) +// or (size_t)-1 if there was insufficient space. +// ---------------------------------------------------------------------- +static size_t CEscapeString(const char *src, size_t src_len, char *dest, size_t dest_len) { + const char *src_end = src + src_len; + size_t used = 0; + + for (; src < src_end; src++) { + if (dest_len - used < 2) // space for two-character escape + return (size_t)-1; + + unsigned char c = *src; + switch (c) { + case '\n': + dest[used++] = '\\'; + dest[used++] = 'n'; + break; + case '\r': + dest[used++] = '\\'; + dest[used++] = 'r'; + break; + case '\t': + dest[used++] = '\\'; + dest[used++] = 't'; + break; + case '\"': + dest[used++] = '\\'; + dest[used++] = '\"'; + break; + case '\'': + dest[used++] = '\\'; + dest[used++] = '\''; + break; + case '\\': + dest[used++] = '\\'; + dest[used++] = '\\'; + break; + default: + // Note that if we emit \xNN and the src character after that is a hex + // digit then that digit must be escaped too to prevent it being + // interpreted as part of the character code by C. + if (c < ' ' || c > '~') { + if (dest_len - used < 5) // space for four-character escape + \0 + return (size_t)-1; + snprintf(dest + used, 5, "\\%03o", c); + used += 4; + } else { + dest[used++] = c; + break; + } + } + } + + if (dest_len - used < 1) // make sure that there is room for \0 + return (size_t)-1; + + dest[used] = '\0'; // doesn't count towards return value though + return used; +} + +// ---------------------------------------------------------------------- +// CEscape() +// Copies 'src' to result, escaping dangerous characters using +// C-style escape sequences. 'src' and 'dest' should not overlap. +// ---------------------------------------------------------------------- +std::string CEscape(const StringPiece &src) { + const size_t dest_len = src.size() * 4 + 1; // Maximum possible expansion + char *dest = new char[dest_len]; + const size_t used = CEscapeString(src.data(), src.size(), dest, dest_len); + std::string s = std::string(dest, used); + delete[] dest; + return s; +} + +void PrefixSuccessor(std::string *prefix) { + // We can increment the last character in the string and be done + // unless that character is 255, in which case we have to erase the + // last character and increment the previous character, unless that + // is 255, etc. If the string is empty or consists entirely of + // 255's, we just return the empty string. + while (!prefix->empty()) { + char &c = prefix->back(); + if (c == '\xff') { // char literal avoids signed/unsigned. + prefix->pop_back(); + } else { + ++c; + break; + } + } +} + +static void StringAppendV(std::string *dst, const char *format, va_list ap) { + // First try with a small fixed size buffer + char space[1024]; + + // It's possible for methods that use a va_list to invalidate + // the data in it upon use. The fix is to make a copy + // of the structure before using it and use that copy instead. + va_list backup_ap; + va_copy(backup_ap, ap); + int result = vsnprintf(space, sizeof(space), format, backup_ap); + va_end(backup_ap); + + if ((result >= 0) && (static_cast(result) < sizeof(space))) { + // It fit + dst->append(space, result); + return; + } + + // Repeatedly increase buffer size until it fits + int length = sizeof(space); + while (true) { + if (result < 0) { + // Older behavior: just try doubling the buffer size + length *= 2; + } else { + // We need exactly "result+1" characters + length = result + 1; + } + char *buf = new char[length]; + + // Restore the va_list before we use it again + va_copy(backup_ap, ap); + result = vsnprintf(buf, length, format, backup_ap); + va_end(backup_ap); + + if ((result >= 0) && (result < length)) { + // It fit + dst->append(buf, result); + delete[] buf; + return; + } + delete[] buf; + } +} + +std::string StringPrintf(const char *format, ...) { + va_list ap; + va_start(ap, format); + std::string result; + StringAppendV(&result, format, ap); + va_end(ap); + return result; +} + +} // namespace re2 diff --git a/internal/cpp/util/strutil.h b/internal/cpp/util/strutil.h new file mode 100644 index 00000000000..6f44cf04a1c --- /dev/null +++ b/internal/cpp/util/strutil.h @@ -0,0 +1,21 @@ +// Copyright 2016 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef UTIL_STRUTIL_H_ +#define UTIL_STRUTIL_H_ + +#include + +#include "re2/stringpiece.h" +#include "util/util.h" + +namespace re2 { + +std::string CEscape(const StringPiece &src); +void PrefixSuccessor(std::string *prefix); +std::string StringPrintf(const char *format, ...); + +} // namespace re2 + +#endif // UTIL_STRUTIL_H_ diff --git a/internal/cpp/util/utf.h b/internal/cpp/util/utf.h new file mode 100644 index 00000000000..6c865a45e4f --- /dev/null +++ b/internal/cpp/util/utf.h @@ -0,0 +1,43 @@ +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + * + * This file and rune.cc have been converted to compile as C++ code + * in name space re2. + */ + +#ifndef UTIL_UTF_H_ +#define UTIL_UTF_H_ + +#include + +namespace re2 { + +typedef signed int Rune; /* Code-point values in Unicode 4.0 are 21 bits wide.*/ + +enum { + UTFmax = 4, /* maximum bytes per rune */ + Runesync = 0x80, /* cannot represent part of a UTF sequence (<) */ + Runeself = 0x80, /* rune and UTF sequences are the same (<) */ + Runeerror = 0xFFFD, /* decoding error in UTF */ + Runemax = 0x10FFFF, /* maximum rune value */ +}; + +int runetochar(char *s, const Rune *r); +int chartorune(Rune *r, const char *s); +int fullrune(const char *s, int n); +int utflen(const char *s); +char *utfrune(const char *, Rune); + +} // namespace re2 + +#endif // UTIL_UTF_H_ diff --git a/internal/cpp/util/util.h b/internal/cpp/util/util.h new file mode 100644 index 00000000000..d978414a719 --- /dev/null +++ b/internal/cpp/util/util.h @@ -0,0 +1,44 @@ +// Copyright 2009 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef UTIL_UTIL_H_ +#define UTIL_UTIL_H_ + +#define arraysize(array) (sizeof(array) / sizeof((array)[0])) + +#ifndef ATTRIBUTE_NORETURN +#if defined(__GNUC__) +#define ATTRIBUTE_NORETURN __attribute__((noreturn)) +#elif defined(_MSC_VER) +#define ATTRIBUTE_NORETURN __declspec(noreturn) +#else +#define ATTRIBUTE_NORETURN +#endif +#endif + +#ifndef ATTRIBUTE_UNUSED +#if defined(__GNUC__) +#define ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define ATTRIBUTE_UNUSED +#endif +#endif + +#ifndef FALLTHROUGH_INTENDED +#if defined(__clang__) +#define FALLTHROUGH_INTENDED [[clang::fallthrough]] +#elif defined(__GNUC__) && __GNUC__ >= 7 +#define FALLTHROUGH_INTENDED [[gnu::fallthrough]] +#else +#define FALLTHROUGH_INTENDED \ + do { \ + } while (0) +#endif +#endif + +#ifndef NO_THREAD_SAFETY_ANALYSIS +#define NO_THREAD_SAFETY_ANALYSIS +#endif + +#endif // UTIL_UTIL_H_ diff --git a/internal/cpp/wordnet_lemmatizer.cpp b/internal/cpp/wordnet_lemmatizer.cpp new file mode 100644 index 00000000000..673a008a015 --- /dev/null +++ b/internal/cpp/wordnet_lemmatizer.cpp @@ -0,0 +1,225 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wordnet_lemmatizer.h" +#include +#include + +namespace fs = std::filesystem; + +static const std::string ADJ = "a"; +static const std::string ADJ_SAT = "s"; +static const std::string ADV = "r"; +static const std::string NOUN = "n"; +static const std::string VERB = "v"; + +WordNetLemmatizer::WordNetLemmatizer(const std::string &wordnet_path) : wordnet_path_(wordnet_path) { Load(); } + +WordNetLemmatizer::~WordNetLemmatizer() = default; + +int32_t WordNetLemmatizer::Load() { + file_map_ = {{ADJ, "adj"}, {ADV, "adv"}, {NOUN, "noun"}, {VERB, "verb"}}; + + MORPHOLOGICAL_SUBSTITUTIONS = { + {NOUN, {{"s", ""}, {"ses", "s"}, {"ves", "f"}, {"xes", "x"}, {"zes", "z"}, {"ches", "ch"}, {"shes", "sh"}, {"men", "man"}, {"ies", "y"}}}, + {VERB, {{"s", ""}, {"ies", "y"}, {"es", "e"}, {"es", ""}, {"ed", "e"}, {"ed", ""}, {"ing", "e"}, {"ing", ""}}}, + {ADJ, {{"er", ""}, {"est", ""}, {"er", "e"}, {"est", "e"}}}, + {ADV, {}}, + {ADJ_SAT, {{"er", ""}, {"est", ""}, {"er", "e"}, {"est", "e"}}}}; + + POS_LIST = {NOUN, VERB, ADJ, ADV}; + + auto ret = LoadLemmas(); + if (ret != 0) { + return ret; + } + + LoadExceptions(); + // return Status::OK(); + return 0; +} + +int32_t WordNetLemmatizer::LoadLemmas() { + fs::path root(wordnet_path_); + for (const auto &pair : file_map_) { + const std::string &pos_abbrev = pair.first; + const std::string &pos_name = pair.second; + fs::path index_path(root / ("index." + pos_name)); + + std::ifstream file(index_path.string()); + if (!file.is_open()) { + return -1; + // return Status::InvalidAnalyzerFile(fmt::format("Failed to load WordNet lemmatizer, index.{}", pos_name)); + } + + std::string line; + + while (std::getline(file, line)) { + if (line.empty() || line[0] == ' ') { + continue; + } + + std::istringstream stream(line); + try { + std::string lemma; + stream >> lemma; + + if (lemmas_.find(lemma) == lemmas_.end()) { + lemmas_[lemma] = std::unordered_set(); + } + lemmas_[lemma].insert(pos_abbrev); + + if (pos_abbrev == ADJ) { + if (lemmas_.find(lemma) == lemmas_.end()) { + lemmas_[lemma] = std::unordered_set(); + } + lemmas_[lemma].insert(ADJ_SAT); + } + + } catch (const std::exception &e) { + return -1; + // return Status::InvalidAnalyzerFile("Failed to load WordNet lemmatizer lemmas"); + } + } + } + // return Status::OK(); + return 0; +} + +void WordNetLemmatizer::LoadExceptions() { + fs::path root(wordnet_path_); + for (const auto &pair : file_map_) { + const std::string &pos_abbrev = pair.first; + const std::string &pos_name = pair.second; + fs::path exc_path(root / (pos_name + ".exc")); + + std::ifstream file(exc_path.string()); + if (!file.is_open()) { + continue; + } + + exceptions_[pos_abbrev] = {}; + + std::string line; + while (std::getline(file, line)) { + std::istringstream stream(line); + std::string inflected_form; + stream >> inflected_form; + + std::vector base_forms; + std::string base_form; + while (stream >> base_form) { + base_forms.push_back(base_form); + } + + exceptions_[pos_abbrev][inflected_form] = base_forms; + } + } + exceptions_[ADJ_SAT] = exceptions_[ADJ]; +} + +std::vector WordNetLemmatizer::CollectSubstitutions(const std::vector &forms, const std::string &pos) { + const auto &substitutions = MORPHOLOGICAL_SUBSTITUTIONS.at(pos); + std::vector results; + + for (const auto &form : forms) { + for (const auto &[old_suffix, new_suffix] : substitutions) { + if (form.size() >= old_suffix.size() && form.compare(form.size() - old_suffix.size(), old_suffix.size(), old_suffix) == 0) { + results.push_back(form.substr(0, form.size() - old_suffix.size()) + new_suffix); + } + } + } + return results; +} + +std::vector WordNetLemmatizer::CollectSubstitutions(const std::string &form, const std::string &pos) { + const auto &substitutions = MORPHOLOGICAL_SUBSTITUTIONS.at(pos); + std::vector results; + + for (const auto &[old_suffix, new_suffix] : substitutions) { + if (form.size() >= old_suffix.size() && form.compare(form.size() - old_suffix.size(), old_suffix.size(), old_suffix) == 0) { + results.push_back(form.substr(0, form.size() - old_suffix.size()) + new_suffix); + } + } + return results; +} + +std::vector WordNetLemmatizer::FilterForms(const std::vector &forms, const std::string &pos) { + std::vector result; + std::unordered_set seen; + + for (const auto &form : forms) { + if (lemmas_.find(form) != lemmas_.end()) { + if (lemmas_[form].find(pos) != lemmas_[form].end()) { + if (seen.find(form) == seen.end()) { + result.push_back(form); + seen.insert(form); + } + } + } + } + return result; +} + +std::vector WordNetLemmatizer::Morphy(const std::string &form, const std::string &pos, bool check_exceptions) { + const auto &pos_exceptions = exceptions_.at(pos); + + // Check exceptions first + if (check_exceptions && pos_exceptions.find(form) != pos_exceptions.end()) { + std::vector forms = pos_exceptions.at(form); + forms.push_back(form); + return FilterForms(forms, pos); + } + + // Apply morphological rules with recursion (like Java version) + std::vector forms = CollectSubstitutions(form, pos); + std::vector combined_forms = forms; + combined_forms.push_back(form); + + // First attempt with original form and first-level substitutions + auto results = FilterForms(combined_forms, pos); + if (!results.empty()) { + return results; + } + + // Recursively apply rules (Java version's while loop) + while (!forms.empty()) { + forms = CollectSubstitutions(forms, pos); + results = FilterForms(forms, pos); + if (!results.empty()) { + return results; + } + } + + // Return empty result if no valid lemma found + return {}; +} + +std::string WordNetLemmatizer::Lemmatize(const std::string &form, const std::string &pos) { + std::vector parts_of_speech; + if (!pos.empty()) { + parts_of_speech.push_back(pos); + } else { + parts_of_speech = POS_LIST; + } + + for (const auto &part : parts_of_speech) { + auto analyses = Morphy(form, part); + if (!analyses.empty()) { + return analyses[0]; + } + } + + return form; +} \ No newline at end of file diff --git a/internal/cpp/wordnet_lemmatizer.h b/internal/cpp/wordnet_lemmatizer.h new file mode 100644 index 00000000000..d4e9c49b182 --- /dev/null +++ b/internal/cpp/wordnet_lemmatizer.h @@ -0,0 +1,52 @@ +// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +class WordNetLemmatizer { +public: + explicit + WordNetLemmatizer(const std::string &wordnet_path); + + ~WordNetLemmatizer(); + + int32_t Load(); + + std::string Lemmatize(const std::string &form, const std::string &pos = ""); + +private: + int32_t LoadLemmas(); + + void LoadExceptions(); + + std::vector Morphy(const std::string &form, const std::string &pos, bool check_exceptions = true); + + std::vector CollectSubstitutions(const std::vector &forms, const std::string &pos); + std::vector CollectSubstitutions(const std::string &form, const std::string &pos); + + std::vector FilterForms(const std::vector &forms, const std::string &pos); + + std::string wordnet_path_; + + std::unordered_map> lemmas_; + std::unordered_map>> exceptions_; + std::unordered_map>> MORPHOLOGICAL_SUBSTITUTIONS; + std::vector POS_LIST; + std::unordered_map file_map_; +}; diff --git a/internal/dao/chat.go b/internal/dao/chat.go new file mode 100644 index 00000000000..1500ea540a4 --- /dev/null +++ b/internal/dao/chat.go @@ -0,0 +1,212 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "fmt" + "strings" + + "ragflow/internal/model" +) + +// ChatDAO chat data access object +type ChatDAO struct{} + +// NewChatDAO create chat DAO +func NewChatDAO() *ChatDAO { + return &ChatDAO{} +} + +// ListByTenantID list chats by tenant ID +func (dao *ChatDAO) ListByTenantID(tenantID string, status string) ([]*model.Chat, error) { + var chats []*model.Chat + + query := DB.Model(&model.Chat{}). + Where("tenant_id = ?", tenantID) + + if status != "" { + query = query.Where("status = ?", status) + } + + // Order by create_time desc + if err := query.Order("create_time DESC").Find(&chats).Error; err != nil { + return nil, err + } + + return chats, nil +} + +// ListByTenantIDs list chats by tenant IDs with pagination and filtering +func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.Chat, int64, error) { + var chats []*model.Chat + var total int64 + + // Build query with join to user table for nickname and avatar + query := DB.Model(&model.Chat{}). + Select(` + dialog.*, + user.nickname, + user.avatar as tenant_avatar + `). + Joins("LEFT JOIN user ON dialog.tenant_id = user.id"). + Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1") + + // Apply keyword filter + if keywords != "" { + query = query.Where("LOWER(dialog.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + // Apply ordering + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderby + " " + orderDirection) + + // Count total + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // Apply pagination + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + if err := query.Offset(offset).Limit(pageSize).Find(&chats).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Find(&chats).Error; err != nil { + return nil, 0, err + } + } + + return chats, total, nil +} + +// ListByOwnerIDs list chats by owner IDs with filtering (manual pagination) +func (dao *ChatDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*model.Chat, int64, error) { + var chats []*model.Chat + + // Build query with join to user table + query := DB.Model(&model.Chat{}). + Select(` + dialog.*, + user.nickname, + user.avatar as tenant_avatar + `). + Joins("LEFT JOIN user ON dialog.tenant_id = user.id"). + Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", ownerIDs, userID, "1") + + // Apply keyword filter + if keywords != "" { + query = query.Where("LOWER(dialog.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + // Filter by owner IDs (additional filter to ensure tenant_id is in ownerIDs) + query = query.Where("dialog.tenant_id IN ?", ownerIDs) + + // Apply ordering + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderby + " " + orderDirection) + + // Get all matching records + if err := query.Find(&chats).Error; err != nil { + return nil, 0, err + } + + total := int64(len(chats)) + + return chats, total, nil +} + +// GetByID gets chat by ID +func (dao *ChatDAO) GetByID(id string) (*model.Chat, error) { + var chat model.Chat + err := DB.Where("id = ?", id).First(&chat).Error + if err != nil { + return nil, err + } + return &chat, nil +} + +// GetByIDAndStatus gets chat by ID and status +func (dao *ChatDAO) GetByIDAndStatus(id string, status string) (*model.Chat, error) { + var chat model.Chat + err := DB.Where("id = ? AND status = ?", id, status).First(&chat).Error + if err != nil { + return nil, err + } + return &chat, nil +} + +// GetExistingNames gets existing dialog names for a tenant +func (dao *ChatDAO) GetExistingNames(tenantID string, status string) ([]string, error) { + var names []string + err := DB.Model(&model.Chat{}). + Where("tenant_id = ? AND status = ?", tenantID, status). + Pluck("name", &names).Error + return names, err +} + +// Create creates a new chat/dialog +func (dao *ChatDAO) Create(chat *model.Chat) error { + return DB.Create(chat).Error +} + +// UpdateByID updates a chat by ID +func (dao *ChatDAO) UpdateByID(id string, updates map[string]interface{}) error { + return DB.Model(&model.Chat{}).Where("id = ?", id).Updates(updates).Error +} + +// UpdateManyByID updates multiple chats by ID (batch update) +func (dao *ChatDAO) UpdateManyByID(updates []map[string]interface{}) error { + if len(updates) == 0 { + return nil + } + + // Use transaction for batch update + tx := DB.Begin() + if tx.Error != nil { + return tx.Error + } + + for _, update := range updates { + id, ok := update["id"].(string) + if !ok { + tx.Rollback() + return fmt.Errorf("invalid id in update") + } + + // Remove id from updates map + updatesWithoutID := make(map[string]interface{}) + for k, v := range update { + if k != "id" { + updatesWithoutID[k] = v + } + } + + if err := tx.Model(&model.Chat{}).Where("id = ?", id).Updates(updatesWithoutID).Error; err != nil { + tx.Rollback() + return err + } + } + + return tx.Commit().Error +} diff --git a/internal/dao/chat_session.go b/internal/dao/chat_session.go new file mode 100644 index 00000000000..f728b7ca8db --- /dev/null +++ b/internal/dao/chat_session.go @@ -0,0 +1,85 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// ChatSessionDAO chat session data access object +type ChatSessionDAO struct{} + +// NewChatSessionDAO create chat session DAO +func NewChatSessionDAO() *ChatSessionDAO { + return &ChatSessionDAO{} +} + +// GetByID gets chat session by ID +func (dao *ChatSessionDAO) GetByID(id string) (*model.ChatSession, error) { + var conv model.ChatSession + err := DB.Where("id = ?", id).First(&conv).Error + if err != nil { + return nil, err + } + return &conv, nil +} + +// Create creates a new chat session +func (dao *ChatSessionDAO) Create(conv *model.ChatSession) error { + return DB.Create(conv).Error +} + +// UpdateByID updates a chat session by ID +func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error { + return DB.Model(&model.ChatSession{}).Where("id = ?", id).Updates(updates).Error +} + +// DeleteByID deletes a chat session by ID (hard delete) +func (dao *ChatSessionDAO) DeleteByID(id string) error { + return DB.Where("id = ?", id).Delete(&model.ChatSession{}).Error +} + +// ListByDialogID lists chat sessions by dialog ID +func (dao *ChatSessionDAO) ListByDialogID(dialogID string) ([]*model.ChatSession, error) { + var convs []*model.ChatSession + err := DB.Where("dialog_id = ?", dialogID). + Order("create_time DESC"). + Find(&convs).Error + return convs, err +} + +// CheckDialogExists checks if a dialog exists with given tenant_id and dialog_id +func (dao *ChatSessionDAO) CheckDialogExists(tenantID, dialogID string) (bool, error) { + var count int64 + err := DB.Model(&model.Chat{}). + Where("tenant_id = ? AND id = ? AND status = ?", tenantID, dialogID, "1"). + Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +// GetDialogByID gets dialog by ID +func (dao *ChatSessionDAO) GetDialogByID(dialogID string) (*model.Chat, error) { + var dialog model.Chat + err := DB.Where("id = ? AND status = ?", dialogID, "1").First(&dialog).Error + if err != nil { + return nil, err + } + return &dialog, nil +} diff --git a/internal/dao/connector.go b/internal/dao/connector.go new file mode 100644 index 00000000000..f8d0c9555ad --- /dev/null +++ b/internal/dao/connector.go @@ -0,0 +1,79 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// ConnectorDAO connector data access object +type ConnectorDAO struct{} + +// NewConnectorDAO create connector DAO +func NewConnectorDAO() *ConnectorDAO { + return &ConnectorDAO{} +} + +// ConnectorListItem connector list item (subset of fields) +type ConnectorListItem struct { + ID string `json:"id"` + Name string `json:"name"` + Source string `json:"source"` + Status string `json:"status"` +} + +// ListByTenantID list connectors by tenant ID +// Only selects id, name, source, status fields (matching Python implementation) +func (dao *ConnectorDAO) ListByTenantID(tenantID string) ([]*ConnectorListItem, error) { + var connectors []*ConnectorListItem + + err := DB.Model(&model.Connector{}). + Select("id", "name", "source", "status"). + Where("tenant_id = ?", tenantID). + Find(&connectors).Error + + if err != nil { + return nil, err + } + + return connectors, nil +} + +// GetByID get connector by ID +func (dao *ConnectorDAO) GetByID(id string) (*model.Connector, error) { + var connector model.Connector + err := DB.Where("id = ?", id).First(&connector).Error + if err != nil { + return nil, err + } + return &connector, nil +} + +// Create create a new connector +func (dao *ConnectorDAO) Create(connector *model.Connector) error { + return DB.Create(connector).Error +} + +// UpdateByID update connector by ID +func (dao *ConnectorDAO) UpdateByID(id string, updates map[string]interface{}) error { + return DB.Model(&model.Connector{}).Where("id = ?", id).Updates(updates).Error +} + +// DeleteByID delete connector by ID +func (dao *ConnectorDAO) DeleteByID(id string) error { + return DB.Where("id = ?", id).Delete(&model.Connector{}).Error +} diff --git a/internal/dao/database.go b/internal/dao/database.go new file mode 100644 index 00000000000..b0d0e1ee5ed --- /dev/null +++ b/internal/dao/database.go @@ -0,0 +1,91 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "fmt" + "ragflow/internal/server" + "time" + + gormLogger "gorm.io/gorm/logger" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + + "ragflow/internal/logger" +) + +var DB *gorm.DB + +// InitDB initialize database connection +func InitDB() error { + cfg := server.GetConfig() + dbCfg := cfg.Database + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", + dbCfg.Username, + dbCfg.Password, + dbCfg.Host, + dbCfg.Port, + dbCfg.Database, + dbCfg.Charset, + ) + + // Set log level + var gormLogLevel gormLogger.LogLevel + if cfg.Server.Mode == "debug" { + gormLogLevel = gormLogger.Info + } else { + gormLogLevel = gormLogger.Silent + } + + // Connect to database + var err error + DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: gormLogger.Default.LogMode(gormLogLevel), + NowFunc: func() time.Time { + return time.Now().Local() + }, + }) + if err != nil { + return fmt.Errorf("failed to connect database: %w", err) + } + + // Get general database object sql.DB + sqlDB, err := DB.DB() + if err != nil { + return fmt.Errorf("failed to get database instance: %w", err) + } + + // Set connection pool + sqlDB.SetMaxIdleConns(10) + sqlDB.SetMaxOpenConns(100) + sqlDB.SetConnMaxLifetime(time.Hour) + + // Auto migrate + //if err := DB.AutoMigrate(&model.User{}, &model.Document{}); err != nil { + // return fmt.Errorf("failed to migrate database: %w", err) + //} + + logger.Info("Database connected and migrated successfully") + return nil +} + +// GetDB get database instance +func GetDB() *gorm.DB { + return DB +} diff --git a/internal/dao/document.go b/internal/dao/document.go new file mode 100644 index 00000000000..49bdb51edfc --- /dev/null +++ b/internal/dao/document.go @@ -0,0 +1,81 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// DocumentDAO document data access object +type DocumentDAO struct{} + +// NewDocumentDAO create document DAO +func NewDocumentDAO() *DocumentDAO { + return &DocumentDAO{} +} + +// Create create document +func (dao *DocumentDAO) Create(document *model.Document) error { + return DB.Create(document).Error +} + +// GetByID get document by ID +func (dao *DocumentDAO) GetByID(id string) (*model.Document, error) { + var document model.Document + err := DB.Preload("Author").First(&document, "id = ?", id).Error + if err != nil { + return nil, err + } + return &document, nil +} + +// GetByAuthorID get documents by author ID +func (dao *DocumentDAO) GetByAuthorID(authorID string, offset, limit int) ([]*model.Document, int64, error) { + var documents []*model.Document + var total int64 + + query := DB.Model(&model.Document{}).Where("created_by = ?", authorID) + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Preload("Author").Offset(offset).Limit(limit).Find(&documents).Error + return documents, total, err +} + +// Update update document +func (dao *DocumentDAO) Update(document *model.Document) error { + return DB.Save(document).Error +} + +// Delete delete document +func (dao *DocumentDAO) Delete(id string) error { + return DB.Delete(&model.Document{}, "id = ?", id).Error +} + +// List list documents +func (dao *DocumentDAO) List(offset, limit int) ([]*model.Document, int64, error) { + var documents []*model.Document + var total int64 + + if err := DB.Model(&model.Document{}).Count(&total).Error; err != nil { + return nil, 0, err + } + + err := DB.Preload("Author").Offset(offset).Limit(limit).Find(&documents).Error + return documents, total, err +} diff --git a/internal/dao/file.go b/internal/dao/file.go new file mode 100644 index 00000000000..bbf9a660989 --- /dev/null +++ b/internal/dao/file.go @@ -0,0 +1,202 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "strings" + + "github.com/google/uuid" + + "ragflow/internal/model" +) + +// FileDAO file data access object +type FileDAO struct{} + +// NewFileDAO create file DAO +func NewFileDAO() *FileDAO { + return &FileDAO{} +} + +// GetByID gets file by ID +func (dao *FileDAO) GetByID(id string) (*model.File, error) { + var file model.File + err := DB.Where("id = ?", id).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +// GetByPfID gets files by parent folder ID with pagination and filtering +func (dao *FileDAO) GetByPfID(tenantID, pfID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.File, int64, error) { + var files []*model.File + var total int64 + + query := DB.Model(&model.File{}). + Where("tenant_id = ? AND parent_id = ? AND id != ?", tenantID, pfID, pfID) + + // Apply keyword filter + if keywords != "" { + query = query.Where("LOWER(name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + // Count total + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // Apply ordering + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderby + " " + orderDirection) + + // Apply pagination + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + if err := query.Offset(offset).Limit(pageSize).Find(&files).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Find(&files).Error; err != nil { + return nil, 0, err + } + } + + return files, total, nil +} + +// GetRootFolder gets or creates root folder for tenant +func (dao *FileDAO) GetRootFolder(tenantID string) (*model.File, error) { + var file model.File + err := DB.Where("tenant_id = ? AND parent_id = id", tenantID).First(&file).Error + if err == nil { + return &file, nil + } + + // Create root folder if not exists + fileID := generateUUID() + file = model.File{ + ID: fileID, + ParentID: fileID, + TenantID: tenantID, + CreatedBy: tenantID, + Name: "/", + Type: "folder", + Size: 0, + } + file.SourceType = "" + + if err := DB.Create(&file).Error; err != nil { + return nil, err + } + return &file, nil +} + +// GetParentFolder gets parent folder of a file +func (dao *FileDAO) GetParentFolder(fileID string) (*model.File, error) { + var file model.File + err := DB.Where("id = ?", fileID).First(&file).Error + if err != nil { + return nil, err + } + + var parentFile model.File + err = DB.Where("id = ?", file.ParentID).First(&parentFile).Error + if err != nil { + return nil, err + } + return &parentFile, nil +} + +// ListByParentID lists all files by parent ID (including subfolders) +func (dao *FileDAO) ListByParentID(parentID string) ([]*model.File, error) { + var files []*model.File + err := DB.Where("parent_id = ? AND id != ?", parentID, parentID).Find(&files).Error + return files, err +} + +// GetFolderSize calculates folder size recursively +func (dao *FileDAO) GetFolderSize(folderID string) (int64, error) { + var size int64 + + var dfs func(parentID string) error + dfs = func(parentID string) error { + var files []*model.File + if err := DB.Select("id", "size", "type"). + Where("parent_id = ? AND id != ?", parentID, parentID). + Find(&files).Error; err != nil { + return err + } + + for _, f := range files { + size += f.Size + if f.Type == "folder" { + if err := dfs(f.ID); err != nil { + return err + } + } + } + return nil + } + + if err := dfs(folderID); err != nil { + return 0, err + } + return size, nil +} + +// HasChildFolder checks if folder has child folders +func (dao *FileDAO) HasChildFolder(folderID string) (bool, error) { + var count int64 + err := DB.Model(&model.File{}). + Where("parent_id = ? AND id != ? AND type = ?", folderID, folderID, "folder"). + Count(&count).Error + return count > 0, err +} + +// GetAllParentFolders gets all parent folders in path (from current to root) +func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) { + var parentFolders []*model.File + currentID := startID + + for currentID != "" { + var file model.File + err := DB.Where("id = ?", currentID).First(&file).Error + if err != nil { + return nil, err + } + + parentFolders = append(parentFolders, &file) + + // Stop if we've reached the root folder (parent_id == id) + if file.ParentID == file.ID { + break + } + currentID = file.ParentID + } + + return parentFolders, nil +} + +// generateUUID generates a UUID +func generateUUID() string { + id := uuid.New().String() + return strings.ReplaceAll(id, "-", "") +} diff --git a/internal/dao/file2document.go b/internal/dao/file2document.go new file mode 100644 index 00000000000..81ce813e320 --- /dev/null +++ b/internal/dao/file2document.go @@ -0,0 +1,60 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// File2DocumentDAO file to document mapping data access object +type File2DocumentDAO struct{} + +// NewFile2DocumentDAO create file2document DAO +func NewFile2DocumentDAO() *File2DocumentDAO { + return &File2DocumentDAO{} +} + +// GetKBInfoByFileID gets knowledge base info by file ID +func (dao *File2DocumentDAO) GetKBInfoByFileID(fileID string) ([]map[string]interface{}, error) { + var results []map[string]interface{} + + rows, err := DB.Model(&model.File{}). + Select("knowledgebase.id, knowledgebase.name, file2document.document_id"). + Joins("JOIN file2document ON file2document.file_id = ?", fileID). + Joins("JOIN document ON document.id = file2document.document_id"). + Joins("JOIN knowledgebase ON knowledgebase.id = document.kb_id"). + Where("file.id = ?", fileID). + Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var kbID, kbName, docID string + if err := rows.Scan(&kbID, &kbName, &docID); err != nil { + continue + } + results = append(results, map[string]interface{}{ + "kb_id": kbID, + "kb_name": kbName, + "document_id": docID, + }) + } + + return results, nil +} diff --git a/internal/dao/kb.go b/internal/dao/kb.go new file mode 100644 index 00000000000..cf36e1a7e61 --- /dev/null +++ b/internal/dao/kb.go @@ -0,0 +1,149 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" + "strings" +) + +// KnowledgebaseDAO knowledge base data access object +type KnowledgebaseDAO struct{} + +// NewKnowledgebaseDAO create knowledge base DAO +func NewKnowledgebaseDAO() *KnowledgebaseDAO { + return &KnowledgebaseDAO{} +} + +// ListByTenantIDs list knowledge bases by tenant IDs +func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { + var kbs []*model.Knowledgebase + var total int64 + + query := DB.Model(&model.Knowledgebase{}). + Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). + Where("(knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?", tenantIDs, "team", userID). + Where("knowledgebase.status = ?", "1") + + if keywords != "" { + query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + if parserID != "" { + query = query.Where("knowledgebase.parser_id = ?", parserID) + } + + // Order + if desc { + query = query.Order(orderby + " DESC") + } else { + query = query.Order(orderby + " ASC") + } + + // Count + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // Pagination + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + if err := query.Offset(offset).Limit(pageSize).Find(&kbs).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Find(&kbs).Error; err != nil { + return nil, 0, err + } + } + + return kbs, total, nil +} + +// ListByOwnerIDs list knowledge bases by owner IDs +func (dao *KnowledgebaseDAO) ListByOwnerIDs(ownerIDs []string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { + var kbs []*model.Knowledgebase + + query := DB.Model(&model.Knowledgebase{}). + Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). + Where("knowledgebase.tenant_id IN ?", ownerIDs). + Where("knowledgebase.status = ?", "1") + + if keywords != "" { + query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + if parserID != "" { + query = query.Where("knowledgebase.parser_id = ?", parserID) + } + + // Order + if desc { + query = query.Order(orderby + " DESC") + } else { + query = query.Order(orderby + " ASC") + } + + if err := query.Find(&kbs).Error; err != nil { + return nil, 0, err + } + + total := int64(len(kbs)) + + // Manual pagination + if page > 0 && pageSize > 0 { + start := (page - 1) * pageSize + end := start + pageSize + if end > int(total) { + end = int(total) + } + if start < end { + kbs = kbs[start:end] + } else { + kbs = []*model.Knowledgebase{} + } + } + + return kbs, total, nil +} + +// GetByID gets knowledge base by ID +func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND status = ?", id, "1").First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDAndTenantID gets knowledge base by ID and tenant ID +func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, "1").First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDs gets knowledge bases by IDs +func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Where("id IN ? AND status = ?", ids, "1").Find(&kbs).Error + return kbs, err +} diff --git a/internal/dao/llm.go b/internal/dao/llm.go new file mode 100644 index 00000000000..44590ca9dcf --- /dev/null +++ b/internal/dao/llm.go @@ -0,0 +1,69 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// LLMDAO LLM data access object +type LLMDAO struct{} + +// NewLLMDAO create LLM DAO +func NewLLMDAO() *LLMDAO { + return &LLMDAO{} +} + +// GetAll gets all LLMs +func (dao *LLMDAO) GetAll() ([]*model.LLM, error) { + var llms []*model.LLM + err := DB.Find(&llms).Error + if err != nil { + return nil, err + } + return llms, nil +} + +// GetAllValid gets all valid LLMs +func (dao *LLMDAO) GetAllValid() ([]*model.LLM, error) { + var llms []*model.LLM + err := DB.Where("status = ?", "1").Find(&llms).Error + if err != nil { + return nil, err + } + return llms, nil +} + +// GetByFactory gets LLMs by factory +func (dao *LLMDAO) GetByFactory(factory string) ([]*model.LLM, error) { + var llms []*model.LLM + err := DB.Where("fid = ?", factory).Find(&llms).Error + if err != nil { + return nil, err + } + return llms, nil +} + +// GetByFactoryAndName gets LLM by factory and name +func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error) { + var llm model.LLM + err := DB.Where("fid = ? AND llm_name = ?", factory, name).First(&llm).Error + if err != nil { + return nil, err + } + return &llm, nil +} diff --git a/internal/dao/model_provider.go b/internal/dao/model_provider.go new file mode 100644 index 00000000000..83e8bc80cd2 --- /dev/null +++ b/internal/dao/model_provider.go @@ -0,0 +1,123 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/server" + "sync" +) + +// ModelProviderDAO provides access to model provider configuration data +type ModelProviderDAO struct{} + +var ( + modelProviderDAOInstance *ModelProviderDAO + modelProviderDAOOnce sync.Once +) + +// NewModelProviderDAO creates a new ModelProviderDAO instance (singleton) +func NewModelProviderDAO() *ModelProviderDAO { + modelProviderDAOOnce.Do(func() { + modelProviderDAOInstance = &ModelProviderDAO{} + }) + return modelProviderDAOInstance +} + +// GetAllProviders returns all model providers +func (dao *ModelProviderDAO) GetAllProviders() []server.ModelProvider { + return server.GetModelProviders() +} + +// GetProviderByName returns the model provider with the given name +func (dao *ModelProviderDAO) GetProviderByName(name string) *server.ModelProvider { + return server.GetModelProviderByName(name) +} + +// GetLLMByProviderAndName returns the LLM with the given provider name and model name +func (dao *ModelProviderDAO) GetLLMByProviderAndName(providerName, modelName string) *server.LLM { + return server.GetLLMByProviderAndName(providerName, modelName) +} + +// GetLLMsByType returns all LLMs across all providers that match the given model type +func (dao *ModelProviderDAO) GetLLMsByType(modelType string) []server.LLM { + var result []server.LLM + for _, provider := range server.GetModelProviders() { + for _, llm := range provider.LLMs { + if llm.ModelType == modelType { + result = append(result, llm) + } + } + } + return result +} + +// GetProvidersByTag returns providers that have the given tag in their tags string +func (dao *ModelProviderDAO) GetProvidersByTag(tag string) []server.ModelProvider { + var result []server.ModelProvider + for _, provider := range server.GetModelProviders() { + if containsTag(provider.Tags, tag) { + result = append(result, provider) + } + } + return result +} + +// GetLLMsByProviderAndType returns LLMs for a specific provider that match the given model type +func (dao *ModelProviderDAO) GetLLMsByProviderAndType(providerName, modelType string) []server.LLM { + provider := server.GetModelProviderByName(providerName) + if provider == nil { + return nil + } + var result []server.LLM + for _, llm := range provider.LLMs { + if llm.ModelType == modelType { + result = append(result, llm) + } + } + return result +} + +// helper function to check if a comma-separated tag string contains a specific tag +func containsTag(tags, tag string) bool { + // Simple implementation: check substring with boundaries + // Assuming tags are uppercase and comma-separated without spaces + // This may need refinement based on actual tag format + for _, t := range splitTags(tags) { + if t == tag { + return true + } + } + return false +} + +func splitTags(tags string) []string { + // Split by comma and trim spaces + var result []string + start := 0 + for i, ch := range tags { + if ch == ',' { + if start < i { + result = append(result, tags[start:i]) + } + start = i + 1 + } + } + if start < len(tags) { + result = append(result, tags[start:]) + } + return result +} diff --git a/internal/dao/search.go b/internal/dao/search.go new file mode 100644 index 00000000000..5cdcd44225d --- /dev/null +++ b/internal/dao/search.go @@ -0,0 +1,127 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "strings" + + "ragflow/internal/model" +) + +// SearchDAO search data access object +type SearchDAO struct{} + +// NewSearchDAO create search DAO +func NewSearchDAO() *SearchDAO { + return &SearchDAO{} +} + +// ListByTenantIDs list searches by tenant IDs with pagination and filtering +func (dao *SearchDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string) ([]*model.Search, int64, error) { + var searches []*model.Search + var total int64 + + // Build query with join to user table for nickname and avatar + query := DB.Model(&model.Search{}). + Select(` + search.*, + user.nickname, + user.avatar as tenant_avatar + `). + Joins("LEFT JOIN user ON search.tenant_id = user.id"). + Where("(search.tenant_id IN ? OR search.tenant_id = ?) AND search.status = ?", tenantIDs, userID, "1") + + // Apply keyword filter + if keywords != "" { + query = query.Where("LOWER(search.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + // Apply ordering + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderby + " " + orderDirection) + + // Count total + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // Apply pagination + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + if err := query.Offset(offset).Limit(pageSize).Find(&searches).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Find(&searches).Error; err != nil { + return nil, 0, err + } + } + + return searches, total, nil +} + +// ListByOwnerIDs list searches by owner IDs with filtering (manual pagination) +func (dao *SearchDAO) ListByOwnerIDs(ownerIDs []string, userID string, orderby string, desc bool, keywords string) ([]*model.Search, int64, error) { + var searches []*model.Search + + // Build query with join to user table + query := DB.Model(&model.Search{}). + Select(` + search.*, + user.nickname, + user.avatar as tenant_avatar + `). + Joins("LEFT JOIN user ON search.tenant_id = user.id"). + Where("(search.tenant_id IN ? OR search.tenant_id = ?) AND search.status = ?", ownerIDs, userID, "1") + + // Apply keyword filter + if keywords != "" { + query = query.Where("LOWER(search.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + // Filter by owner IDs (additional filter to ensure tenant_id is in ownerIDs) + query = query.Where("search.tenant_id IN ?", ownerIDs) + + // Apply ordering + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderby + " " + orderDirection) + + // Get all matching records + if err := query.Find(&searches).Error; err != nil { + return nil, 0, err + } + + total := int64(len(searches)) + + return searches, total, nil +} + +// GetByID gets search by ID +func (dao *SearchDAO) GetByID(id string) (*model.Search, error) { + var search model.Search + err := DB.Where("id = ?", id).First(&search).Error + if err != nil { + return nil, err + } + return &search, nil +} diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go new file mode 100644 index 00000000000..c992b1a7429 --- /dev/null +++ b/internal/dao/tenant.go @@ -0,0 +1,90 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// TenantDAO tenant data access object +type TenantDAO struct{} + +// NewTenantDAO create tenant DAO +func NewTenantDAO() *TenantDAO { + return &TenantDAO{} +} + +// GetJoinedTenantsByUserID get joined tenants by user ID +func (dao *TenantDAO) GetJoinedTenantsByUserID(userID string) ([]*TenantWithRole, error) { + var results []*TenantWithRole + + err := DB.Model(&model.Tenant{}). + Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.asr_id, tenant.img2txt_id, user_tenant.role"). + Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id"). + Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "normal", "1"). + Scan(&results).Error + + return results, err +} + +// TenantWithRole tenant with role information +type TenantWithRole struct { + TenantID string `gorm:"column:tenant_id" json:"tenant_id"` + Name string `gorm:"column:name" json:"name"` + LLMID string `gorm:"column:llm_id" json:"llm_id"` + EmbDID string `gorm:"column:embd_id" json:"embd_id"` + ASRID string `gorm:"column:asr_id" json:"asr_id"` + Img2TxtID string `gorm:"column:img2txt_id" json:"img2txt_id"` + Role string `gorm:"column:role" json:"role"` +} + +// TenantInfo tenant information with role (for owner tenant) +type TenantInfo struct { + TenantID string `gorm:"column:tenant_id" json:"tenant_id"` + Name *string `gorm:"column:name" json:"name,omitempty"` + LLMID string `gorm:"column:llm_id" json:"llm_id"` + EmbDID string `gorm:"column:embd_id" json:"embd_id"` + RerankID string `gorm:"column:rerank_id" json:"rerank_id"` + ASRID string `gorm:"column:asr_id" json:"asr_id"` + Img2TxtID string `gorm:"column:img2txt_id" json:"img2txt_id"` + TTSID *string `gorm:"column:tts_id" json:"tts_id,omitempty"` + ParserIDs string `gorm:"column:parser_ids" json:"parser_ids"` + Role string `gorm:"column:role" json:"role"` +} + +// GetInfoByUserID get tenant information for the owner tenant of a user +func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) { + var results []*TenantInfo + + err := DB.Model(&model.Tenant{}). + Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.parser_ids, user_tenant.role"). + Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id"). + Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1"). + Scan(&results).Error + + return results, err +} + +// GetByID gets tenant by ID +func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) { + var tenant model.Tenant + err := DB.Where("id = ? AND status = ?", id, "1").First(&tenant).Error + if err != nil { + return nil, err + } + return &tenant, nil +} diff --git a/internal/dao/tenant_llm.go b/internal/dao/tenant_llm.go new file mode 100644 index 00000000000..8752e041fa9 --- /dev/null +++ b/internal/dao/tenant_llm.go @@ -0,0 +1,136 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// TenantLLMDAO tenant LLM data access object +type TenantLLMDAO struct{} + +// NewTenantLLMDAO create tenant LLM DAO +func NewTenantLLMDAO() *TenantLLMDAO { + return &TenantLLMDAO{} +} + +// GetByTenantAndModelName get tenant LLM by tenant ID and model name +func (dao *TenantLLMDAO) GetByTenantAndModelName(tenantID, providerName string, modelName string) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + err := DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, providerName, modelName).First(&tenantLLM).Error + if err != nil { + return nil, err + } + return &tenantLLM, nil +} + +// GetByTenantAndType get tenant LLM by tenant ID and model type +func (dao *TenantLLMDAO) GetByTenantAndType(tenantID string, modelType model.ModelType) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + err := DB.Where("tenant_id = ? AND model_type = ?", tenantID, modelType).First(&tenantLLM).Error + if err != nil { + return nil, err + } + return &tenantLLM, nil +} + +// GetByTenantAndFactory get tenant LLM by tenant ID, model type and factory +func (dao *TenantLLMDAO) GetByTenantAndFactory(tenantID string, modelType model.ModelType, factory string) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + err := DB.Where("tenant_id = ? AND model_type = ? AND llm_factory = ?", tenantID, modelType, factory).First(&tenantLLM).Error + if err != nil { + return nil, err + } + return &tenantLLM, nil +} + +// ListByTenant list all tenant LLMs for a tenant +func (dao *TenantLLMDAO) ListByTenant(tenantID string) ([]model.TenantLLM, error) { + var tenantLLMs []model.TenantLLM + err := DB.Where("tenant_id = ?", tenantID).Find(&tenantLLMs).Error + if err != nil { + return nil, err + } + return tenantLLMs, nil +} + +// GetByTenantFactoryAndModelName get tenant LLM by tenant ID, factory and model name +func (dao *TenantLLMDAO) GetByTenantFactoryAndModelName(tenantID, factory, modelName string) (*model.TenantLLM, error) { + var tenantLLM model.TenantLLM + err := DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, modelName).First(&tenantLLM).Error + if err != nil { + return nil, err + } + return &tenantLLM, nil +} + +// Create create a new tenant LLM record +func (dao *TenantLLMDAO) Create(tenantLLM *model.TenantLLM) error { + return DB.Create(tenantLLM).Error +} + +// Update update an existing tenant LLM record +func (dao *TenantLLMDAO) Update(tenantLLM *model.TenantLLM) error { + return DB.Save(tenantLLM).Error +} + +// Delete delete a tenant LLM record by tenant ID, factory and model name +func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error { + return DB.Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, modelName).Delete(&model.TenantLLM{}).Error +} + +// GetMyLLMs get tenant LLMs with factory details +func (dao *TenantLLMDAO) GetMyLLMs(tenantID string, includeDetails bool) ([]model.MyLLM, error) { + var myLLMs []model.MyLLM + + // Base query + query := DB.Table("tenant_llm tl"). + Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status"). + Joins("JOIN llm_factories lf ON tl.llm_factory = lf.name"). + Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID) + + // Add detailed fields if requested + if includeDetails { + query = query.Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status, tl.api_base, tl.max_tokens") + } + + err := query.Find(&myLLMs).Error + if err != nil { + return nil, err + } + return myLLMs, nil +} + +// ListValidByTenant lists valid tenant LLMs for a tenant +func (dao *TenantLLMDAO) ListValidByTenant(tenantID string) ([]*model.TenantLLM, error) { + var tenantLLMs []*model.TenantLLM + err := DB.Where("tenant_id = ? AND api_key IS NOT NULL AND api_key != ? AND status = ?", tenantID, "", "1").Find(&tenantLLMs).Error + if err != nil { + return nil, err + } + return tenantLLMs, nil +} + +// ListAllByTenant lists all tenant LLMs for a tenant +func (dao *TenantLLMDAO) ListAllByTenant(tenantID string) ([]*model.TenantLLM, error) { + var tenantLLMs []*model.TenantLLM + err := DB.Where("tenant_id = ?", tenantID).Find(&tenantLLMs).Error + if err != nil { + return nil, err + } + return tenantLLMs, nil +} diff --git a/internal/dao/user.go b/internal/dao/user.go new file mode 100644 index 00000000000..014be061979 --- /dev/null +++ b/internal/dao/user.go @@ -0,0 +1,103 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// UserDAO user data access object +type UserDAO struct{} + +// NewUserDAO create user DAO +func NewUserDAO() *UserDAO { + return &UserDAO{} +} + +// Create create user +func (dao *UserDAO) Create(user *model.User) error { + return DB.Create(user).Error +} + +// GetByID get user by ID +func (dao *UserDAO) GetByID(id uint) (*model.User, error) { + var user model.User + err := DB.First(&user, id).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByUsername get user by username +func (dao *UserDAO) GetByUsername(username string) (*model.User, error) { + var user model.User + err := DB.Where("username = ?", username).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByEmail get user by email +func (dao *UserDAO) GetByEmail(email string) (*model.User, error) { + var user model.User + query := DB.Where("email = ?", email) + err := query.First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByAccessToken get user by access token +func (dao *UserDAO) GetByAccessToken(token string) (*model.User, error) { + var user model.User + err := DB.Where("access_token = ?", token).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// Update update user +func (dao *UserDAO) Update(user *model.User) error { + return DB.Save(user).Error +} + +// UpdateAccessToken update user's access token +func (dao *UserDAO) UpdateAccessToken(user *model.User, token string) error { + return DB.Model(user).Update("access_token", token).Error +} + +// List list users +func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + if err := DB.Model(&model.User{}).Count(&total).Error; err != nil { + return nil, 0, err + } + + err := DB.Offset(offset).Limit(limit).Find(&users).Error + return users, total, err +} + +// Delete delete user +func (dao *UserDAO) Delete(id uint) error { + return DB.Delete(&model.User{}, id).Error +} diff --git a/internal/dao/user_canvas.go b/internal/dao/user_canvas.go new file mode 100644 index 00000000000..5d819cdcb27 --- /dev/null +++ b/internal/dao/user_canvas.go @@ -0,0 +1,129 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// UserCanvasDAO user canvas data access object +type UserCanvasDAO struct{} + +// NewUserCanvasDAO create user canvas DAO +func NewUserCanvasDAO() *UserCanvasDAO { + return &UserCanvasDAO{} +} + +// Create user canvas +func (dao *UserCanvasDAO) Create(userCanvas *model.UserCanvas) error { + return DB.Create(userCanvas).Error +} + +// GetByID get user canvas by ID +func (dao *UserCanvasDAO) GetByID(id string) (*model.UserCanvas, error) { + var canvas model.UserCanvas + err := DB.Where("id = ?", id).First(&canvas).Error + if err != nil { + return nil, err + } + return &canvas, nil +} + +// Update update user canvas +func (dao *UserCanvasDAO) Update(userCanvas *model.UserCanvas) error { + return DB.Save(userCanvas).Error +} + +// Delete delete user canvas +func (dao *UserCanvasDAO) Delete(id string) error { + return DB.Delete(&model.UserCanvas{}, id).Error +} + +// GetList get canvases list with pagination and filtering +// Similar to Python UserCanvasService.get_list +func (dao *UserCanvasDAO) GetList( + tenantID string, + pageNumber, itemsPerPage int, + orderby string, + desc bool, + id, title string, + canvasCategory string, +) ([]*model.UserCanvas, error) { + + query := DB.Model(&model.UserCanvas{}). + Where("user_id = ?", tenantID) + + if id != "" { + query = query.Where("id = ?", id) + } + if title != "" { + query = query.Where("title = ?", title) + } + if canvasCategory != "" { + query = query.Where("canvas_category = ?", canvasCategory) + } else { + // Default to agent category + query = query.Where("canvas_category = ?", "agent_canvas") + } + + // Order by + if desc { + query = query.Order(orderby + " DESC") + } else { + query = query.Order(orderby + " ASC") + } + + // Pagination + if pageNumber > 0 && itemsPerPage > 0 { + offset := (pageNumber - 1) * itemsPerPage + query = query.Offset(offset).Limit(itemsPerPage) + } + + var canvases []*model.UserCanvas + err := query.Find(&canvases).Error + return canvases, err +} + +// GetAllCanvasesByTenantIDs get all permitted canvases by tenant IDs +// Similar to Python UserCanvasService.get_all_agents_by_tenant_ids +func (dao *UserCanvasDAO) GetAllCanvasesByTenantIDs(tenantIDs []string, userID string) ([]*CanvasBasicInfo, error) { + + query := DB.Model(&model.UserCanvas{}). + Select("id, avatar, title, permission, canvas_type, canvas_category"). + Where("user_id IN (?) AND permission = ?", tenantIDs, "team"). + Or("user_id = ?", userID). + Order("create_time ASC") + + var results []*CanvasBasicInfo + err := query.Scan(&results).Error + return results, err +} + +// GetByCanvasID get user canvas by canvas ID (alias for GetByID) +func (dao *UserCanvasDAO) GetByCanvasID(canvasID string) (*model.UserCanvas, error) { + return dao.GetByID(canvasID) +} + +// CanvasBasicInfo basic canvas information for list responses +type CanvasBasicInfo struct { + ID string `gorm:"column:id" json:"id"` + Avatar *string `gorm:"column:avatar" json:"avatar,omitempty"` + Title *string `gorm:"column:title" json:"title,omitempty"` + Permission string `gorm:"column:permission" json:"permission"` + CanvasType *string `gorm:"column:canvas_type" json:"canvas_type,omitempty"` + CanvasCategory string `gorm:"column:canvas_category" json:"canvas_category"` +} diff --git a/internal/dao/user_tenant.go b/internal/dao/user_tenant.go new file mode 100644 index 00000000000..f6eb2e13bb6 --- /dev/null +++ b/internal/dao/user_tenant.go @@ -0,0 +1,126 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/model" +) + +// UserTenantDAO user tenant data access object +type UserTenantDAO struct{} + +// NewUserTenantDAO create user tenant DAO +func NewUserTenantDAO() *UserTenantDAO { + return &UserTenantDAO{} +} + +// Create create user tenant relationship +func (dao *UserTenantDAO) Create(userTenant *model.UserTenant) error { + return DB.Create(userTenant).Error +} + +// GetByID get user tenant relationship by ID +func (dao *UserTenantDAO) GetByID(id string) (*model.UserTenant, error) { + var userTenant model.UserTenant + err := DB.Where("id = ? AND status = ?", id, "1").First(&userTenant).Error + if err != nil { + return nil, err + } + return &userTenant, nil +} + +// Update update user tenant relationship +func (dao *UserTenantDAO) Update(userTenant *model.UserTenant) error { + return DB.Save(userTenant).Error +} + +// Delete delete user tenant relationship (soft delete by setting status to "0") +func (dao *UserTenantDAO) Delete(id string) error { + return DB.Model(&model.UserTenant{}).Where("id = ?", id).Update("status", "0").Error +} + +// GetByUserID get user tenant relationships by user ID +func (dao *UserTenantDAO) GetByUserID(userID string) ([]*model.UserTenant, error) { + var relations []*model.UserTenant + err := DB.Where("user_id = ? AND status = ?", userID, "1").Find(&relations).Error + return relations, err +} + +// GetByTenantID get user tenant relationships by tenant ID +func (dao *UserTenantDAO) GetByTenantID(tenantID string) ([]*model.UserTenant, error) { + var relations []*model.UserTenant + err := DB.Where("tenant_id = ? AND status = ?", tenantID, "1").Find(&relations).Error + return relations, err +} + +// GetTenantIDsByUserID get tenant ID list by user ID +func (dao *UserTenantDAO) GetTenantIDsByUserID(userID string) ([]string, error) { + var tenantIDs []string + err := DB.Model(&model.UserTenant{}). + Select("tenant_id"). + Where("user_id = ? AND status = ?", userID, "1"). + Pluck("tenant_id", &tenantIDs).Error + return tenantIDs, err +} + +// FilterByUserIDAndTenantID filter user tenant relationship by user ID and tenant ID +func (dao *UserTenantDAO) FilterByUserIDAndTenantID(userID, tenantID string) (*model.UserTenant, error) { + var userTenant model.UserTenant + err := DB.Where("user_id = ? AND tenant_id = ? AND status = ?", userID, tenantID, "1"). + First(&userTenant).Error + if err != nil { + return nil, err + } + return &userTenant, nil +} + +// GetByUserIDAndRole get user tenant relationships by user ID and role +func (dao *UserTenantDAO) GetByUserIDAndRole(userID, role string) ([]*model.UserTenant, error) { + var relations []*model.UserTenant + err := DB.Where("user_id = ? AND role = ? AND status = ?", userID, role, "1").Find(&relations).Error + return relations, err +} + +// GetNumMembers get number of members in a tenant (excluding owner) +func (dao *UserTenantDAO) GetNumMembers(tenantID string) (int64, error) { + var count int64 + err := DB.Model(&model.UserTenant{}). + Where("tenant_id = ? AND status = ? AND role != ?", tenantID, "1", "owner"). + Count(&count).Error + return count, err +} + +// TenantInfoByUserID tenant info with user details +type TenantInfoByUserID struct { + TenantID string `json:"tenant_id"` + Role string `json:"role"` + Nickname string `json:"nickname"` + Email string `json:"email"` + Avatar string `json:"avatar"` + UpdateDate string `json:"update_date"` +} + +// GetTenantsByUserID get tenants by user ID with user details +func (dao *UserTenantDAO) GetTenantsByUserID(userID string) ([]*TenantInfoByUserID, error) { + var results []*TenantInfoByUserID + err := DB.Table("user_tenant"). + Select("user_tenant.tenant_id, user_tenant.role, user.nickname, user.email, user.avatar, user.update_date"). + Joins("JOIN user ON user_tenant.tenant_id = user.id AND user_tenant.user_id = ? AND user_tenant.status = ?", userID, "1"). + Where("user_tenant.status = ?", "1"). + Scan(&results).Error + return results, err +} diff --git a/internal/engine/README.md b/internal/engine/README.md new file mode 100644 index 00000000000..b2226119cfd --- /dev/null +++ b/internal/engine/README.md @@ -0,0 +1,200 @@ +# Doc Engine Implementation + +RAGFlow Go document engine implementation, supporting Elasticsearch and Infinity storage engines. + +## Directory Structure + +``` +internal/engine/ +├── engine.go # DocEngine interface definition +├── engine_factory.go # Factory function +├── global.go # Global engine instance management +├── elasticsearch/ # Elasticsearch implementation +│ ├── client.go # ES client initialization +│ ├── search.go # Search implementation +│ ├── index.go # Index operations +│ └── document.go # Document operations +└── infinity/ # Infinity implementation + ├── client.go # Infinity client initialization (placeholder) + ├── search.go # Search implementation (placeholder) + ├── index.go # Table operations (placeholder) + └── document.go # Document operations (placeholder) +``` + +## Configuration + +### Using Elasticsearch + +Add to `conf/service_conf.yaml`: + +```yaml +doc_engine: + type: elasticsearch + es: + hosts: "http://localhost:9200" + username: "elastic" + password: "infini_rag_flow" +``` + +### Using Infinity + +```yaml +doc_engine: + type: infinity + infinity: + uri: "localhost:23817" + postgres_port: 5432 + db_name: "default_db" +``` + +**Note**: Infinity implementation is a placeholder waiting for the official Infinity Go SDK. Only Elasticsearch is fully functional at this time. + +## Usage + +### 1. Initialize Engine + +The engine is automatically initialized on service startup (see `cmd/server_main.go`): + +```go +// Initialize doc engine +if err := engine.Init(&cfg.DocEngine); err != nil { + log.Fatalf("Failed to initialize doc engine: %v", err) +} +defer engine.Close() +``` + +### 2. Use in Service + +In `ChunkService`: + +```go +type ChunkService struct { + docEngine engine.DocEngine + engineType config.EngineType +} + +func NewChunkService() *ChunkService { + cfg := config.Get() + return &ChunkService{ + docEngine: engine.Get(), + engineType: cfg.DocEngine.Type, + } +} + +// Search +func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest) (*RetrievalTestResponse, error) { + ctx := context.Background() + + switch s.engineType { + case config.EngineElasticsearch: + // Use Elasticsearch retrieval + searchReq := &elasticsearch.SearchRequest{ + IndexNames: []string{"chunks"}, + Query: elasticsearch.BuildMatchTextQuery([]string{"content"}, req.Question, "AUTO"), + Size: 10, + } + result, _ := s.docEngine.Search(ctx, searchReq) + esResp := result.(*elasticsearch.SearchResponse) + // Process result... + + case config.EngineInfinity: + // Infinity not implemented yet + return nil, fmt.Errorf("infinity not yet implemented") + } +} +``` + +### 3. Direct Use of Global Engine + +```go +import "ragflow/internal/engine" + +// Get engine instance +docEngine := engine.Get() + +// Search +searchReq := &elasticsearch.SearchRequest{ + IndexNames: []string{"my_index"}, + Query: elasticsearch.BuildTermQuery("status", "active"), +} +result, err := docEngine.Search(ctx, searchReq) + +// Index operations +err = docEngine.CreateIndex(ctx, "my_index", mapping) +err = docEngine.DeleteIndex(ctx, "my_index") +exists, _ := docEngine.IndexExists(ctx, "my_index") + +// Document operations +err = docEngine.IndexDocument(ctx, "my_index", "doc_id", docData) +bulkResp, _ := docEngine.BulkIndex(ctx, "my_index", docs) +doc, _ := docEngine.GetDocument(ctx, "my_index", "doc_id") +err = docEngine.DeleteDocument(ctx, "my_index", "doc_id") +``` + +## API Documentation + +### DocEngine Interface + +```go +type DocEngine interface { + // Search + Search(ctx context.Context, req interface{}) (interface{}, error) + + // Index operations + CreateIndex(ctx context.Context, indexName string, mapping interface{}) error + DeleteIndex(ctx context.Context, indexName string) error + IndexExists(ctx context.Context, indexName string) (bool, error) + + // Document operations + IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error + BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) + GetDocument(ctx context.Context, indexName, docID string) (interface{}, error) + DeleteDocument(ctx context.Context, indexName, docID string) error + + // Health check + Ping(ctx context.Context) error + Close() error +} +``` + +## Dependencies + +### Elasticsearch +- `github.com/elastic/go-elasticsearch/v8` + +### Infinity +- **Not available yet** - Waiting for official Infinity Go SDK + +## Notes + +1. **Type Conversion**: The `Search` method returns `interface{}`, requiring type assertion based on engine type +2. **Model Definitions**: Each engine has its own request/response models defined in their respective packages +3. **Error Handling**: It's recommended to handle errors uniformly in the service layer and return user-friendly error messages +4. **Performance Optimization**: For large volumes of documents, prefer using `BulkIndex` for batch operations +5. **Connection Management**: The engine is automatically closed when the program exits, no manual management needed +6. **Infinity Status**: Infinity implementation is currently a placeholder. Only Elasticsearch is fully functional. + +## Extending with New Engines + +To add a new document engine (e.g., Milvus, Qdrant): + +1. Create a new directory under `internal/engine/`, e.g., `milvus/` +2. Implement four files: `client.go`, `search.go`, `index.go`, `document.go` +3. Add corresponding creation logic in `engine_factory.go` +4. Add configuration structure in `config.go` +5. Update service layer code to support the new engine + +## Correspondence with Python Project + +| Python Module | Go Module | +|--------------|-----------| +| `common/doc_store/doc_store_base.py` | `internal/engine/engine.go` | +| `rag/utils/es_conn.py` | `internal/engine/elasticsearch/` | +| `rag/utils/infinity_conn.py` | `internal/engine/infinity/` (placeholder) | +| `common/settings.py` | `internal/config/config.go` | + +## Current Status + +- ✅ Elasticsearch: Fully implemented and functional +- ⏳ Infinity: Placeholder implementation, waiting for official Go SDK +- 📋 OceanBase: Not implemented (removed from requirements) diff --git a/internal/engine/elasticsearch/client.go b/internal/engine/elasticsearch/client.go new file mode 100644 index 00000000000..bfd10d056b6 --- /dev/null +++ b/internal/engine/elasticsearch/client.go @@ -0,0 +1,103 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "context" + "fmt" + "net/http" + "ragflow/internal/server" + "time" + + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esapi" +) + +// Engine Elasticsearch engine implementation +type elasticsearchEngine struct { + client *elasticsearch.Client + config *server.ElasticsearchConfig +} + +// NewEngine creates an Elasticsearch engine +func NewEngine(cfg interface{}) (*elasticsearchEngine, error) { + esConfig, ok := cfg.(*server.ElasticsearchConfig) + if !ok { + return nil, fmt.Errorf("invalid Elasticsearch config type, expected *config.ElasticsearchConfig") + } + + // Create ES client + client, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{esConfig.Hosts}, + Username: esConfig.Username, + Password: esConfig.Password, + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + ResponseHeaderTimeout: 30 * time.Second, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create Elasticsearch client: %w", err) + } + + // Check connection + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req := esapi.InfoRequest{} + res, err := req.Do(ctx, client) + if err != nil { + return nil, fmt.Errorf("failed to ping Elasticsearch: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return nil, fmt.Errorf("Elasticsearch returned error: %s", res.Status()) + } + + engine := &elasticsearchEngine{ + client: client, + config: esConfig, + } + + return engine, nil +} + +// Type returns the engine type +func (e *elasticsearchEngine) Type() string { + return "elasticsearch" +} + +// Ping health check +func (e *elasticsearchEngine) Ping(ctx context.Context) error { + req := esapi.InfoRequest{} + res, err := req.Do(ctx, e.client) + if err != nil { + return err + } + defer res.Body.Close() + if res.IsError() { + return fmt.Errorf("elasticsearch ping failed: %s", res.Status()) + } + return nil +} + +// Close closes the connection +func (e *elasticsearchEngine) Close() error { + // Go-elasticsearch client doesn't have a Close method, connection is managed by the transport + return nil +} diff --git a/internal/engine/elasticsearch/document.go b/internal/engine/elasticsearch/document.go new file mode 100644 index 00000000000..393a81d3992 --- /dev/null +++ b/internal/engine/elasticsearch/document.go @@ -0,0 +1,238 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8/esapi" +) + +// IndexDocument indexes a single document +func (e *elasticsearchEngine) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + if docID == "" { + return fmt.Errorf("document id cannot be empty") + } + if doc == nil { + return fmt.Errorf("document cannot be nil") + } + + // Serialize document + data, err := json.Marshal(doc) + if err != nil { + return fmt.Errorf("failed to marshal document: %w", err) + } + + // Index document + req := esapi.IndexRequest{ + Index: indexName, + DocumentID: docID, + Body: bytes.NewReader(data), + Refresh: "true", + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to index document: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + return nil +} + +// BulkIndex indexes documents in bulk +func (e *elasticsearchEngine) BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) { + if indexName == "" { + return nil, fmt.Errorf("index name cannot be empty") + } + if len(docs) == 0 { + return nil, fmt.Errorf("documents cannot be empty") + } + + // Build bulk request + var buf bytes.Buffer + for _, doc := range docs { + docMap, ok := doc.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("document must be map[string]interface{}") + } + + docID, hasID := docMap["_id"] + if !hasID { + return nil, fmt.Errorf("document missing _id field") + } + + // Delete _id field to avoid duplication + delete(docMap, "_id") + + // Add index operation + meta := map[string]interface{}{ + "_index": indexName, + "_id": docID, + } + metaData, _ := json.Marshal(meta) + docData, _ := json.Marshal(docMap) + + buf.Write(metaData) + buf.WriteByte('\n') + buf.Write(docData) + buf.WriteByte('\n') + } + + // Execute bulk request + req := esapi.BulkRequest{ + Body: &buf, + Refresh: "true", + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("bulk index failed: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return nil, fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Check for errors + if errors, ok := result["errors"].(bool); ok && errors { + // Get error details + if items, ok := result["items"].([]interface{}); ok && len(items) > 0 { + for _, item := range items { + if itemMap, ok := item.(map[string]interface{}); ok { + for _, op := range itemMap { + if opMap, ok := op.(map[string]interface{}); ok { + if errInfo, ok := opMap["error"].(map[string]interface{}); ok { + if reason, ok := errInfo["reason"].(string); ok { + return nil, fmt.Errorf("bulk index error: %s", reason) + } + } + } + } + } + } + } + return nil, fmt.Errorf("bulk index has errors") + } + + response := &BulkResponse{ + Took: int64(result["took"].(float64)), + Errors: result["errors"].(bool), + Indexed: len(docs), + } + + return response, nil +} + +// BulkResponse bulk operation response +type BulkResponse struct { + Took int64 + Errors bool + Indexed int +} + +// GetDocument gets a document +func (e *elasticsearchEngine) GetDocument(ctx context.Context, indexName, docID string) (interface{}, error) { + if indexName == "" { + return nil, fmt.Errorf("index name cannot be empty") + } + if docID == "" { + return nil, fmt.Errorf("document id cannot be empty") + } + + // Get document + req := esapi.GetRequest{ + Index: indexName, + DocumentID: docID, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("failed to get document: %w", err) + } + defer res.Body.Close() + + if res.StatusCode == 404 { + return nil, fmt.Errorf("document not found") + } + + if res.IsError() { + return nil, fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if found, ok := result["found"].(bool); !ok || !found { + return nil, fmt.Errorf("document not found") + } + + return result["_source"], nil +} + +// DeleteDocument deletes a document +func (e *elasticsearchEngine) DeleteDocument(ctx context.Context, indexName, docID string) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + if docID == "" { + return fmt.Errorf("document id cannot be empty") + } + + // Delete document + req := esapi.DeleteRequest{ + Index: indexName, + DocumentID: docID, + Refresh: "true", + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to delete document: %w", err) + } + defer res.Body.Close() + + if res.StatusCode == 404 { + return fmt.Errorf("document not found") + } + + if res.IsError() { + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + return nil +} diff --git a/internal/engine/elasticsearch/index.go b/internal/engine/elasticsearch/index.go new file mode 100644 index 00000000000..795c41bf04e --- /dev/null +++ b/internal/engine/elasticsearch/index.go @@ -0,0 +1,144 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + + "github.com/elastic/go-elasticsearch/v8/esapi" +) + +// CreateIndex creates an index +func (e *elasticsearchEngine) CreateIndex(ctx context.Context, indexName string, mapping interface{}) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + + // Check if index already exists + exists, err := e.IndexExists(ctx, indexName) + if err != nil { + return fmt.Errorf("failed to check index existence: %w", err) + } + if exists { + return fmt.Errorf("index '%s' already exists", indexName) + } + + // Prepare request body + var body io.Reader + if mapping != nil { + if str, ok := mapping.(string); ok { + body = bytes.NewBufferString(str) + } else { + data, err := json.Marshal(mapping) + if err != nil { + return fmt.Errorf("failed to marshal mapping: %w", err) + } + body = bytes.NewReader(data) + } + } + + // Create index + req := esapi.IndicesCreateRequest{ + Index: indexName, + Body: body, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to create index: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + acknowledged, ok := result["acknowledged"].(bool) + if !ok || !acknowledged { + return fmt.Errorf("index creation not acknowledged") + } + + return nil +} + +// DeleteIndex deletes an index +func (e *elasticsearchEngine) DeleteIndex(ctx context.Context, indexName string) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + + // Check if index exists + exists, err := e.IndexExists(ctx, indexName) + if err != nil { + return fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + return fmt.Errorf("index '%s' does not exist", indexName) + } + + // Delete index + req := esapi.IndicesDeleteRequest{ + Index: []string{indexName}, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to delete index: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + return nil +} + +// IndexExists checks if index exists +func (e *elasticsearchEngine) IndexExists(ctx context.Context, indexName string) (bool, error) { + if indexName == "" { + return false, fmt.Errorf("index name cannot be empty") + } + + req := esapi.IndicesExistsRequest{ + Index: []string{indexName}, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return false, fmt.Errorf("failed to check index existence: %w", err) + } + defer res.Body.Close() + + if res.StatusCode == 200 { + return true, nil + } else if res.StatusCode == 404 { + return false, nil + } + + return false, fmt.Errorf("elasticsearch returned error: %s", res.Status()) +} diff --git a/internal/engine/elasticsearch/search.go b/internal/engine/elasticsearch/search.go new file mode 100644 index 00000000000..c4338295200 --- /dev/null +++ b/internal/engine/elasticsearch/search.go @@ -0,0 +1,528 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/elastic/go-elasticsearch/v8/esapi" + "go.uber.org/zap" + + "ragflow/internal/engine/types" + "ragflow/internal/logger" +) + +// SearchRequest Elasticsearch search request (legacy, kept for backward compatibility) +type SearchRequest struct { + IndexNames []string + Query map[string]interface{} + Filters map[string]interface{} // Filter conditions (e.g., kb_id, doc_id, available_int) + Size int + From int + Highlight map[string]interface{} + Source []string + Sort []interface{} +} + +// SearchResponse Elasticsearch search response +type SearchResponse struct { + Hits struct { + Total struct { + Value int64 `json:"value"` + } `json:"total"` + Hits []struct { + ID string `json:"_id"` + Score float64 `json:"_score"` + Source map[string]interface{} `json:"_source"` + } `json:"hits"` + } `json:"hits"` + Aggregations map[string]interface{} `json:"aggregations"` +} + +// Search executes search (supports both unified engine.SearchRequest and legacy SearchRequest) +func (e *elasticsearchEngine) Search(ctx context.Context, req interface{}) (interface{}, error) { + + switch searchReq := req.(type) { + case *types.SearchRequest: + return e.searchUnified(ctx, searchReq) + case *SearchRequest: + return e.searchLegacy(ctx, searchReq) + default: + return nil, fmt.Errorf("invalid search request type: %T", req) + } +} + +// searchUnified handles the unified engine.SearchRequest +func (e *elasticsearchEngine) searchUnified(ctx context.Context, req *types.SearchRequest) (*types.SearchResponse, error) { + if len(req.IndexNames) == 0 { + return nil, fmt.Errorf("index names cannot be empty") + } + + // Build pagination parameters + offset, limit := calculatePagination(req.Page, req.Size, req.TopK) + + // Build filter clauses (default: available=1, meaning available_int >= 1) + // Reference: rag/utils/es_conn.py L60-L78 + filterClauses := buildFilterClauses(req.KbIDs, req.DocIDs, 1) + + // Build search query body + queryBody := make(map[string]interface{}) + + // Use MatchText if available (from QueryBuilder), otherwise use original Question + matchText := req.MatchText + if matchText == "" { + matchText = req.Question + } + + var vectorFieldName string + if req.KeywordOnly || len(req.Vector) == 0 { + // Keyword-only search + queryBody["query"] = buildESKeywordQuery(matchText, filterClauses, 1.0) + } else { + // Hybrid search: keyword + vector + // Calculate text weight + textWeight := 1.0 - req.VectorSimilarityWeight + // Build boolean query for text match and filters + boolQuery := buildESKeywordQuery(matchText, filterClauses, 1.0) + // Add boost to the bool query (as in Python code) + if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok { + boolMap["boost"] = textWeight + } + // Build kNN query + dimension := len(req.Vector) + var fieldBuilder strings.Builder + fieldBuilder.WriteString("q_") + fieldBuilder.WriteString(strconv.Itoa(dimension)) + fieldBuilder.WriteString("_vec") + vectorFieldName = fieldBuilder.String() + + k := req.TopK + if k <= 0 { + k = 1024 + } + numCandidates := k * 2 + + knnQuery := map[string]interface{}{ + "field": vectorFieldName, + "query_vector": req.Vector, + "k": k, + "num_candidates": numCandidates, + "filter": boolQuery, + "similarity": req.SimilarityThreshold, + } + + queryBody["knn"] = knnQuery + queryBody["query"] = boolQuery + } + + queryBody["size"] = limit + queryBody["from"] = offset + + // Serialize query + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(queryBody); err != nil { + return nil, fmt.Errorf("error encoding query: %w", err) + } + + // Log search details + logger.Debug("Elasticsearch searching indices", zap.Strings("indices", req.IndexNames)) + logger.Debug("Elasticsearch DSL", zap.Any("dsl", queryBody)) + + // Build search request + reqES := esapi.SearchRequest{ + Index: req.IndexNames, + Body: &buf, + } + + // Execute search + res, err := reqES.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + logger.Error("Elasticsearch failed to read error response body", err) + } else { + logger.Warn("Elasticsearch error response", zap.String("body", string(bodyBytes))) + } + return nil, fmt.Errorf("Elasticsearch returned error: %s", res.Status()) + } + + // Parse response + var esResp SearchResponse + if err := json.NewDecoder(res.Body).Decode(&esResp); err != nil { + return nil, fmt.Errorf("error parsing response: %w", err) + } + + // Convert to unified response + chunks := convertESResponse(&esResp, vectorFieldName) + return &types.SearchResponse{ + Chunks: chunks, + Total: esResp.Hits.Total.Value, + }, nil +} + +// searchLegacy handles the legacy elasticsearch.SearchRequest (backward compatibility) +func (e *elasticsearchEngine) searchLegacy(ctx context.Context, searchReq *SearchRequest) (*SearchResponse, error) { + if len(searchReq.IndexNames) == 0 { + return nil, fmt.Errorf("index names cannot be empty") + } + + // Build search query + queryBody := make(map[string]interface{}) + + // Process Filters first - convert to Elasticsearch filter clauses + var filterClauses []map[string]interface{} + if searchReq.Filters != nil && len(searchReq.Filters) > 0 { + for field, value := range searchReq.Filters { + switch v := value.(type) { + case map[string]interface{}: + filterClauses = append(filterClauses, map[string]interface{}{ + field: v, + }) + default: + filterClauses = append(filterClauses, map[string]interface{}{ + "term": map[string]interface{}{ + field: v, + }, + }) + } + } + } + + if searchReq.Query != nil { + queryCopy := make(map[string]interface{}) + for k, v := range searchReq.Query { + queryCopy[k] = v + } + + if knnValue, ok := queryCopy["knn"]; ok { + queryBody["knn"] = knnValue + delete(queryCopy, "knn") + } + + if len(queryCopy) > 0 { + if len(filterClauses) > 0 { + queryBody["query"] = map[string]interface{}{ + "bool": map[string]interface{}{ + "must": queryCopy, + "filter": filterClauses, + }, + } + } else { + queryBody["query"] = queryCopy + } + } else if len(filterClauses) > 0 { + queryBody["query"] = map[string]interface{}{ + "bool": map[string]interface{}{ + "filter": filterClauses, + }, + } + } + } else if len(filterClauses) > 0 { + queryBody["query"] = map[string]interface{}{ + "bool": map[string]interface{}{ + "filter": filterClauses, + }, + } + } + if searchReq.Size > 0 { + queryBody["size"] = searchReq.Size + } + if searchReq.From > 0 { + queryBody["from"] = searchReq.From + } + if searchReq.Highlight != nil { + queryBody["highlight"] = searchReq.Highlight + } + if len(searchReq.Source) > 0 { + queryBody["_source"] = searchReq.Source + } + if len(searchReq.Sort) > 0 { + queryBody["sort"] = searchReq.Sort + } + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(queryBody); err != nil { + return nil, fmt.Errorf("error encoding query: %w", err) + } + + logger.Debug("Elasticsearch searching indices", zap.Strings("indices", searchReq.IndexNames)) + logger.Debug("Elasticsearch DSL", zap.Any("dsl", queryBody)) + + reqES := esapi.SearchRequest{ + Index: searchReq.IndexNames, + Body: &buf, + } + + res, err := reqES.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + logger.Error("Elasticsearch failed to read error response body", err) + } else { + logger.Warn("Elasticsearch error response", zap.String("body", string(bodyBytes))) + } + return nil, fmt.Errorf("Elasticsearch returned error: %s", res.Status()) + } + + var response SearchResponse + if err := json.NewDecoder(res.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("error parsing response: %w", err) + } + + return &response, nil +} + +// calculatePagination calculates offset and limit based on page, size and topK +func calculatePagination(page, size, topK int) (int, int) { + if page < 1 { + page = 1 + } + if size <= 0 { + size = 30 + } + if topK <= 0 { + topK = 1024 + } + + RERANK_LIMIT := max(30, (64/size)*size) + if RERANK_LIMIT < size { + RERANK_LIMIT = size + } + if RERANK_LIMIT > topK { + RERANK_LIMIT = topK + } + + offset := (page - 1) * RERANK_LIMIT + if offset < 0 { + offset = 0 + } + + return offset, RERANK_LIMIT +} + +// buildFilterClauses builds ES filter clauses from kb_ids, doc_ids and available_int +// Reference: rag/utils/es_conn.py L60-L78 +// When available=0: available_int < 1 +// When available!=0: NOT (available_int < 1) +func buildFilterClauses(kbIDs, docIDs []string, available int) []map[string]interface{} { + var filters []map[string]interface{} + + if len(kbIDs) > 0 { + filters = append(filters, map[string]interface{}{ + "terms": map[string]interface{}{"kb_id": kbIDs}, + }) + } + + if len(docIDs) > 0 { + filters = append(filters, map[string]interface{}{ + "terms": map[string]interface{}{"doc_id": docIDs}, + }) + } + + // Add available_int filter + // Reference: rag/utils/es_conn.py L63-L68 + if available == 0 { + // available_int < 1 + filters = append(filters, map[string]interface{}{ + "range": map[string]interface{}{ + "available_int": map[string]interface{}{ + "lt": 1, + }, + }, + }) + } else { + // must_not: available_int < 1 (i.e., available_int >= 1) + filters = append(filters, map[string]interface{}{ + "bool": map[string]interface{}{ + "must_not": []map[string]interface{}{ + { + "range": map[string]interface{}{ + "available_int": map[string]interface{}{ + "lt": 1, + }, + }, + }, + }, + }, + }) + } + + return filters +} + +// buildESKeywordQuery builds keyword-only search query for ES +// Uses query_string if matchText is in query_string format, otherwise uses multi_match +// boost is applied to the text match clause (query_string or multi_match) +func buildESKeywordQuery(matchText string, filterClauses []map[string]interface{}, boost float64) map[string]interface{} { + var mustClause map[string]interface{} + + // Use query_string for complex queries + queryString := map[string]interface{}{ + "query": matchText, + "fields": []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"}, + "type": "best_fields", + "minimum_should_match": "30%", + "boost": boost, + } + mustClause = map[string]interface{}{ + "query_string": queryString, + } + + return map[string]interface{}{ + "bool": map[string]interface{}{ + "must": mustClause, + "filter": filterClauses, + }, + } +} + +// convertESResponse converts ES SearchResponse to unified chunks format +func convertESResponse(esResp *SearchResponse, vectorFieldName string) []map[string]interface{} { + if esResp == nil || esResp.Hits.Hits == nil { + return []map[string]interface{}{} + } + + chunks := make([]map[string]interface{}, len(esResp.Hits.Hits)) + for i, hit := range esResp.Hits.Hits { + + //// vectorField is list of float64, which need to be converted to float32 + + chunks[i] = hit.Source + chunks[i]["_score"] = hit.Score + chunks[i]["_id"] = hit.ID + //vectorField := hit.Source[vectorFieldName] + //chunks[i][vectorFieldName] = utility.Float64ToFloat32(vectorField) + } + return chunks +} + +// Helper query builder functions (legacy) + +// BuildMatchTextQuery builds a text match query +func BuildMatchTextQuery(fields []string, text string, fuzziness string) map[string]interface{} { + query := map[string]interface{}{ + "multi_match": map[string]interface{}{ + "query": text, + "fields": fields, + }, + } + + if fuzziness != "" { + if multiMatch, ok := query["multi_match"].(map[string]interface{}); ok { + multiMatch["fuzziness"] = fuzziness + } + } + + return query +} + +// BuildTermQuery builds a term query +func BuildTermQuery(field string, value interface{}) map[string]interface{} { + return map[string]interface{}{ + "term": map[string]interface{}{ + field: value, + }, + } +} + +// BuildRangeQuery builds a range query +func BuildRangeQuery(field string, from, to interface{}) map[string]interface{} { + rangeQuery := make(map[string]interface{}) + if from != nil { + rangeQuery["gte"] = from + } + if to != nil { + rangeQuery["lte"] = to + } + + return map[string]interface{}{ + "range": map[string]interface{}{ + field: rangeQuery, + }, + } +} + +// BuildBoolQuery builds a bool query +func BuildBoolQuery() map[string]interface{} { + return map[string]interface{}{ + "bool": make(map[string]interface{}), + } +} + +// AddMust adds must clause to bool query +func AddMust(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["must"]; !exists { + boolQuery["must"] = []map[string]interface{}{} + } + if must, ok := boolQuery["must"].([]map[string]interface{}); ok { + boolQuery["must"] = append(must, clauses...) + } + } +} + +// AddShould adds should clause to bool query +func AddShould(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["should"]; !exists { + boolQuery["should"] = []map[string]interface{}{} + } + if should, ok := boolQuery["should"].([]map[string]interface{}); ok { + boolQuery["should"] = append(should, clauses...) + } + } +} + +// AddFilter adds filter clause to bool query +func AddFilter(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["filter"]; !exists { + boolQuery["filter"] = []map[string]interface{}{} + } + if filter, ok := boolQuery["filter"].([]map[string]interface{}); ok { + boolQuery["filter"] = append(filter, clauses...) + } + } +} + +// AddMustNot adds must_not clause to bool query +func AddMustNot(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["must_not"]; !exists { + boolQuery["must_not"] = []map[string]interface{}{} + } + if mustNot, ok := boolQuery["must_not"].([]map[string]interface{}); ok { + boolQuery["must_not"] = append(mustNot, clauses...) + } + } +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go new file mode 100644 index 00000000000..c8e91654263 --- /dev/null +++ b/internal/engine/engine.go @@ -0,0 +1,67 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package engine + +import ( + "context" + + "ragflow/internal/engine/types" +) + +// EngineType document engine type +type EngineType string + +const ( + EngineElasticsearch EngineType = "elasticsearch" + EngineInfinity EngineType = "infinity" +) + +// SearchRequest is an alias for types.SearchRequest +type SearchRequest = types.SearchRequest + +// SearchResponse is an alias for types.SearchResponse +type SearchResponse = types.SearchResponse + +// DocEngine document storage engine interface +type DocEngine interface { + // Search + Search(ctx context.Context, req interface{}) (interface{}, error) + + // Index operations + CreateIndex(ctx context.Context, indexName string, mapping interface{}) error + DeleteIndex(ctx context.Context, indexName string) error + IndexExists(ctx context.Context, indexName string) (bool, error) + + // Document operations + IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error + BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) + GetDocument(ctx context.Context, indexName, docID string) (interface{}, error) + DeleteDocument(ctx context.Context, indexName, docID string) error + + // Health check + Ping(ctx context.Context) error + Close() error +} + +// Type returns the engine type (helper method for runtime type checking) +// This is a workaround since we can't import elasticsearch or infinity packages directly +func Type(docEngine DocEngine) EngineType { + // Type checking through interface methods is not straightforward + // This is a placeholder that should be implemented differently + // or rely on configuration to know the type + return EngineType("unknown") +} diff --git a/internal/engine/global.go b/internal/engine/global.go new file mode 100644 index 00000000000..315dfb4baae --- /dev/null +++ b/internal/engine/global.go @@ -0,0 +1,70 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package engine + +import ( + "fmt" + "ragflow/internal/server" + "sync" + + "go.uber.org/zap" + + "ragflow/internal/engine/elasticsearch" + "ragflow/internal/engine/infinity" + "ragflow/internal/logger" +) + +var ( + globalEngine DocEngine + once sync.Once +) + +// Init initializes document engine +func Init(cfg *server.DocEngineConfig) error { + var initErr error + once.Do(func() { + var err error + switch EngineType(cfg.Type) { + case EngineElasticsearch: + globalEngine, err = elasticsearch.NewEngine(cfg.ES) + case EngineInfinity: + globalEngine, err = infinity.NewEngine(cfg.Infinity) + default: + err = fmt.Errorf("unsupported doc engine type: %s", cfg.Type) + } + + if err != nil { + initErr = fmt.Errorf("failed to create doc engine: %w", err) + return + } + logger.Info("Doc engine initialized", zap.String("type", string(cfg.Type))) + }) + return initErr +} + +// Get gets global document engine instance +func Get() DocEngine { + return globalEngine +} + +// Close closes document engine +func Close() error { + if globalEngine != nil { + return globalEngine.Close() + } + return nil +} diff --git a/internal/engine/infinity/client.go b/internal/engine/infinity/client.go new file mode 100644 index 00000000000..7c1dbcaacc6 --- /dev/null +++ b/internal/engine/infinity/client.go @@ -0,0 +1,59 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "fmt" + "ragflow/internal/server" +) + +// Engine Infinity engine implementation +// Note: Infinity Go SDK is not yet available. This is a placeholder implementation. +type infinityEngine struct { + config *server.InfinityConfig +} + +// NewEngine creates an Infinity engine +// Note: This is a placeholder implementation waiting for official Infinity Go SDK +func NewEngine(cfg interface{}) (*infinityEngine, error) { + infConfig, ok := cfg.(*server.InfinityConfig) + if !ok { + return nil, fmt.Errorf("invalid infinity config type, expected *config.InfinityConfig") + } + + engine := &infinityEngine{ + config: infConfig, + } + + return engine, nil +} + +// Type returns the engine type +func (e *infinityEngine) Type() string { + return "infinity" +} + +// Ping health check +func (e *infinityEngine) Ping(ctx context.Context) error { + return fmt.Errorf("infinity engine not implemented: waiting for official Go SDK") +} + +// Close closes the connection +func (e *infinityEngine) Close() error { + return nil +} diff --git a/internal/engine/infinity/document.go b/internal/engine/infinity/document.go new file mode 100644 index 00000000000..f56545e83eb --- /dev/null +++ b/internal/engine/infinity/document.go @@ -0,0 +1,47 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "fmt" +) + +// IndexDocument indexes a single document +func (e *infinityEngine) IndexDocument(ctx context.Context, tableName, docID string, doc interface{}) error { + return fmt.Errorf("infinity insert not implemented: waiting for official Go SDK") +} + +// BulkIndex indexes documents in bulk +func (e *infinityEngine) BulkIndex(ctx context.Context, tableName string, docs []interface{}) (interface{}, error) { + return nil, fmt.Errorf("infinity bulk insert not implemented: waiting for official Go SDK") +} + +// BulkResponse bulk operation response +type BulkResponse struct { + Inserted int +} + +// GetDocument gets a document +func (e *infinityEngine) GetDocument(ctx context.Context, tableName, docID string) (interface{}, error) { + return nil, fmt.Errorf("infinity get document not implemented: waiting for official Go SDK") +} + +// DeleteDocument deletes a document +func (e *infinityEngine) DeleteDocument(ctx context.Context, tableName, docID string) error { + return fmt.Errorf("infinity delete not implemented: waiting for official Go SDK") +} diff --git a/internal/engine/infinity/index.go b/internal/engine/infinity/index.go new file mode 100644 index 00000000000..f4bab3dfb4e --- /dev/null +++ b/internal/engine/infinity/index.go @@ -0,0 +1,37 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "fmt" +) + +// CreateIndex creates a table/index +func (e *infinityEngine) CreateIndex(ctx context.Context, indexName string, mapping interface{}) error { + return fmt.Errorf("infinity create table not implemented: waiting for official Go SDK") +} + +// DeleteIndex deletes a table/index +func (e *infinityEngine) DeleteIndex(ctx context.Context, indexName string) error { + return fmt.Errorf("infinity drop table not implemented: waiting for official Go SDK") +} + +// IndexExists checks if table/index exists +func (e *infinityEngine) IndexExists(ctx context.Context, indexName string) (bool, error) { + return false, fmt.Errorf("infinity check table existence not implemented: waiting for official Go SDK") +} diff --git a/internal/engine/infinity/search.go b/internal/engine/infinity/search.go new file mode 100644 index 00000000000..e1aa033c04e --- /dev/null +++ b/internal/engine/infinity/search.go @@ -0,0 +1,205 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "fmt" + "strconv" + "strings" + + "ragflow/internal/engine/types" +) + +// SearchRequest Infinity search request (legacy, kept for backward compatibility) +type SearchRequest struct { + TableName string + ColumnNames []string + MatchText *MatchTextExpr + MatchDense *MatchDenseExpr + Fusion *FusionExpr + Offset int + Limit int + Filter map[string]interface{} +} + +// SearchResponse Infinity search response +type SearchResponse struct { + Rows []map[string]interface{} + Total int64 +} + +// MatchTextExpr text match expression +type MatchTextExpr struct { + Fields []string + MatchingText string + TopN int + ExtraOptions map[string]interface{} +} + +// MatchDenseExpr vector match expression +type MatchDenseExpr struct { + VectorColumnName string + EmbeddingData []float64 + EmbeddingDataType string + DistanceType string + TopN int + ExtraOptions map[string]interface{} +} + +// FusionExpr fusion expression +type FusionExpr struct { + Method string + TopN int + Weights []float64 + FusionParams map[string]interface{} +} + +// Search executes search (supports both unified engine.SearchRequest and legacy SearchRequest) +func (e *infinityEngine) Search(ctx context.Context, req interface{}) (interface{}, error) { + switch searchReq := req.(type) { + case *types.SearchRequest: + return e.searchUnified(ctx, searchReq) + case *SearchRequest: + return e.searchLegacy(ctx, searchReq) + default: + return nil, fmt.Errorf("invalid search request type: %T", req) + } +} + +// searchUnified handles the unified engine.SearchRequest +func (e *infinityEngine) searchUnified(ctx context.Context, req *types.SearchRequest) (*types.SearchResponse, error) { + if len(req.IndexNames) == 0 { + return nil, fmt.Errorf("index names cannot be empty") + } + + // For Infinity, we use the first index name as table name + tableName := req.IndexNames[0] + + // Get retrieval parameters with defaults + similarityThreshold := req.SimilarityThreshold + if similarityThreshold <= 0 { + similarityThreshold = 0.1 + } + + topK := req.TopK + if topK <= 0 { + topK = 1024 + } + + vectorSimilarityWeight := req.VectorSimilarityWeight + if vectorSimilarityWeight < 0 || vectorSimilarityWeight > 1 { + vectorSimilarityWeight = 0.3 + } + + pageSize := req.Size + if pageSize <= 0 { + pageSize = 30 + } + + offset := (req.Page - 1) * pageSize + if offset < 0 { + offset = 0 + } + + // Build search request + searchReq := &SearchRequest{ + TableName: tableName, + Limit: pageSize, + Offset: offset, + Filter: buildInfinityFilters(req.KbIDs, req.DocIDs), + } + + // Add text match (question is always required) + searchReq.MatchText = &MatchTextExpr{ + Fields: []string{"title_tks", "content_ltks"}, + MatchingText: req.Question, + TopN: topK, + } + + // Add vector match if vector is provided and not keyword-only mode + if !req.KeywordOnly && len(req.Vector) > 0 { + fieldName := buildInfinityVectorFieldName(req.Vector) + searchReq.MatchDense = &MatchDenseExpr{ + VectorColumnName: fieldName, + EmbeddingData: req.Vector, + EmbeddingDataType: "float", + DistanceType: "cosine", + TopN: topK, + ExtraOptions: map[string]interface{}{ + "similarity": similarityThreshold, + }, + } + // Infinity uses weighted_sum fusion with weights + searchReq.Fusion = &FusionExpr{ + Method: "weighted_sum", + TopN: topK, + Weights: []float64{ + 1.0 - vectorSimilarityWeight, // text weight + vectorSimilarityWeight, // vector weight + }, + } + } + + // Execute the actual search (would call Infinity SDK here) + // For now, return not implemented + return nil, fmt.Errorf("infinity search unified not implemented: waiting for official Go SDK") +} + +// searchLegacy handles the legacy infinity.SearchRequest (backward compatibility) +func (e *infinityEngine) searchLegacy(ctx context.Context, req *SearchRequest) (*SearchResponse, error) { + // This would contain the actual Infinity search implementation + return nil, fmt.Errorf("infinity search legacy not implemented: waiting for official Go SDK") +} + +// buildInfinityFilters builds filter conditions for Infinity +func buildInfinityFilters(kbIDs []string, docIDs []string) map[string]interface{} { + filters := make(map[string]interface{}) + + // kb_id filter + if len(kbIDs) > 0 { + if len(kbIDs) == 1 { + filters["kb_id"] = kbIDs[0] + } else { + filters["kb_id"] = kbIDs + } + } + + // doc_id filter + if len(docIDs) > 0 { + if len(docIDs) == 1 { + filters["doc_id"] = docIDs[0] + } else { + filters["doc_id"] = docIDs + } + } + + // available_int filter (default to 1 for available chunks) + filters["available_int"] = 1 + + return filters +} + +// buildInfinityVectorFieldName builds vector field name based on dimension +func buildInfinityVectorFieldName(vector []float64) string { + dimension := len(vector) + var fieldBuilder strings.Builder + fieldBuilder.WriteString("q_") + fieldBuilder.WriteString(strconv.Itoa(dimension)) + fieldBuilder.WriteString("_vec") + return fieldBuilder.String() +} diff --git a/internal/engine/types/types.go b/internal/engine/types/types.go new file mode 100644 index 00000000000..e1ebfc4abf8 --- /dev/null +++ b/internal/engine/types/types.go @@ -0,0 +1,54 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package types + +// SearchRequest unified search request for all engines +type SearchRequest struct { + // Common fields + IndexNames []string // For ES: index names; For Infinity: treated as table names + Question string // Search query text + Vector []float64 // Embedding vector (optional, for hybrid search) + + // Query analysis results (from QueryBuilder.Question) + MatchText string // Processed match text for ES query_string + Keywords []string // Extracted keywords from question + + // Filters + KbIDs []string // Knowledge base IDs filter + DocIDs []string // Document IDs filter + + // Pagination + Page int // Page number (1-based) + Size int // Page size + TopK int // Number of candidates for retrieval + + // Search mode + KeywordOnly bool // If true, only do keyword search (no vector search) + + // Scoring parameters + SimilarityThreshold float64 // Minimum similarity score (default: 0.1) + VectorSimilarityWeight float64 // Weight for vector vs keyword (default: 0.3) + + // Engine-specific options (optional, for advanced use) + Options map[string]interface{} +} + +// SearchResponse unified search response for all engines +type SearchResponse struct { + Chunks []map[string]interface{} // Search results + Total int64 // Total number of matches +} diff --git a/internal/go_binding/rag_analyzer.go b/internal/go_binding/rag_analyzer.go new file mode 100644 index 00000000000..f1386f51a85 --- /dev/null +++ b/internal/go_binding/rag_analyzer.go @@ -0,0 +1,265 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package rag_analyzer + +/* +#cgo CXXFLAGS: -std=c++20 -I${SRCDIR}/.. +#cgo linux LDFLAGS: ${SRCDIR}/../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread /usr/lib/x86_64-linux-gnu/libpcre2-8.a +#cgo darwin LDFLAGS: ${SRCDIR}/../cpp/cmake-build-release/librag_tokenizer_c_api.a -lstdc++ -lm -lpthread /usr/local/lib/libpcre2-8.a + +#include +#include "../cpp/rag_analyzer_c_api.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +// Token represents a single token from the analyzer +type Token struct { + Text string + Offset uint32 + EndOffset uint32 +} + +// TokenWithPosition represents a token with position information +type TokenWithPosition struct { + Text string + Offset uint32 + EndOffset uint32 +} + +// Analyzer wraps the C RAGAnalyzer +type Analyzer struct { + handle C.RAGAnalyzerHandle +} + +// NewAnalyzer creates a new RAGAnalyzer instance +// path: path to dictionary files (containing rag/, wordnet/, opencc/ directories) +func NewAnalyzer(path string) (*Analyzer, error) { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + handle := C.RAGAnalyzer_Create(cPath) + if handle == nil { + return nil, fmt.Errorf("failed to create RAGAnalyzer") + } + + return &Analyzer{handle: handle}, nil +} + +// Load loads the analyzer dictionaries +func (a *Analyzer) Load() error { + if a.handle == nil { + return fmt.Errorf("analyzer is not initialized") + } + + ret := C.RAGAnalyzer_Load(a.handle) + if ret != 0 { + return fmt.Errorf("failed to load analyzer, error code: %d", ret) + } + return nil +} + +// SetFineGrained sets whether to use fine-grained tokenization +func (a *Analyzer) SetFineGrained(fineGrained bool) { + if a.handle == nil { + return + } + C.RAGAnalyzer_SetFineGrained(a.handle, C.bool(fineGrained)) +} + +// SetEnablePosition sets whether to enable position tracking +func (a *Analyzer) SetEnablePosition(enablePosition bool) { + if a.handle == nil { + return + } + C.RAGAnalyzer_SetEnablePosition(a.handle, C.bool(enablePosition)) +} + +// Analyze analyzes the input text and returns all tokens +func (a *Analyzer) Analyze(text string) ([]Token, error) { + if a.handle == nil { + return nil, fmt.Errorf("analyzer is not initialized") + } + + // Since the C API now uses TermList instead of callback, + // we need a different approach. Let's use Tokenize for now + // and return the tokens parsed from the space-separated string. + result, err := a.Tokenize(text) + if err != nil { + return nil, err + } + + // Parse the space-separated result into tokens + // This is a simplified version - for full position support, + // we would need to modify the C API to return structured data + tokens := parseTokens(result) + return tokens, nil +} + +// parseTokens splits a space-separated string into tokens +func parseTokens(result string) []Token { + var tokens []Token + start := 0 + for i := 0; i <= len(result); i++ { + if i == len(result) || result[i] == ' ' { + if start < i { + tokens = append(tokens, Token{ + Text: result[start:i], + Offset: uint32(start), + // EndOffset will be approximate without position tracking + EndOffset: uint32(i), + }) + } + start = i + 1 + } + } + return tokens +} + +// Tokenize analyzes text and returns a space-separated string of tokens +func (a *Analyzer) Tokenize(text string) (string, error) { + if a.handle == nil { + return "", fmt.Errorf("analyzer is not initialized") + } + + cText := C.CString(text) + defer C.free(unsafe.Pointer(cText)) + + cResult := C.RAGAnalyzer_Tokenize(a.handle, cText) + if cResult == nil { + return "", fmt.Errorf("tokenize failed") + } + defer C.free(unsafe.Pointer(cResult)) + + return C.GoString(cResult), nil +} + +// TokenizeWithPosition analyzes text and returns tokens with position information +func (a *Analyzer) TokenizeWithPosition(text string) ([]TokenWithPosition, error) { + if a.handle == nil { + return nil, fmt.Errorf("analyzer is not initialized") + } + + cText := C.CString(text) + defer C.free(unsafe.Pointer(cText)) + + cTokenList := C.RAGAnalyzer_TokenizeWithPosition(a.handle, cText) + if cTokenList == nil { + return nil, fmt.Errorf("tokenize with position failed") + } + defer C.RAGAnalyzer_FreeTokenList(cTokenList) + + // Convert C token list to Go slice + tokens := make([]TokenWithPosition, cTokenList.count) + + // Iterate through tokens using helper functions + for i := 0; i < int(cTokenList.count); i++ { + // Calculate pointer to the i-th token + cToken := unsafe.Pointer( + uintptr(unsafe.Pointer(cTokenList.tokens)) + + uintptr(i)*unsafe.Sizeof(C.struct_RAGTokenWithPosition{}), + ) + + // Use C helper functions to access fields (pass as void*) + tokens[i] = TokenWithPosition{ + Text: C.GoString(C.RAGToken_GetText(cToken)), + Offset: uint32(C.RAGToken_GetOffset(cToken)), + EndOffset: uint32(C.RAGToken_GetEndOffset(cToken)), + } + } + + return tokens, nil +} + +// Close destroys the analyzer and releases resources +func (a *Analyzer) Close() { + if a.handle != nil { + C.RAGAnalyzer_Destroy(a.handle) + a.handle = nil + } +} + +// FineGrainedTokenize performs fine-grained tokenization on space-separated tokens +// Input: space-separated tokens (e.g., "hello world 测试") +// Output: space-separated fine-grained tokens (e.g., "hello world 测 试") +func (a *Analyzer) FineGrainedTokenize(tokens string) (string, error) { + if a.handle == nil { + return "", fmt.Errorf("analyzer is not initialized") + } + + cTokens := C.CString(tokens) + defer C.free(unsafe.Pointer(cTokens)) + + cResult := C.RAGAnalyzer_FineGrainedTokenize(a.handle, cTokens) + if cResult == nil { + return "", fmt.Errorf("fine-grained tokenize failed") + } + defer C.free(unsafe.Pointer(cResult)) + + return C.GoString(cResult), nil +} + +// GetTermFreq returns the frequency of a term (matching Python rag_tokenizer.freq) +// Returns: frequency value, or 0 if term not found +func (a *Analyzer) GetTermFreq(term string) int32 { + if a.handle == nil { + return 0 + } + + cTerm := C.CString(term) + defer C.free(unsafe.Pointer(cTerm)) + + return int32(C.RAGAnalyzer_GetTermFreq(a.handle, cTerm)) +} + +// GetTermTag returns the POS tag of a term (matching Python rag_tokenizer.tag) +// Returns: POS tag string (e.g., "n", "v", "ns"), or empty string if term not found or no tag +func (a *Analyzer) GetTermTag(term string) string { + if a.handle == nil { + return "" + } + + cTerm := C.CString(term) + defer C.free(unsafe.Pointer(cTerm)) + + cResult := C.RAGAnalyzer_GetTermTag(a.handle, cTerm) + if cResult == nil { + return "" + } + defer C.free(unsafe.Pointer(cResult)) + + return C.GoString(cResult) +} + +// Copy creates a new independent analyzer instance from the current one +// The new instance shares the loaded dictionaries but has independent internal state +// This is useful for creating per-request analyzer instances in concurrent environments +func (a *Analyzer) Copy() *Analyzer { + if a.handle == nil { + return nil + } + + handle := C.RAGAnalyzer_Copy(a.handle) + if handle == nil { + return nil + } + + return &Analyzer{handle: handle} +} diff --git a/internal/handler/chat.go b/internal/handler/chat.go new file mode 100644 index 00000000000..aa09c3353b9 --- /dev/null +++ b/internal/handler/chat.go @@ -0,0 +1,314 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// ChatHandler chat handler +type ChatHandler struct { + chatService *service.ChatService + userService *service.UserService +} + +// NewChatHandler create chat handler +func NewChatHandler(chatService *service.ChatService, userService *service.UserService) *ChatHandler { + return &ChatHandler{ + chatService: chatService, + userService: userService, + } +} + +// ListChats list chats +// @Summary List Chats +// @Description Get list of chats (dialogs) for the current user +// @Tags chat +// @Accept json +// @Produce json +// @Success 200 {object} service.ListChatsResponse +// @Router /v1/dialog/list [get] +func (h *ChatHandler) ListChats(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // List chats - default to valid status "1" (same as Python StatusEnum.VALID.value) + result, err := h.chatService.ListChats(userID, "1") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} + +// ListChatsNext list chats with advanced filtering and pagination +// @Summary List Chats Next +// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next) +// @Tags chat +// @Accept json +// @Produce json +// @Param keywords query string false "search keywords" +// @Param page query int false "page number" +// @Param page_size query int false "items per page" +// @Param orderby query string false "order by field (default: create_time)" +// @Param desc query bool false "descending order (default: true)" +// @Param request body service.ListChatsNextRequest true "filter options including owner_ids" +// @Success 200 {object} service.ListChatsNextResponse +// @Router /v1/dialog/next [post] +func (h *ChatHandler) ListChatsNext(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse query parameters + keywords := c.Query("keywords") + + page := 0 + if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + + pageSize := 0 + if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 { + pageSize = ps + } + } + + orderby := c.DefaultQuery("orderby", "create_time") + + desc := true + if descStr := c.Query("desc"); descStr != "" { + desc = descStr != "false" + } + + // Parse request body for owner_ids + var req service.ListChatsNextRequest + if c.Request.ContentLength > 0 { + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + } + + // List chats with advanced filtering + result, err := h.chatService.ListChatsNext(userID, keywords, page, pageSize, orderby, desc, req.OwnerIDs) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} + +// SetDialog create or update a dialog +// @Summary Set Dialog +// @Description Create or update a dialog (chat). If dialog_id is provided, updates existing dialog; otherwise creates new one. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body service.SetDialogRequest true "dialog configuration" +// @Success 200 {object} service.SetDialogResponse +// @Router /v1/dialog/set [post] +func (h *ChatHandler) SetDialog(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse request body + var req service.SetDialogRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Validate required field: prompt_config + if req.PromptConfig == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "prompt_config is required", + }) + return + } + + // Call service to set dialog + result, err := h.chatService.SetDialog(userID, &req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} + +// RemoveDialogsRequest remove dialogs request +type RemoveDialogsRequest struct { + DialogIDs []string `json:"dialog_ids" binding:"required"` +} + +// RemoveChats remove/delete dialogs (soft delete by setting status to invalid) +// @Summary Remove Dialogs +// @Description Remove dialogs by setting their status to invalid. Only the owner of the dialog can perform this operation. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body RemoveDialogsRequest true "dialog IDs to remove" +// @Success 200 {object} map[string]interface{} +// @Router /v1/dialog/rm [post] +func (h *ChatHandler) RemoveChats(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse request body + var req RemoveDialogsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Call service to remove dialogs + if err := h.chatService.RemoveChats(userID, req.DialogIDs); err != nil { + // Check if it's an authorization error + if err.Error() == "only owner of chat authorized for this operation" { + c.JSON(http.StatusForbidden, gin.H{ + "code": 403, + "data": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": true, + "message": "success", + }) +} diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go new file mode 100644 index 00000000000..fd5d4492310 --- /dev/null +++ b/internal/handler/chat_session.go @@ -0,0 +1,377 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// ChatSessionHandler chat session (conversation) handler +type ChatSessionHandler struct { + chatSessionService *service.ChatSessionService + userService *service.UserService +} + +// NewChatSessionHandler create chat session handler +func NewChatSessionHandler(chatSessionService *service.ChatSessionService, userService *service.UserService) *ChatSessionHandler { + return &ChatSessionHandler{ + chatSessionService: chatSessionService, + userService: userService, + } +} + +// SetChatSession create or update a chat session +// @Summary Set chat session +// @Description Create or update a chat session. If is_new is true, creates new chat session; otherwise updates existing one. +// @Tags chat_session +// @Accept json +// @Produce json +// @Param request body service.SetChatSessionRequest true "chat session configuration" +// @Success 200 {object} service.SetChatSessionResponse +// @Router /v1/conversation/set [post] +func (h *ChatSessionHandler) SetChatSession(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse request body + var req service.SetChatSessionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Call service to set chat session + result, err := h.chatSessionService.SetChatSession(userID, &req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} + +// RemoveChatSessionsRequest remove chat sessions request +type RemoveChatSessionsRequest struct { + ConversationIDs []string `json:"conversation_ids" binding:"required"` +} + +// RemoveChatSessions remove/delete chat sessions +// @Summary Remove Chat Sessions +// @Description Remove chat sessions by their IDs. Only the owner of the chat session can perform this operation. +// @Tags chat_session +// @Accept json +// @Produce json +// @Param request body RemoveChatSessionsRequest true "chat session IDs to remove" +// @Success 200 {object} map[string]interface{} +// @Router /v1/conversation/rm [post] +func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse request body + var req RemoveChatSessionsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Call service to remove chat sessions + if err := h.chatSessionService.RemoveChatSessions(userID, req.ConversationIDs); err != nil { + // Check if it's an authorization error + if err.Error() == "Only owner of chat session authorized for this operation" { + c.JSON(http.StatusForbidden, gin.H{ + "code": 403, + "data": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": true, + "message": "success", + }) +} + +// ListChatSessions list chat sessions for a dialog +// @Summary List Chat Sessions +// @Description Get list of chat sessions for a specific dialog +// @Tags chat_session +// @Accept json +// @Produce json +// @Param dialog_id query string true "dialog ID" +// @Success 200 {object} service.ListChatSessionsResponse +// @Router /v1/conversation/list [get] +func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Get dialog_id from query parameter + dialogID := c.Query("dialog_id") + if dialogID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "dialog_id is required", + }) + return + } + + // Call service to list chat sessions + result, err := h.chatSessionService.ListChatSessions(userID, dialogID) + if err != nil { + // Check if it's an authorization error + if err.Error() == "Only owner of dialog authorized for this operation" { + c.JSON(http.StatusForbidden, gin.H{ + "code": 403, + "data": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result.Sessions, + "message": "success", + }) +} + +// CompletionRequest completion request +type CompletionRequest struct { + ConversationID string `json:"conversation_id" binding:"required"` + Messages []map[string]interface{} `json:"messages" binding:"required"` + LLMID string `json:"llm_id,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` +} + +// Completion chat completion +// @Summary Chat Completion +// @Description Send messages to the chat model and get a response. Supports streaming and non-streaming modes. +// @Tags chat_session +// @Accept json +// @Produce json +// @Param request body CompletionRequest true "completion request" +// @Success 200 {object} map[string]interface{} +// @Router /v1/conversation/completion [post] +func (h *ChatSessionHandler) Completion(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse request body + var req CompletionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Build chat model config + chatModelConfig := make(map[string]interface{}) + if req.Temperature != 0 { + chatModelConfig["temperature"] = req.Temperature + } + if req.TopP != 0 { + chatModelConfig["top_p"] = req.TopP + } + if req.FrequencyPenalty != 0 { + chatModelConfig["frequency_penalty"] = req.FrequencyPenalty + } + if req.PresencePenalty != 0 { + chatModelConfig["presence_penalty"] = req.PresencePenalty + } + if req.MaxTokens != 0 { + chatModelConfig["max_tokens"] = req.MaxTokens + } + + // Process messages - filter out system messages and initial assistant messages + var processedMessages []map[string]interface{} + for i, m := range req.Messages { + role, _ := m["role"].(string) + if role == "system" { + continue + } + if role == "assistant" && len(processedMessages) == 0 { + continue + } + processedMessages = append(processedMessages, m) + _ = i + } + + // Get last message ID if present + var messageID string + if len(processedMessages) > 0 { + if id, ok := processedMessages[len(processedMessages)-1]["id"].(string); ok { + messageID = id + } + } + + // Call service + if req.Stream { + // Streaming response + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + // Create a channel for streaming data + streamChan := make(chan string) + go func() { + defer close(streamChan) + err := h.chatSessionService.CompletionStream(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", err.Error()) + } + }() + + // Stream data to client + c.Stream(func(w io.Writer) bool { + data, ok := <-streamChan + if !ok { + return false + } + c.Writer.Write([]byte(data)) + return true + }) + } else { + // Non-streaming response + result, err := h.chatSessionService.Completion(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "", + }) + } +} diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go new file mode 100644 index 00000000000..10b19830da3 --- /dev/null +++ b/internal/handler/chunk.go @@ -0,0 +1,180 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// ChunkHandler chunk handler +type ChunkHandler struct { + chunkService *service.ChunkService + userService *service.UserService +} + +// NewChunkHandler create chunk handler +func NewChunkHandler(chunkService *service.ChunkService, userService *service.UserService) *ChunkHandler { + return &ChunkHandler{ + chunkService: chunkService, + userService: userService, + } +} + +// RetrievalTest performs retrieval test for chunks +// @Summary Retrieval Test +// @Description Test retrieval of chunks based on question and knowledge base +// @Tags chunks +// @Accept json +// @Produce json +// @Param request body service.RetrievalTestRequest true "retrieval test parameters" +// @Success 200 {object} map[string]interface{} +// @Router /v1/chunk/retrieval_test [post] +func (h *ChunkHandler) RetrievalTest(c *gin.Context) { + // Extract access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Bind JSON request + var req service.RetrievalTestRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Set default values for optional parameters + if req.Page == nil { + defaultPage := 1 + req.Page = &defaultPage + } + if req.Size == nil { + defaultSize := 30 + req.Size = &defaultSize + } + if req.TopK == nil { + defaultTopK := 1024 + req.TopK = &defaultTopK + } + if req.UseKG == nil { + defaultUseKG := false + req.UseKG = &defaultUseKG + } + + // Validate required fields + if req.Question == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "question is required", + }) + return + } + if req.KbID == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id is required", + }) + return + } + + // Validate kb_id type: string or []string + switch v := req.KbID.(type) { + case string: + if v == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id cannot be empty string", + }) + return + } + case []interface{}: + // Convert to []string + var kbIDs []string + for _, item := range v { + if str, ok := item.(string); ok && str != "" { + kbIDs = append(kbIDs, str) + } else { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id array must contain non-empty strings", + }) + return + } + } + if len(kbIDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id array cannot be empty", + }) + return + } + // Convert back to interface{} for service + req.KbID = kbIDs + case []string: + // Already correct type + if len(v) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id array cannot be empty", + }) + return + } + default: + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "kb_id must be string or array of strings", + }) + return + } + + // Call service with user ID for permission checks + resp, err := h.chunkService.RetrievalTest(&req, user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": resp, + "message": "success", + }) +} diff --git a/internal/handler/connector.go b/internal/handler/connector.go new file mode 100644 index 00000000000..9f54b804198 --- /dev/null +++ b/internal/handler/connector.go @@ -0,0 +1,86 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// ConnectorHandler connector handler +type ConnectorHandler struct { + connectorService *service.ConnectorService + userService *service.UserService +} + +// NewConnectorHandler create connector handler +func NewConnectorHandler(connectorService *service.ConnectorService, userService *service.UserService) *ConnectorHandler { + return &ConnectorHandler{ + connectorService: connectorService, + userService: userService, + } +} + +// ListConnectors list connectors +// @Summary List Connectors +// @Description Get list of connectors for the current user (equivalent to Python's list_connector) +// @Tags connector +// @Accept json +// @Produce json +// @Success 200 {object} service.ListConnectorsResponse +// @Router /connector/list [get] +func (h *ConnectorHandler) ListConnectors(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // List connectors + result, err := h.connectorService.ListConnectors(userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result.Connectors, + "message": "success", + }) +} diff --git a/internal/handler/document.go b/internal/handler/document.go new file mode 100644 index 00000000000..10f08b6baf8 --- /dev/null +++ b/internal/handler/document.go @@ -0,0 +1,258 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// DocumentHandler document handler +type DocumentHandler struct { + documentService *service.DocumentService +} + +// NewDocumentHandler create document handler +func NewDocumentHandler(documentService *service.DocumentService) *DocumentHandler { + return &DocumentHandler{ + documentService: documentService, + } +} + +// CreateDocument create document +// @Summary Create Document +// @Description Create new document +// @Tags documents +// @Accept json +// @Produce json +// @Param request body service.CreateDocumentRequest true "document info" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/documents [post] +func (h *DocumentHandler) CreateDocument(c *gin.Context) { + var req service.CreateDocumentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + document, err := h.documentService.CreateDocument(&req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "created successfully", + "data": document, + }) +} + +// GetDocumentByID get document by ID +// @Summary Get Document Info +// @Description Get document details by ID +// @Tags documents +// @Accept json +// @Produce json +// @Param id path int true "document ID" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/documents/{id} [get] +func (h *DocumentHandler) GetDocumentByID(c *gin.Context) { + id := c.Param("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid document id", + }) + return + } + + document, err := h.documentService.GetDocumentByID(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": "document not found", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": document, + }) +} + +// UpdateDocument update document +// @Summary Update Document +// @Description Update document info +// @Tags documents +// @Accept json +// @Produce json +// @Param id path int true "document ID" +// @Param request body service.UpdateDocumentRequest true "update info" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/documents/{id} [put] +func (h *DocumentHandler) UpdateDocument(c *gin.Context) { + id := c.Param("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid document id", + }) + return + } + + var req service.UpdateDocumentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + if err := h.documentService.UpdateDocument(id, &req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "updated successfully", + }) +} + +// DeleteDocument delete document +// @Summary Delete Document +// @Description Delete specified document +// @Tags documents +// @Accept json +// @Produce json +// @Param id path int true "document ID" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/documents/{id} [delete] +func (h *DocumentHandler) DeleteDocument(c *gin.Context) { + id := c.Param("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid document id", + }) + return + } + + if err := h.documentService.DeleteDocument(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "deleted successfully", + }) +} + +// ListDocuments document list +// @Summary Document List +// @Description Get paginated document list +// @Tags documents +// @Accept json +// @Produce json +// @Param page query int false "page number" default(1) +// @Param page_size query int false "items per page" default(10) +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/documents [get] +func (h *DocumentHandler) ListDocuments(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + + if page < 1 { + page = 1 + } + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + + documents, total, err := h.documentService.ListDocuments(page, pageSize) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get documents", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": gin.H{ + "items": documents, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} + +// GetDocumentsByAuthorID get documents by author ID +// @Summary Get Author Documents +// @Description Get paginated document list by author ID +// @Tags documents +// @Accept json +// @Produce json +// @Param author_id path int true "author ID" +// @Param page query int false "page number" default(1) +// @Param page_size query int false "items per page" default(10) +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/authors/{author_id}/documents [get] +func (h *DocumentHandler) GetDocumentsByAuthorID(c *gin.Context) { + authorIDStr := c.Param("author_id") + authorID, err := strconv.Atoi(authorIDStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid author id", + }) + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + + if page < 1 { + page = 1 + } + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + + documents, total, err := h.documentService.GetDocumentsByAuthorID(authorID, page, pageSize) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get documents", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": gin.H{ + "items": documents, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} diff --git a/internal/handler/error.go b/internal/handler/error.go new file mode 100644 index 00000000000..9ca6b6c5fd9 --- /dev/null +++ b/internal/handler/error.go @@ -0,0 +1,46 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "ragflow/internal/logger" +) + +// HandleNoRoute handles requests to undefined routes +func HandleNoRoute(c *gin.Context) { + // Log the request details on server side + logger.Logger.Warn("The requested URL was not found", + zap.String("method", c.Request.Method), + zap.String("path", c.Request.URL.Path), + zap.String("query", c.Request.URL.RawQuery), + zap.String("remote_addr", c.ClientIP()), + zap.String("user_agent", c.Request.UserAgent()), + ) + + // Return JSON error response + c.JSON(http.StatusNotFound, gin.H{ + "code": 404, + "message": "Not Found: " + c.Request.URL.Path, + "data": nil, + "error": "Not Found", + }) +} diff --git a/internal/handler/file.go b/internal/handler/file.go new file mode 100644 index 00000000000..974d3bbd688 --- /dev/null +++ b/internal/handler/file.go @@ -0,0 +1,283 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// FileHandler file handler +type FileHandler struct { + fileService *service.FileService + userService *service.UserService +} + +// NewFileHandler create file handler +func NewFileHandler(fileService *service.FileService, userService *service.UserService) *FileHandler { + return &FileHandler{ + fileService: fileService, + userService: userService, + } +} + +// ListFiles list files +// @Summary List Files +// @Description Get list of files for the current user with filtering, pagination and sorting +// @Tags file +// @Accept json +// @Produce json +// @Param parent_id query string false "parent folder ID" +// @Param keywords query string false "search keywords" +// @Param page query int false "page number (default: 1)" +// @Param page_size query int false "items per page (default: 15)" +// @Param orderby query string false "order by field (default: create_time)" +// @Param desc query bool false "descending order (default: true)" +// @Success 200 {object} service.ListFilesResponse +// @Router /v1/file/list [get] +func (h *FileHandler) ListFiles(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse query parameters + parentID := c.Query("parent_id") + keywords := c.Query("keywords") + + // Parse page (default: 1) + page := 1 + if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + + // Parse page_size (default: 15) + pageSize := 15 + if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 { + pageSize = ps + } + } + + // Parse orderby (default: create_time) + orderby := c.DefaultQuery("orderby", "create_time") + + // Parse desc (default: true) + desc := true + if descStr := c.Query("desc"); descStr != "" { + desc = descStr != "false" + } + + // List files + result, err := h.fileService.ListFiles(userID, parentID, page, pageSize, orderby, desc, keywords) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} + +// GetRootFolder gets root folder for current user +// @Summary Get Root Folder +// @Description Get or create root folder for the current user +// @Tags file +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} +// @Router /v1/file/root_folder [get] +func (h *FileHandler) GetRootFolder(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Get root folder + rootFolder, err := h.fileService.GetRootFolder(userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": gin.H{"root_folder": rootFolder}, + "message": "success", + }) +} + +// GetParentFolder gets parent folder of a file +// @Summary Get Parent Folder +// @Description Get parent folder of a file by file ID +// @Tags file +// @Accept json +// @Produce json +// @Param file_id query string true "file ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/file/parent_folder [get] +func (h *FileHandler) GetParentFolder(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token (for validation) + _, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Get file_id from query + fileID := c.Query("file_id") + if fileID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "file_id is required", + }) + return + } + + // Get parent folder + parentFolder, err := h.fileService.GetParentFolder(fileID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": gin.H{"parent_folder": parentFolder}, + "message": "success", + }) +} + +// GetAllParentFolders gets all parent folders in path +// @Summary Get All Parent Folders +// @Description Get all parent folders in path from file to root +// @Tags file +// @Accept json +// @Produce json +// @Param file_id query string true "file ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/file/all_parent_folder [get] +func (h *FileHandler) GetAllParentFolders(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token (for validation) + _, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Get file_id from query + fileID := c.Query("file_id") + if fileID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "file_id is required", + }) + return + } + + // Get all parent folders + parentFolders, err := h.fileService.GetAllParentFolders(fileID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": gin.H{"parent_folders": parentFolders}, + "message": "success", + }) +} diff --git a/internal/handler/kb.go b/internal/handler/kb.go new file mode 100644 index 00000000000..1c482fa89f1 --- /dev/null +++ b/internal/handler/kb.go @@ -0,0 +1,158 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// KnowledgebaseHandler knowledge base handler +type KnowledgebaseHandler struct { + kbService *service.KnowledgebaseService + userService *service.UserService +} + +// NewKnowledgebaseHandler create knowledge base handler +func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userService *service.UserService) *KnowledgebaseHandler { + return &KnowledgebaseHandler{ + kbService: kbService, + userService: userService, + } +} + +// ListKbs list knowledge bases +// @Summary List Knowledge Bases +// @Description Get list of knowledge bases with filtering and pagination +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Param keywords query string false "search keywords" +// @Param page query int false "page number" +// @Param page_size query int false "items per page" +// @Param parser_id query string false "parser ID filter" +// @Param orderby query string false "order by field" +// @Param desc query bool false "descending order" +// @Param request body service.ListKbsRequest true "filter options" +// @Success 200 {object} service.ListKbsResponse +// @Router /v1/kb/list [post] +func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { + // Parse request body - allow empty body + var req service.ListKbsRequest + if c.Request.ContentLength > 0 { + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + } + + // Extract parameters from query or request body with defaults + keywords := "" + if req.Keywords != nil { + keywords = *req.Keywords + } else if queryKeywords := c.Query("keywords"); queryKeywords != "" { + keywords = queryKeywords + } + + page := 0 + if req.Page != nil { + page = *req.Page + } else if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + + pageSize := 0 + if req.PageSize != nil { + pageSize = *req.PageSize + } else if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 { + pageSize = ps + } + } + + parserID := "" + if req.ParserID != nil { + parserID = *req.ParserID + } else if queryParserID := c.Query("parser_id"); queryParserID != "" { + parserID = queryParserID + } + + orderby := "update_time" + if req.Orderby != nil { + orderby = *req.Orderby + } else if queryOrderby := c.Query("orderby"); queryOrderby != "" { + orderby = queryOrderby + } + + desc := true + if req.Desc != nil { + desc = *req.Desc + } else if descStr := c.Query("desc"); descStr != "" { + desc = descStr == "true" + } + + var ownerIDs []string + if req.OwnerIDs != nil { + ownerIDs = *req.OwnerIDs + } + + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // List knowledge bases + result, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} diff --git a/internal/handler/llm.go b/internal/handler/llm.go new file mode 100644 index 00000000000..bcad7f2be1d --- /dev/null +++ b/internal/handler/llm.go @@ -0,0 +1,247 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/dao" + "ragflow/internal/service" +) + +// FactoryResponse represents a model provider factory +type FactoryResponse struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + ModelTypes []string `json:"model_types"` +} + +// LLMHandler LLM handler +type LLMHandler struct { + llmService *service.LLMService + userService *service.UserService +} + +// NewLLMHandler create LLM handler +func NewLLMHandler(llmService *service.LLMService, userService *service.UserService) *LLMHandler { + return &LLMHandler{ + llmService: llmService, + userService: userService, + } +} + +// GetMyLLMs get my LLMs +// @Summary Get My LLMs +// @Description Get LLM list for current tenant +// @Tags llm +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param include_details query string false "Include detailed fields" default(false) +// @Success 200 {object} map[string]interface{} +// @Router /v1/llm/my_llms [get] +func (h *LLMHandler) GetMyLLMs(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid access token", + }) + return + } + + // Get tenant ID from user + tenantID := user.ID + if tenantID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User has no tenant ID", + }) + return + } + + // Parse include_details query parameter + includeDetailsStr := c.DefaultQuery("include_details", "false") + includeDetails := includeDetailsStr == "true" + + // Get LLMs for tenant + llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get LLMs", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": llms, + }) +} + +// Factories get model provider factories +// @Summary Get Model Provider Factories +// @Description Get list of model provider factories +// @Tags llm +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {array} FactoryResponse +// @Router /v1/llm/factories [get] +func (h *LLMHandler) Factories(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + _, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid access token", + }) + return + } + + // Get model providers + dao := dao.NewModelProviderDAO() + providers := dao.GetAllProviders() + + // Filter out unwanted providers + filtered := make([]FactoryResponse, 0) + excluded := map[string]bool{ + "Youdao": true, + "FastEmbed": true, + "BAAI": true, + "Builtin": true, + } + + for _, provider := range providers { + if excluded[provider.Name] { + continue + } + + // Collect unique model types from LLMs + modelTypes := make(map[string]bool) + for _, llm := range provider.LLMs { + modelTypes[llm.ModelType] = true + } + + // Convert to slice + modelTypeSlice := make([]string, 0, len(modelTypes)) + for mt := range modelTypes { + modelTypeSlice = append(modelTypeSlice, mt) + } + + // If no model types found, use defaults + if len(modelTypeSlice) == 0 { + modelTypeSlice = []string{"chat", "embedding", "rerank", "image2text", "speech2text", "tts", "ocr"} + } + + filtered = append(filtered, FactoryResponse{ + Name: provider.Name, + Logo: provider.Logo, + Tags: provider.Tags, + Status: provider.Status, + Rank: provider.Rank, + ModelTypes: modelTypeSlice, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "data": filtered, + }) +} + +// ListApp lists LLMs grouped by factory +// @Summary List LLMs +// @Description Get list of LLMs grouped by factory with availability info +// @Tags llm +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param model_type query string false "Filter by model type" +// @Success 200 {object} map[string][]service.LLMListItem +// @Router /v1/llm/list [get] +func (h *LLMHandler) ListApp(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Get tenant ID from user + tenantID := user.ID + if tenantID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "User has no tenant ID", + }) + return + } + + // Parse model_type query parameter + modelType := c.Query("model_type") + + // Get LLM list + llms, err := h.llmService.ListLLMs(tenantID, modelType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": llms, + "message": "success", + }) +} diff --git a/internal/handler/search.go b/internal/handler/search.go new file mode 100644 index 00000000000..5a6317b183f --- /dev/null +++ b/internal/handler/search.go @@ -0,0 +1,129 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// SearchHandler search handler +type SearchHandler struct { + searchService *service.SearchService + userService *service.UserService +} + +// NewSearchHandler create search handler +func NewSearchHandler(searchService *service.SearchService, userService *service.UserService) *SearchHandler { + return &SearchHandler{ + searchService: searchService, + userService: userService, + } +} + +// ListSearchApps list search apps +// @Summary List Search Apps +// @Description Get list of search apps for the current user with filtering, pagination and sorting +// @Tags search +// @Accept json +// @Produce json +// @Param keywords query string false "search keywords" +// @Param page query int false "page number" +// @Param page_size query int false "items per page" +// @Param orderby query string false "order by field (default: create_time)" +// @Param desc query bool false "descending order (default: true)" +// @Param request body service.ListSearchAppsRequest true "filter options including owner_ids" +// @Success 200 {object} service.ListSearchAppsResponse +// @Router /v1/search/list [post] +func (h *SearchHandler) ListSearchApps(c *gin.Context) { + // Get access token from Authorization header + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by access token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + userID := user.ID + + // Parse query parameters + keywords := c.Query("keywords") + + page := 0 + if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + + pageSize := 0 + if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 { + pageSize = ps + } + } + + orderby := c.DefaultQuery("orderby", "create_time") + + desc := true + if descStr := c.Query("desc"); descStr != "" { + desc = descStr != "false" + } + + // Parse request body for owner_ids + var req service.ListSearchAppsRequest + if c.Request.ContentLength > 0 { + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + } + + // List search apps with filtering + result, err := h.searchService.ListSearchApps(userID, keywords, page, pageSize, orderby, desc, req.OwnerIDs) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": result, + "message": "success", + }) +} diff --git a/internal/handler/system.go b/internal/handler/system.go new file mode 100644 index 00000000000..da7fe52f625 --- /dev/null +++ b/internal/handler/system.go @@ -0,0 +1,125 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "ragflow/internal/server" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// SystemHandler system handler +type SystemHandler struct { + systemService *service.SystemService +} + +// NewSystemHandler create system handler +func NewSystemHandler(systemService *service.SystemService) *SystemHandler { + return &SystemHandler{ + systemService: systemService, + } +} + +// Ping health check endpoint +// @Summary Ping +// @Description Simple ping endpoint +// @Tags system +// @Produce plain +// @Success 200 {string} string "pong" +// @Router /v1/system/ping [get] +func (h *SystemHandler) Ping(c *gin.Context) { + c.String(http.StatusOK, "pong") +} + +// GetConfig get system configuration +// @Summary Get System Configuration +// @Description Get system configuration including register enabled status +// @Tags system +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} +// @Router /v1/system/config [get] +func (h *SystemHandler) GetConfig(c *gin.Context) { + config, err := h.systemService.GetConfig() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to get system configuration", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": config, + }) +} + +// GetConfigs get all system configurations +// @Summary Get All System Configurations +// @Description Get all system configurations from globalConfig +// @Tags system +// @Accept json +// @Produce json +// @Success 200 {object} config.Config +// @Router /v1/system/configs [get] +func (h *SystemHandler) GetConfigs(c *gin.Context) { + cfg := server.GetConfig() + if cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Configuration not initialized", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": cfg, + }) +} + +// GetVersion get RAGFlow version +// @Summary Get RAGFlow Version +// @Description Get the current version of the application +// @Tags system +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/system/version [get] +func (h *SystemHandler) GetVersion(c *gin.Context) { + version, err := h.systemService.GetVersion() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to get version", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": version.Version, + }) +} diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go new file mode 100644 index 00000000000..ab96f958ce4 --- /dev/null +++ b/internal/handler/tenant.go @@ -0,0 +1,135 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// TenantHandler tenant handler +type TenantHandler struct { + tenantService *service.TenantService + userService *service.UserService +} + +// NewTenantHandler create tenant handler +func NewTenantHandler(tenantService *service.TenantService, userService *service.UserService) *TenantHandler { + return &TenantHandler{ + tenantService: tenantService, + userService: userService, + } +} + +// TenantInfo get tenant information +// @Summary Get Tenant Information +// @Description Get current user's tenant information (owner tenant) +// @Tags tenants +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/tenant_info [get] +func (h *TenantHandler) TenantInfo(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid access token", + }) + return + } + + // Get tenant info + tenantInfo, err := h.tenantService.GetTenantInfo(user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get tenant information", + }) + return + } + + if tenantInfo == nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": "Tenant not found", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": tenantInfo, + }) +} + +// TenantList get tenant list for current user +// @Summary Get Tenant List +// @Description Get all tenants that the current user belongs to +// @Tags tenants +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/tenant/list [get] +func (h *TenantHandler) TenantList(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Get tenant list + tenantList, err := h.tenantService.GetTenantList(user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Failed to get tenant list", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": tenantList, + }) +} diff --git a/internal/handler/user.go b/internal/handler/user.go new file mode 100644 index 00000000000..2a4091857fd --- /dev/null +++ b/internal/handler/user.go @@ -0,0 +1,456 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "ragflow/internal/server" + "ragflow/internal/utility" + "strconv" + + "github.com/gin-gonic/gin" + + "ragflow/internal/service" +) + +// UserHandler user handler +type UserHandler struct { + userService *service.UserService +} + +// NewUserHandler create user handler +func NewUserHandler(userService *service.UserService) *UserHandler { + return &UserHandler{ + userService: userService, + } +} + +// Register user registration +// @Summary User Registration +// @Description Create new user +// @Tags users +// @Accept json +// @Produce json +// @Param request body service.RegisterRequest true "registration info" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/users/register [post] +func (h *UserHandler) Register(c *gin.Context) { + var req service.RegisterRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + user, err := h.userService.Register(&req) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "registration successful", + "data": gin.H{ + "id": user.ID, + "nickname": user.Nickname, + "email": user.Email, + }, + }) +} + +// Login user login +// @Summary User Login +// @Description User login verification +// @Tags users +// @Accept json +// @Produce json +// @Param request body service.LoginRequest true "login info" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/users/login [post] +func (h *UserHandler) Login(c *gin.Context) { + var req service.LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + user, err := h.userService.Login(&req) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": err.Error(), + }) + return + } + + // Set Authorization header with access_token + if user.AccessToken != nil { + c.Header("Authorization", *user.AccessToken) + } + // Set CORS headers + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "*") + c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Expose-Headers", "Authorization") + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "Welcome back!", + "data": user, + }) +} + +// LoginByEmail user login by email +// @Summary User Login by Email +// @Description User login verification using email +// @Tags users +// @Accept json +// @Produce json +// @Param request body service.EmailLoginRequest true "login info with email" +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/login [post] +func (h *UserHandler) LoginByEmail(c *gin.Context) { + var req service.EmailLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + user, err := h.userService.LoginByEmail(&req) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": err.Error(), + }) + return + } + + variables := server.GetVariables() + secretKey := variables.SecretKey + authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) + + // Set Authorization header with access_token + if user.AccessToken != nil { + c.Header("Authorization", authToken) + } + // Set CORS headers + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "*") + c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Expose-Headers", "Authorization") + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "Welcome back!", + "data": user, + }) +} + +// GetUserByID get user by ID +// @Summary Get User Info +// @Description Get user details by ID +// @Tags users +// @Accept json +// @Produce json +// @Param id path int true "user ID" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/users/{id} [get] +func (h *UserHandler) GetUserByID(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseUint(idStr, 10, 32) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid user id", + }) + return + } + + user, err := h.userService.GetUserByID(uint(id)) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": "user not found", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": user, + }) +} + +// ListUsers user list +// @Summary User List +// @Description Get paginated user list +// @Tags users +// @Accept json +// @Produce json +// @Param page query int false "page number" default(1) +// @Param page_size query int false "items per page" default(10) +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/users [get] +func (h *UserHandler) ListUsers(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + + if page < 1 { + page = 1 + } + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + + users, total, err := h.userService.ListUsers(page, pageSize) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get users", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": gin.H{ + "items": users, + "total": total, + "page": page, + "page_size": pageSize, + }, + }) +} + +// Logout user logout +// @Summary User Logout +// @Description Logout user and invalidate access token +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/logout [post] +func (h *UserHandler) Logout(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Invalid access token", + }) + return + } + + // Logout user + if err := h.userService.Logout(user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": true, + "message": "success", + }) +} + +// Info get user profile information +// @Summary Get User Profile +// @Description Get current user's profile information +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/info [get] +func (h *UserHandler) Info(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "error": "Invalid access token", + }) + return + } + + // Get user profile + profile := h.userService.GetUserProfile(user) + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": profile, + }) +} + +// Setting update user settings +// @Summary Update User Settings +// @Description Update current user's settings +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.UpdateSettingsRequest true "user settings" +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/setting [post] +func (h *UserHandler) Setting(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid access token", + }) + return + } + + // Parse request + var req service.UpdateSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + // Update user settings + if err := h.userService.UpdateUserSettings(user, &req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "settings updated successfully", + }) +} + +// ChangePassword change user password +// @Summary Change User Password +// @Description Change current user's password +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.ChangePasswordRequest true "password change info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/setting/password [post] +func (h *UserHandler) ChangePassword(c *gin.Context) { + // Extract token from request + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + return + } + + // Get user by token + user, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid access token", + }) + return + } + + // Parse request + var req service.ChangePasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + // Change password + if err := h.userService.ChangePassword(user, &req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "password changed successfully", + }) +} + +// GetLoginChannels get all supported authentication channels +// @Summary Get Login Channels +// @Description Get all supported OAuth authentication channels +// @Tags users +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/login/channels [get] +func (h *UserHandler) GetLoginChannels(c *gin.Context) { + channels, err := h.userService.GetLoginChannels() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": "Load channels failure, error: " + err.Error(), + "data": []interface{}{}, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": channels, + }) +} diff --git a/internal/logger/README.md b/internal/logger/README.md new file mode 100644 index 00000000000..adc941baf14 --- /dev/null +++ b/internal/logger/README.md @@ -0,0 +1,70 @@ +# Logger Package + +This package provides structured logging using Uber's Zap library. + +## Installation + +Install zap dependency: + +```bash +go get go.uber.org/zap +``` + +## Usage + +The logger is initialized in `cmd/server_main.go` and is available throughout the application. + +### Basic Usage + +```go +import ( + "ragflow/internal/logger" + "go.uber.org/zap" +) + +// Log with structured fields +logger.Info("User login", zap.String("user_id", userID), zap.String("ip", clientIP)) + +// Log error +logger.Error("Failed to connect database", err) + +// Log fatal (exits application) +logger.Fatal("Failed to start server", err) + +// Debug level +logger.Debug("Processing request", zap.String("request_id", reqID)) + +// Warning level +logger.Warn("Slow query", zap.Duration("duration", duration)) +``` + +### Access Logger Directly + +If you need the underlying Zap logger: + +```go +logger.Logger.Info("Message", zap.String("key", "value")) +``` + +Or use the SugaredLogger for more flexible API: + +```go +logger.Sugar.Infow("Message", "key", "value") +``` + +## Fallback to Standard Logger + +If zap is not installed or fails to initialize, the logger will fallback to the standard library `log` package, ensuring the application continues to work. + +## Log Levels + +The logger supports the following levels: +- `debug` - Detailed information for debugging +- `info` - General informational messages +- `warn` - Warning messages +- `error` - Error messages +- `fatal` - Fatal errors that stop the application + +The log level is configured via the server mode in the configuration: +- `debug` mode uses `debug` level +- `release` mode uses `info` level diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000000..d45313d37e1 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,138 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package logger + +import ( + "fmt" + "runtime" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var ( + Logger *zap.Logger + Sugar *zap.SugaredLogger +) + +// Init initializes the global logger +// Note: This requires zap to be installed: go get go.uber.org/zap +func Init(level string) error { + // Parse log level + var zapLevel zapcore.Level + switch level { + case "debug": + zapLevel = zapcore.DebugLevel + case "info": + zapLevel = zapcore.InfoLevel + case "warn": + zapLevel = zapcore.WarnLevel + case "error": + zapLevel = zapcore.ErrorLevel + default: + zapLevel = zapcore.InfoLevel + } + + // Custom encoder config to control output format + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "timestamp", + LevelKey: "level", + NameKey: "logger", + CallerKey: "", // Disable caller/line number + FunctionKey: "", + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05"), // Human-readable time format + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, // Not used since CallerKey is empty + } + + // Configure zap + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zapLevel), + Development: false, + Encoding: "console", + EncoderConfig: encoderConfig, + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + } + + // Build logger + logger, err := config.Build(zap.AddCallerSkip(1)) + if err != nil { + return err + } + + Logger = logger + Sugar = logger.Sugar() + + return nil +} + +// Sync flushes any buffered log entries +func Sync() { + if Logger != nil { + _ = Logger.Sync() + } +} + +// Fatal logs a fatal message using zap with caller info +func Fatal(msg string, fields ...zap.Field) { + if Logger == nil { + panic("logger not initialized") + } + // Get caller info (skip this function to get the actual caller) + _, file, line, ok := runtime.Caller(1) + if ok { + fields = append(fields, zap.String("caller", fmt.Sprintf("%s:%d", file, line))) + } + Logger.Fatal(msg, fields...) +} + +// Info logs an info message using zap or standard logger +func Info(msg string, fields ...zap.Field) { + if Logger == nil { + return + } + Logger.Info(msg, fields...) +} + +// Error logs an error message using zap or standard logger +func Error(msg string, err error) { + if Logger == nil { + return + } + Logger.Error(msg, zap.Error(err)) +} + +// Debug logs a debug message using zap or standard logger +func Debug(msg string, fields ...zap.Field) { + if Logger == nil { + return + } + Logger.Debug(msg, fields...) +} + +// Warn logs a warning message using zap or standard logger +func Warn(msg string, fields ...zap.Field) { + if Logger == nil { + return + } + Logger.Warn(msg, fields...) +} diff --git a/internal/model/api.go b/internal/model/api.go new file mode 100644 index 00000000000..afc3a985fb2 --- /dev/null +++ b/internal/model/api.go @@ -0,0 +1,54 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// APIToken API token model +type APIToken struct { + TenantID string `gorm:"column:tenant_id;size:32;not null;primaryKey" json:"tenant_id"` + Token string `gorm:"column:token;size:255;not null;primaryKey" json:"token"` + DialogID *string `gorm:"column:dialog_id;size:32;index" json:"dialog_id,omitempty"` + Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"` + Beta *string `gorm:"column:beta;size:255;index" json:"beta,omitempty"` + BaseModel +} + +// TableName specify table name +func (APIToken) TableName() string { + return "api_token" +} + +// API4Conversation API for conversation model +type API4Conversation struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` + UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"` + Message JSONMap `gorm:"column:message;type:json" json:"message,omitempty"` + Reference JSONMap `gorm:"column:reference;type:json;default:'[]'" json:"reference"` + Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"` + Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + Duration float64 `gorm:"column:duration;default:0;index" json:"duration"` + Round int64 `gorm:"column:round;default:0;index" json:"round"` + ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"` + Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"` + BaseModel +} + +// TableName specify table name +func (API4Conversation) TableName() string { + return "api_4_conversation" +} diff --git a/internal/model/base.go b/internal/model/base.go new file mode 100644 index 00000000000..dfccc45a80b --- /dev/null +++ b/internal/model/base.go @@ -0,0 +1,79 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import ( + "database/sql/driver" + "encoding/json" + "time" +) + +// BaseModel base model +type BaseModel struct { + CreateTime int64 `gorm:"column:create_time;index" json:"create_time"` + CreateDate *time.Time `gorm:"column:create_date;index" json:"create_date,omitempty"` + UpdateTime *int64 `gorm:"column:update_time;index" json:"update_time,omitempty"` + UpdateDate *time.Time `gorm:"column:update_date;index" json:"update_date,omitempty"` +} + +// JSONMap is a map type that can store JSON data +type JSONMap map[string]interface{} + +// Value implements driver.Valuer interface +func (j JSONMap) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return json.Marshal(j) +} + +// Scan implements sql.Scanner interface +func (j *JSONMap) Scan(value interface{}) error { + if value == nil { + *j = nil + return nil + } + b, ok := value.([]byte) + if !ok { + return json.Unmarshal([]byte(value.(string)), j) + } + return json.Unmarshal(b, j) +} + +// JSONSlice is a slice type that can store JSON array data +type JSONSlice []interface{} + +// Value implements driver.Valuer interface +func (j JSONSlice) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return json.Marshal(j) +} + +// Scan implements sql.Scanner interface +func (j *JSONSlice) Scan(value interface{}) error { + if value == nil { + *j = nil + return nil + } + b, ok := value.([]byte) + if !ok { + return json.Unmarshal([]byte(value.(string)), j) + } + return json.Unmarshal(b, j) +} diff --git a/internal/model/canvas.go b/internal/model/canvas.go new file mode 100644 index 00000000000..06a0be3edd1 --- /dev/null +++ b/internal/model/canvas.go @@ -0,0 +1,68 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// UserCanvas user canvas model +type UserCanvas struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"` + Title *string `gorm:"column:title;size:255" json:"title,omitempty"` + Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + CanvasType *string `gorm:"column:canvas_type;size:32;index" json:"canvas_type,omitempty"` + CanvasCategory string `gorm:"column:canvas_category;size:32;not null;default:agent_canvas;index" json:"canvas_category"` + DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + BaseModel +} + +// TableName specify table name +func (UserCanvas) TableName() string { + return "user_canvas" +} + +// CanvasTemplate canvas template model +type CanvasTemplate struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + Title JSONMap `gorm:"column:title;type:json;default:'{}'" json:"title"` + Description JSONMap `gorm:"column:description;type:json;default:'{}'" json:"description"` + CanvasType *string `gorm:"column:canvas_type;size:32;index" json:"canvas_type,omitempty"` + CanvasCategory string `gorm:"column:canvas_category;size:32;not null;default:agent_canvas;index" json:"canvas_category"` + DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + BaseModel +} + +// TableName specify table name +func (CanvasTemplate) TableName() string { + return "canvas_template" +} + +// UserCanvasVersion user canvas version model +type UserCanvasVersion struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + UserCanvasID string `gorm:"column:user_canvas_id;size:255;not null;index" json:"user_canvas_id"` + Title *string `gorm:"column:title;size:255" json:"title,omitempty"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + BaseModel +} + +// TableName specify table name +func (UserCanvasVersion) TableName() string { + return "user_canvas_version" +} diff --git a/internal/model/chat.go b/internal/model/chat.go new file mode 100644 index 00000000000..2bb54aec40b --- /dev/null +++ b/internal/model/chat.go @@ -0,0 +1,64 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "encoding/json" + +// Chat chat model (mapped to dialog table) +type Chat struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name *string `gorm:"column:name;size:255;index" json:"name,omitempty"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + Icon *string `gorm:"column:icon;type:longtext" json:"icon,omitempty"` + Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` + LLMID string `gorm:"column:llm_id;size:128;not null" json:"llm_id"` + LLMSetting JSONMap `gorm:"column:llm_setting;type:json;not null;default:'{\"temperature\":0.1,\"top_p\":0.3,\"frequency_penalty\":0.7,\"presence_penalty\":0.4,\"max_tokens\":512}'" json:"llm_setting"` + PromptType string `gorm:"column:prompt_type;size:16;not null;default:simple;index" json:"prompt_type"` + PromptConfig JSONMap `gorm:"column:prompt_config;type:json;not null;default:'{\"system\":\"\",\"prologue\":\"Hi! I'm your assistant. What can I do for you?\",\"parameters\":[],\"empty_response\":\"Sorry! No relevant content was found in the knowledge base!\"}'" json:"prompt_config"` + MetaDataFilter *JSONMap `gorm:"column:meta_data_filter;type:json" json:"meta_data_filter,omitempty"` + SimilarityThreshold float64 `gorm:"column:similarity_threshold;default:0.2" json:"similarity_threshold"` + VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3" json:"vector_similarity_weight"` + TopN int64 `gorm:"column:top_n;default:6" json:"top_n"` + TopK int64 `gorm:"column:top_k;default:1024" json:"top_k"` + DoRefer string `gorm:"column:do_refer;size:1;not null;default:1" json:"do_refer"` + RerankID string `gorm:"column:rerank_id;size:128;not null;default:''" json:"rerank_id"` + KBIDs JSONSlice `gorm:"column:kb_ids;type:json;not null;default:'[]'" json:"kb_ids"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (Chat) TableName() string { + return "dialog" +} + +// Conversation conversation model +type ChatSession struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` + Name *string `gorm:"column:name;size:255;index" json:"name,omitempty"` + Message json.RawMessage `gorm:"column:message;type:json" json:"message,omitempty"` + Reference json.RawMessage `gorm:"column:reference;type:json;default:'[]'" json:"reference"` + UserID *string `gorm:"column:user_id;size:255;index" json:"user_id,omitempty"` + BaseModel +} + +// TableName specify table name +func (ChatSession) TableName() string { + return "conversation" +} diff --git a/internal/model/connector.go b/internal/model/connector.go new file mode 100644 index 00000000000..893c12fb63b --- /dev/null +++ b/internal/model/connector.go @@ -0,0 +1,78 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// Connector connector model +type Connector struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name string `gorm:"column:name;size:128;not null" json:"name"` + Source string `gorm:"column:source;size:128;not null;index" json:"source"` + InputType string `gorm:"column:input_type;size:128;not null;index" json:"input_type"` + Config JSONMap `gorm:"column:config;type:json;not null;default:'{}'" json:"config"` + RefreshFreq int64 `gorm:"column:refresh_freq;default:0" json:"refresh_freq"` + PruneFreq int64 `gorm:"column:prune_freq;default:0" json:"prune_freq"` + TimeoutSecs int64 `gorm:"column:timeout_secs;default:3600" json:"timeout_secs"` + IndexingStart *time.Time `gorm:"column:indexing_start;index" json:"indexing_start,omitempty"` + Status string `gorm:"column:status;size:16;not null;default:schedule;index" json:"status"` + BaseModel +} + +// TableName specify table name +func (Connector) TableName() string { + return "connector" +} + +// Connector2Kb connector to knowledge base mapping model +type Connector2Kb struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + ConnectorID string `gorm:"column:connector_id;size:32;not null;index" json:"connector_id"` + KbID string `gorm:"column:kb_id;size:32;not null;index" json:"kb_id"` + AutoParse string `gorm:"column:auto_parse;size:1;not null;default:1" json:"auto_parse"` + BaseModel +} + +// TableName specify table name +func (Connector2Kb) TableName() string { + return "connector2kb" +} + +// SyncLogs sync logs model +type SyncLogs struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + ConnectorID string `gorm:"column:connector_id;size:32;index" json:"connector_id"` + Status string `gorm:"column:status;size:128;not null;index" json:"status"` + FromBeginning *string `gorm:"column:from_beginning;size:1" json:"from_beginning,omitempty"` + NewDocsIndexed int64 `gorm:"column:new_docs_indexed;default:0" json:"new_docs_indexed"` + TotalDocsIndexed int64 `gorm:"column:total_docs_indexed;default:0" json:"total_docs_indexed"` + DocsRemovedFromIndex int64 `gorm:"column:docs_removed_from_index;default:0" json:"docs_removed_from_index"` + ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;default:''" json:"error_msg"` + ErrorCount int64 `gorm:"column:error_count;default:0" json:"error_count"` + FullExceptionTrace *string `gorm:"column:full_exception_trace;type:longtext" json:"full_exception_trace,omitempty"` + TimeStarted *time.Time `gorm:"column:time_started;index" json:"time_started,omitempty"` + PollRangeStart *string `gorm:"column:poll_range_start;size:255;index" json:"poll_range_start,omitempty"` + PollRangeEnd *string `gorm:"column:poll_range_end;size:255;index" json:"poll_range_end,omitempty"` + KbID string `gorm:"column:kb_id;size:32;not null;index" json:"kb_id"` + BaseModel +} + +// TableName specify table name +func (SyncLogs) TableName() string { + return "sync_logs" +} diff --git a/internal/model/document.go b/internal/model/document.go new file mode 100644 index 00000000000..a161e08f772 --- /dev/null +++ b/internal/model/document.go @@ -0,0 +1,51 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// Document document model +type Document struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Thumbnail *string `gorm:"column:thumbnail;type:longtext" json:"thumbnail,omitempty"` + KbID string `gorm:"column:kb_id;size:256;not null;index" json:"kb_id"` + ParserID string `gorm:"column:parser_id;size:32;not null;index" json:"parser_id"` + PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` + ParserConfig JSONMap `gorm:"column:parser_config;type:json;not null;default:'{\"pages\":[[1,1000000]],\"table_context_size\":0,\"image_context_size\":0}'" json:"parser_config"` + SourceType string `gorm:"column:source_type;size:128;not null;default:local;index" json:"source_type"` + Type string `gorm:"column:type;size:32;not null;index" json:"type"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + Name *string `gorm:"column:name;size:255;index" json:"name,omitempty"` + Location *string `gorm:"column:location;size:255;index" json:"location,omitempty"` + Size int64 `gorm:"column:size;default:0;index" json:"size"` + TokenNum int64 `gorm:"column:token_num;default:0;index" json:"token_num"` + ChunkNum int64 `gorm:"column:chunk_num;default:0;index" json:"chunk_num"` + Progress float64 `gorm:"column:progress;default:0;index" json:"progress"` + ProgressMsg *string `gorm:"column:progress_msg;type:longtext" json:"progress_msg,omitempty"` + ProcessBeginAt *time.Time `gorm:"column:process_begin_at;index" json:"process_begin_at,omitempty"` + ProcessDuration float64 `gorm:"column:process_duration;default:0" json:"process_duration"` + MetaFields *JSONMap `gorm:"column:meta_fields;type:json" json:"meta_fields,omitempty"` + Suffix string `gorm:"column:suffix;size:32;not null;index" json:"suffix"` + Run *string `gorm:"column:run;size:1;index" json:"run,omitempty"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (Document) TableName() string { + return "document" +} diff --git a/internal/model/evaluation.go b/internal/model/evaluation.go new file mode 100644 index 00000000000..5b9bac787ac --- /dev/null +++ b/internal/model/evaluation.go @@ -0,0 +1,87 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// EvaluationDataset evaluation dataset model +type EvaluationDataset struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name string `gorm:"column:name;size:255;not null;index" json:"name"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + KbIDs JSONMap `gorm:"column:kb_ids;type:json;not null" json:"kb_ids"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + Status int64 `gorm:"column:status;default:1;index" json:"status"` + BaseModel +} + +// TableName specify table name +func (EvaluationDataset) TableName() string { + return "evaluation_datasets" +} + +// EvaluationCase evaluation case model +type EvaluationCase struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DatasetID string `gorm:"column:dataset_id;size:32;not null;index" json:"dataset_id"` + Question string `gorm:"column:question;type:longtext;not null" json:"question"` + ReferenceAnswer *string `gorm:"column:reference_answer;type:longtext" json:"reference_answer,omitempty"` + RelevantDocIDs *JSONMap `gorm:"column:relevant_doc_ids;type:json" json:"relevant_doc_ids,omitempty"` + RelevantChunkIDs *JSONMap `gorm:"column:relevant_chunk_ids;type:json" json:"relevant_chunk_ids,omitempty"` + Metadata *JSONMap `gorm:"column:metadata;type:json" json:"metadata,omitempty"` + BaseModel +} + +// TableName specify table name +func (EvaluationCase) TableName() string { + return "evaluation_cases" +} + +// EvaluationRun evaluation run model +type EvaluationRun struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DatasetID string `gorm:"column:dataset_id;size:32;not null;index" json:"dataset_id"` + DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` + Name string `gorm:"column:name;size:255;not null" json:"name"` + ConfigSnapshot JSONMap `gorm:"column:config_snapshot;type:json;not null" json:"config_snapshot"` + MetricsSummary *JSONMap `gorm:"column:metrics_summary;type:json" json:"metrics_summary,omitempty"` + Status string `gorm:"column:status;size:32;not null;default:PENDING" json:"status"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + BaseModel +} + +// TableName specify table name +func (EvaluationRun) TableName() string { + return "evaluation_runs" +} + +// EvaluationResult evaluation result model +type EvaluationResult struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + RunID string `gorm:"column:run_id;size:32;not null;index" json:"run_id"` + CaseID string `gorm:"column:case_id;size:32;not null;index" json:"case_id"` + GeneratedAnswer string `gorm:"column:generated_answer;type:longtext;not null" json:"generated_answer"` + RetrievedChunks JSONMap `gorm:"column:retrieved_chunks;type:json;not null" json:"retrieved_chunks"` + Metrics JSONMap `gorm:"column:metrics;type:json;not null" json:"metrics"` + ExecutionTime float64 `gorm:"column:execution_time;not null" json:"execution_time"` + TokenUsage *JSONMap `gorm:"column:token_usage;type:json" json:"token_usage,omitempty"` + BaseModel +} + +// TableName specify table name +func (EvaluationResult) TableName() string { + return "evaluation_results" +} diff --git a/internal/model/file.go b/internal/model/file.go new file mode 100644 index 00000000000..096ce27079c --- /dev/null +++ b/internal/model/file.go @@ -0,0 +1,49 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// File file model +type File struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + ParentID string `gorm:"column:parent_id;size:32;not null;index" json:"parent_id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + Name string `gorm:"column:name;size:255;not null;index" json:"name"` + Location *string `gorm:"column:location;size:255;index" json:"location,omitempty"` + Size int64 `gorm:"column:size;default:0;index" json:"size"` + Type string `gorm:"column:type;size:32;not null;index" json:"type"` + SourceType string `gorm:"column:source_type;size:128;not null;default:'';index" json:"source_type"` + BaseModel +} + +// TableName specify table name +func (File) TableName() string { + return "file" +} + +// File2Document file to document mapping model +type File2Document struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + FileID *string `gorm:"column:file_id;size:32;index" json:"file_id,omitempty"` + DocumentID *string `gorm:"column:document_id;size:32;index" json:"document_id,omitempty"` + BaseModel +} + +// TableName specify table name +func (File2Document) TableName() string { + return "file2document" +} diff --git a/internal/model/kb.go b/internal/model/kb.go new file mode 100644 index 00000000000..8862b1e1acc --- /dev/null +++ b/internal/model/kb.go @@ -0,0 +1,70 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// Knowledgebase knowledge base model +type Knowledgebase struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name string `gorm:"column:name;size:128;not null;index" json:"name"` + Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` + Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"` + TokenNum int64 `gorm:"column:token_num;default:0;index" json:"token_num"` + ChunkNum int64 `gorm:"column:chunk_num;default:0;index" json:"chunk_num"` + SimilarityThreshold float64 `gorm:"column:similarity_threshold;default:0.2;index" json:"similarity_threshold"` + VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"` + ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"` + PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` + ParserConfig JSONMap `gorm:"column:parser_config;type:json;not null;default:'{\"pages\":[[1,1000000]],\"table_context_size\":0,\"image_context_size\":0}'" json:"parser_config"` + Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"` + GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"` + GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"` + RaptorTaskID *string `gorm:"column:raptor_task_id;size:32;index" json:"raptor_task_id,omitempty"` + RaptorTaskFinishAt *time.Time `gorm:"column:raptor_task_finish_at" json:"raptor_task_finish_at,omitempty"` + MindmapTaskID *string `gorm:"column:mindmap_task_id;size:32;index" json:"mindmap_task_id,omitempty"` + MindmapTaskFinishAt *time.Time `gorm:"column:mindmap_task_finish_at" json:"mindmap_task_finish_at,omitempty"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (Knowledgebase) TableName() string { + return "knowledgebase" +} + +// InvitationCode invitation code model +type InvitationCode struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Code string `gorm:"column:code;size:32;not null;index" json:"code"` + VisitTime *time.Time `gorm:"column:visit_time;index" json:"visit_time,omitempty"` + UserID *string `gorm:"column:user_id;size:32;index" json:"user_id,omitempty"` + TenantID *string `gorm:"column:tenant_id;size:32;index" json:"tenant_id,omitempty"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (InvitationCode) TableName() string { + return "invitation_code" +} diff --git a/internal/model/llm.go b/internal/model/llm.go new file mode 100644 index 00000000000..96377d1ebb1 --- /dev/null +++ b/internal/model/llm.go @@ -0,0 +1,76 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// LLMFactories LLM factory model +type LLMFactories struct { + Name string `gorm:"column:name;primaryKey;size:128" json:"name"` + Logo *string `gorm:"column:logo;type:longtext" json:"logo,omitempty"` + Tags string `gorm:"column:tags;size:255;not null;index" json:"tags"` + Rank int64 `gorm:"column:rank;default:0" json:"rank"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (LLMFactories) TableName() string { + return "llm_factories" +} + +// LLM LLM model +type LLM struct { + LLMName string `gorm:"column:llm_name;size:128;not null;primaryKey" json:"llm_name"` + ModelType string `gorm:"column:model_type;size:128;not null;index" json:"model_type"` + FID string `gorm:"column:fid;size:128;not null;primaryKey" json:"fid"` + MaxTokens int64 `gorm:"column:max_tokens;default:0" json:"max_tokens"` + Tags string `gorm:"column:tags;size:255;not null;index" json:"tags"` + IsTools bool `gorm:"column:is_tools;default:false" json:"is_tools"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (LLM) TableName() string { + return "llm" +} + +// TenantLangfuse tenant langfuse model +type TenantLangfuse struct { + TenantID string `gorm:"column:tenant_id;primaryKey;size:32" json:"tenant_id"` + SecretKey string `gorm:"column:secret_key;size:2048;not null;index" json:"secret_key"` + PublicKey string `gorm:"column:public_key;size:2048;not null;index" json:"public_key"` + Host string `gorm:"column:host;size:128;not null;index" json:"host"` + BaseModel +} + +// TableName specify table name +func (TenantLangfuse) TableName() string { + return "tenant_langfuse" +} + +// MyLLM represents LLM information for a tenant with factory details +type MyLLM struct { + LLMFactory string `gorm:"column:llm_factory" json:"llm_factory"` + Logo *string `gorm:"column:logo" json:"logo,omitempty"` + Tags string `gorm:"column:tags" json:"tags"` + ModelType string `gorm:"column:model_type" json:"model_type"` + LLMName string `gorm:"column:llm_name" json:"llm_name"` + UsedTokens int64 `gorm:"column:used_tokens" json:"used_tokens"` + Status string `gorm:"column:status" json:"status"` + APIBase string `gorm:"column:api_base" json:"api_base,omitempty"` + MaxTokens int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"` +} diff --git a/internal/model/mcp.go b/internal/model/mcp.go new file mode 100644 index 00000000000..044bbdab149 --- /dev/null +++ b/internal/model/mcp.go @@ -0,0 +1,35 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// MCPServer MCP server model +type MCPServer struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Name string `gorm:"column:name;size:255;not null" json:"name"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + URL string `gorm:"column:url;size:2048;not null" json:"url"` + ServerType string `gorm:"column:server_type;size:32;not null" json:"server_type"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + Variables JSONMap `gorm:"column:variables;type:json;default:'{}'" json:"variables,omitempty"` + Headers JSONMap `gorm:"column:headers;type:json;default:'{}'" json:"headers,omitempty"` + BaseModel +} + +// TableName specify table name +func (MCPServer) TableName() string { + return "mcp_server" +} diff --git a/internal/model/memory.go b/internal/model/memory.go new file mode 100644 index 00000000000..28f9f58c1c1 --- /dev/null +++ b/internal/model/memory.go @@ -0,0 +1,42 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// Memory memory model +type Memory struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Name string `gorm:"column:name;size:128;not null" json:"name"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + MemoryType int64 `gorm:"column:memory_type;default:1;index" json:"memory_type"` + StorageType string `gorm:"column:storage_type;size:32;not null;default:table;index" json:"storage_type"` + EmbdID string `gorm:"column:embd_id;size:128;not null" json:"embd_id"` + LLMID string `gorm:"column:llm_id;size:128;not null" json:"llm_id"` + Permissions string `gorm:"column:permissions;size:16;not null;default:me;index" json:"permissions"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + MemorySize int64 `gorm:"column:memory_size;default:5242880;not null" json:"memory_size"` + ForgettingPolicy string `gorm:"column:forgetting_policy;size:32;not null;default:FIFO" json:"forgetting_policy"` + Temperature float64 `gorm:"column:temperature;default:0.5;not null" json:"temperature"` + SystemPrompt *string `gorm:"column:system_prompt;type:longtext" json:"system_prompt,omitempty"` + UserPrompt *string `gorm:"column:user_prompt;type:longtext" json:"user_prompt,omitempty"` + BaseModel +} + +// TableName specify table name +func (Memory) TableName() string { + return "memory" +} diff --git a/internal/model/pipeline.go b/internal/model/pipeline.go new file mode 100644 index 00000000000..a47d6119871 --- /dev/null +++ b/internal/model/pipeline.go @@ -0,0 +1,49 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// PipelineOperationLog pipeline operation log model +type PipelineOperationLog struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DocumentID string `gorm:"column:document_id;size:32;index" json:"document_id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + KbID string `gorm:"column:kb_id;size:32;not null;index" json:"kb_id"` + PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` + PipelineTitle *string `gorm:"column:pipeline_title;size:32;index" json:"pipeline_title,omitempty"` + ParserID string `gorm:"column:parser_id;size:32;not null;index" json:"parser_id"` + DocumentName string `gorm:"column:document_name;size:255;not null" json:"document_name"` + DocumentSuffix string `gorm:"column:document_suffix;size:255;not null" json:"document_suffix"` + DocumentType string `gorm:"column:document_type;size:255;not null" json:"document_type"` + SourceFrom string `gorm:"column:source_from;size:255;not null" json:"source_from"` + Progress float64 `gorm:"column:progress;default:0;index" json:"progress"` + ProgressMsg *string `gorm:"column:progress_msg;type:longtext" json:"progress_msg,omitempty"` + ProcessBeginAt *time.Time `gorm:"column:process_begin_at;index" json:"process_begin_at,omitempty"` + ProcessDuration float64 `gorm:"column:process_duration;default:0" json:"process_duration"` + DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + TaskType string `gorm:"column:task_type;size:32;not null;default:''" json:"task_type"` + OperationStatus string `gorm:"column:operation_status;size:32;not null" json:"operation_status"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (PipelineOperationLog) TableName() string { + return "pipeline_operation_log" +} diff --git a/internal/model/search.go b/internal/model/search.go new file mode 100644 index 00000000000..da95ccd6939 --- /dev/null +++ b/internal/model/search.go @@ -0,0 +1,35 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// Search search model +type Search struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name string `gorm:"column:name;size:128;not null;index" json:"name"` + Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + SearchConfig JSONMap `gorm:"column:search_config;type:json;not null;default:'{\"kb_ids\":[],\"doc_ids\":[],\"similarity_threshold\":0.2,\"vector_similarity_weight\":0.3,\"use_kg\":false,\"rerank_id\":\"\",\"top_k\":1024,\"summary\":false,\"chat_id\":\"\",\"chat_settingcross_languages\":[],\"highlight\":false,\"keyword\":false,\"web_search\":false,\"related_search\":false,\"query_mindmap\":false}'" json:"search_config"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (Search) TableName() string { + return "search" +} diff --git a/internal/model/system.go b/internal/model/system.go new file mode 100644 index 00000000000..48775561136 --- /dev/null +++ b/internal/model/system.go @@ -0,0 +1,30 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// SystemSettings system settings model +type SystemSettings struct { + Name string `gorm:"column:name;primaryKey;size:128" json:"name"` + Source string `gorm:"column:source;size:32;not null" json:"source"` + DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` + Value string `gorm:"column:value;size:1024;not null" json:"value"` +} + +// TableName specify table name +func (SystemSettings) TableName() string { + return "system_settings" +} diff --git a/internal/model/task.go b/internal/model/task.go new file mode 100644 index 00000000000..94fe3f27838 --- /dev/null +++ b/internal/model/task.go @@ -0,0 +1,42 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// Task task model +type Task struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DocID string `gorm:"column:doc_id;size:32;not null;index" json:"doc_id"` + FromPage int64 `gorm:"column:from_page;default:0" json:"from_page"` + ToPage int64 `gorm:"column:to_page;default:100000000" json:"to_page"` + TaskType string `gorm:"column:task_type;size:32;not null;default:''" json:"task_type"` + Priority int64 `gorm:"column:priority;default:0" json:"priority"` + BeginAt *time.Time `gorm:"column:begin_at;index" json:"begin_at,omitempty"` + ProcessDuration float64 `gorm:"column:process_duration;default:0" json:"process_duration"` + Progress float64 `gorm:"column:progress;default:0;index" json:"progress"` + ProgressMsg *string `gorm:"column:progress_msg;type:longtext" json:"progress_msg,omitempty"` + RetryCount int64 `gorm:"column:retry_count;default:0" json:"retry_count"` + Digest *string `gorm:"column:digest;type:longtext" json:"digest,omitempty"` + ChunkIDs *string `gorm:"column:chunk_ids;type:longtext" json:"chunk_ids,omitempty"` + BaseModel +} + +// TableName specify table name +func (Task) TableName() string { + return "task" +} diff --git a/internal/model/tenant.go b/internal/model/tenant.go new file mode 100644 index 00000000000..f7f76df8d20 --- /dev/null +++ b/internal/model/tenant.go @@ -0,0 +1,39 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// Tenant tenant model +type Tenant struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Name *string `gorm:"column:name;size:100;index" json:"name,omitempty"` + PublicKey *string `gorm:"column:public_key;size:255;index" json:"public_key,omitempty"` + LLMID string `gorm:"column:llm_id;size:128;not null;index" json:"llm_id"` + EmbDID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` + ASRID string `gorm:"column:asr_id;size:128;not null;index" json:"asr_id"` + Img2TxtID string `gorm:"column:img2txt_id;size:128;not null;index" json:"img2txt_id"` + RerankID string `gorm:"column:rerank_id;size:128;not null;index" json:"rerank_id"` + TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"` + ParserIDs string `gorm:"column:parser_ids;size:256;not null" json:"parser_ids"` + Credit int64 `gorm:"column:credit;default:512;index" json:"credit"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (Tenant) TableName() string { + return "tenant" +} diff --git a/internal/model/tenant_llm.go b/internal/model/tenant_llm.go new file mode 100644 index 00000000000..dbadca6bd95 --- /dev/null +++ b/internal/model/tenant_llm.go @@ -0,0 +1,36 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// TenantLLM tenant LLM model +type TenantLLM struct { + TenantID string `gorm:"column:tenant_id;size:32;not null;primaryKey" json:"tenant_id"` + LLMFactory string `gorm:"column:llm_factory;size:128;not null;primaryKey" json:"llm_factory"` + ModelType string `gorm:"column:model_type;size:128;not null;index" json:"model_type"` + LLMName string `gorm:"column:llm_name;size:128;not null;primaryKey;default:\"\"" json:"llm_name"` + APIKey string `gorm:"column:api_key;type:longtext" json:"api_key,omitempty"` + APIBase string `gorm:"column:api_base;size:255" json:"api_base,omitempty"` + MaxTokens int64 `gorm:"column:max_tokens;default:8192;index" json:"max_tokens"` + UsedTokens int64 `gorm:"column:used_tokens;default:0;index" json:"used_tokens"` + Status string `gorm:"column:status;size:1;not null;default:1;index" json:"status"` + BaseModel +} + +// TableName specify table name +func (TenantLLM) TableName() string { + return "tenant_llm" +} diff --git a/internal/model/types.go b/internal/model/types.go new file mode 100644 index 00000000000..7c534c559fb --- /dev/null +++ b/internal/model/types.go @@ -0,0 +1,71 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// ModelType represents the type of model +type ModelType string + +const ( + // ModelTypeChat chat model + ModelTypeChat ModelType = "chat" + // ModelTypeEmbedding embedding model + ModelTypeEmbedding ModelType = "embedding" + // ModelTypeSpeech2Text speech to text model + ModelTypeSpeech2Text ModelType = "speech2text" + // ModelTypeImage2Text image to text model + ModelTypeImage2Text ModelType = "image2text" + // ModelTypeRerank rerank model + ModelTypeRerank ModelType = "rerank" + // ModelTypeTTS text to speech model + ModelTypeTTS ModelType = "tts" + // ModelTypeOCR optical character recognition model + ModelTypeOCR ModelType = "ocr" +) + +// EmbeddingModel interface for embedding models +type EmbeddingModel interface { + // Encode encodes a list of texts into embeddings + Encode(texts []string) ([][]float64, error) + // EncodeQuery encodes a single query string into embedding + EncodeQuery(query string) ([]float64, error) +} + +// ChatModel interface for chat models +type ChatModel interface { + // Chat sends a message and returns response + Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, error) + // ChatStreamly sends a message and streams response + ChatStreamly(system string, history []map[string]string, genConf map[string]interface{}) (<-chan string, error) +} + +// RerankModel interface for rerank models +type RerankModel interface { + // Similarity calculates similarity between query and texts + Similarity(query string, texts []string) ([]float64, error) +} + +// ModelConfig represents configuration for a model +type ModelConfig struct { + TenantID string `json:"tenant_id"` + LLMFactory string `json:"llm_factory"` + ModelType ModelType `json:"model_type"` + LLMName string `json:"llm_name"` + APIKey string `json:"api_key"` + APIBase string `json:"api_base"` + MaxTokens int64 `json:"max_tokens"` + IsTools bool `json:"is_tools"` +} diff --git a/internal/model/user.go b/internal/model/user.go new file mode 100644 index 00000000000..05f5633517e --- /dev/null +++ b/internal/model/user.go @@ -0,0 +1,45 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +import "time" + +// User user model +type User struct { + ID string `gorm:"column:id;size:32;primaryKey" json:"id"` + AccessToken *string `gorm:"column:access_token;size:255;index" json:"access_token,omitempty"` + Nickname string `gorm:"column:nickname;size:100;not null;index" json:"nickname"` + Password *string `gorm:"column:password;size:255;index" json:"-"` + Email string `gorm:"column:email;size:255;not null;index" json:"email"` + Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` + Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` + ColorSchema *string `gorm:"column:color_schema;size:32;index" json:"color_schema,omitempty"` + Timezone *string `gorm:"column:timezone;size:64;index" json:"timezone,omitempty"` + LastLoginTime *time.Time `gorm:"column:last_login_time;index" json:"last_login_time,omitempty"` + IsAuthenticated string `gorm:"column:is_authenticated;size:1;not null;default:1;index" json:"is_authenticated"` + IsActive string `gorm:"column:is_active;size:1;not null;default:1;index" json:"is_active"` + IsAnonymous string `gorm:"column:is_anonymous;size:1;not null;default:0;index" json:"is_anonymous"` + LoginChannel *string `gorm:"column:login_channel;index" json:"login_channel,omitempty"` + Status *string `gorm:"column:status;size:1;default:1;index" json:"status"` + IsSuperuser *bool `gorm:"column:is_superuser;index" json:"is_superuser,omitempty"` + BaseModel +} + +// TableName specify table name +func (User) TableName() string { + return "user" +} diff --git a/internal/model/user_tenant.go b/internal/model/user_tenant.go new file mode 100644 index 00000000000..963a6dbe545 --- /dev/null +++ b/internal/model/user_tenant.go @@ -0,0 +1,33 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package model + +// UserTenant user tenant relationship model +type UserTenant struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + UserID string `gorm:"column:user_id;size:32;not null;index" json:"user_id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Role string `gorm:"column:role;size:32;not null;index" json:"role"` + InvitedBy string `gorm:"column:invited_by;size:32;not null;index" json:"invited_by"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + BaseModel +} + +// TableName specify table name +func (UserTenant) TableName() string { + return "user_tenant" +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 00000000000..a7a77867727 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,194 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package router + +import ( + "github.com/gin-gonic/gin" + + "ragflow/internal/handler" +) + +// Router router +type Router struct { + userHandler *handler.UserHandler + tenantHandler *handler.TenantHandler + documentHandler *handler.DocumentHandler + systemHandler *handler.SystemHandler + knowledgebaseHandler *handler.KnowledgebaseHandler + chunkHandler *handler.ChunkHandler + llmHandler *handler.LLMHandler + chatHandler *handler.ChatHandler + chatSessionHandler *handler.ChatSessionHandler + connectorHandler *handler.ConnectorHandler + searchHandler *handler.SearchHandler + fileHandler *handler.FileHandler +} + +// NewRouter create router +func NewRouter( + userHandler *handler.UserHandler, + tenantHandler *handler.TenantHandler, + documentHandler *handler.DocumentHandler, + systemHandler *handler.SystemHandler, + knowledgebaseHandler *handler.KnowledgebaseHandler, + chunkHandler *handler.ChunkHandler, + llmHandler *handler.LLMHandler, + chatHandler *handler.ChatHandler, + chatSessionHandler *handler.ChatSessionHandler, + connectorHandler *handler.ConnectorHandler, + searchHandler *handler.SearchHandler, + fileHandler *handler.FileHandler, +) *Router { + return &Router{ + userHandler: userHandler, + tenantHandler: tenantHandler, + documentHandler: documentHandler, + systemHandler: systemHandler, + knowledgebaseHandler: knowledgebaseHandler, + chunkHandler: chunkHandler, + llmHandler: llmHandler, + chatHandler: chatHandler, + chatSessionHandler: chatSessionHandler, + connectorHandler: connectorHandler, + searchHandler: searchHandler, + fileHandler: fileHandler, + } +} + +// Setup setup routes +func (r *Router) Setup(engine *gin.Engine) { + // Health check + engine.GET("/health", func(c *gin.Context) { + c.JSON(200, gin.H{ + "status": "ok", + }) + }) + + // System endpoints + engine.GET("/v1/system/ping", r.systemHandler.Ping) + engine.GET("/v1/system/config", r.systemHandler.GetConfig) + engine.GET("/v1/system/configs", r.systemHandler.GetConfigs) + engine.GET("/v1/system/version", r.systemHandler.GetVersion) + + // User login by email endpoint + engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + // User login channels endpoint + engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels) + // User logout endpoint + engine.GET("/v1/user/logout", r.userHandler.Logout) + // User info endpoint + engine.GET("/v1/user/info", r.userHandler.Info) + // User tenant info endpoint + engine.GET("/v1/user/tenant_info", r.tenantHandler.TenantInfo) + // Tenant list endpoint + engine.GET("/v1/tenant/list", r.tenantHandler.TenantList) + // User settings endpoint + engine.POST("/v1/user/setting", r.userHandler.Setting) + // User change password endpoint + engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword) + + // API v1 route group + v1 := engine.Group("/api/v1") + { + // User routes + users := v1.Group("/users") + { + users.POST("/register", r.userHandler.Register) + users.POST("/login", r.userHandler.Login) + users.GET("", r.userHandler.ListUsers) + users.GET("/:id", r.userHandler.GetUserByID) + } + + // Document routes + documents := v1.Group("/documents") + { + documents.POST("", r.documentHandler.CreateDocument) + documents.GET("", r.documentHandler.ListDocuments) + documents.GET("/:id", r.documentHandler.GetDocumentByID) + documents.PUT("/:id", r.documentHandler.UpdateDocument) + documents.DELETE("/:id", r.documentHandler.DeleteDocument) + } + + // Author routes + authors := v1.Group("/authors") + { + authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID) + } + + // Knowledge base routes + kb := engine.Group("/v1/kb") + { + kb.POST("/list", r.knowledgebaseHandler.ListKbs) + } + + // Chunk routes + chunk := engine.Group("/v1/chunk") + { + chunk.POST("/retrieval_test", r.chunkHandler.RetrievalTest) + } + + // LLM routes + llm := engine.Group("/v1/llm") + { + llm.GET("/my_llms", r.llmHandler.GetMyLLMs) + llm.GET("/factories", r.llmHandler.Factories) + llm.GET("/list", r.llmHandler.ListApp) + } + + // Chat routes + chat := engine.Group("/v1/dialog") + { + chat.GET("/list", r.chatHandler.ListChats) + chat.POST("/next", r.chatHandler.ListChatsNext) + chat.POST("/set", r.chatHandler.SetDialog) + chat.POST("/rm", r.chatHandler.RemoveChats) + } + + // Chat session (conversation) routes + session := engine.Group("/v1/conversation") + { + session.POST("/set", r.chatSessionHandler.SetChatSession) + session.POST("/rm", r.chatSessionHandler.RemoveChatSessions) + session.GET("/list", r.chatSessionHandler.ListChatSessions) + session.POST("/completion", r.chatSessionHandler.Completion) + } + + // Connector routes + connector := engine.Group("/v1/connector") + { + connector.GET("/list", r.connectorHandler.ListConnectors) + } + + // Search routes + search := engine.Group("/v1/search") + { + search.POST("/list", r.searchHandler.ListSearchApps) + } + + // File routes + file := engine.Group("/v1/file") + { + file.GET("/list", r.fileHandler.ListFiles) + file.GET("/root_folder", r.fileHandler.GetRootFolder) + file.GET("/parent_folder", r.fileHandler.GetParentFolder) + file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders) + } + } + + // Handle undefined routes + engine.NoRoute(handler.HandleNoRoute) +} diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 00000000000..b29cef02996 --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,294 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package server + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/spf13/viper" + "go.uber.org/zap" +) + +// DefaultConnectTimeout default connection timeout for external services +const DefaultConnectTimeout = 5 * time.Second + +// Config application configuration +type Config struct { + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Log LogConfig `mapstructure:"log"` + DocEngine DocEngineConfig `mapstructure:"doc_engine"` + RegisterEnabled int `mapstructure:"register_enabled"` + OAuth map[string]OAuthConfig `mapstructure:"oauth"` +} + +// OAuthConfig OAuth configuration for a channel +type OAuthConfig struct { + DisplayName string `mapstructure:"display_name"` + Icon string `mapstructure:"icon"` +} + +// ServerConfig server configuration +type ServerConfig struct { + Mode string `mapstructure:"mode"` // debug, release + Port int `mapstructure:"port"` +} + +// DatabaseConfig database configuration +type DatabaseConfig struct { + Driver string `mapstructure:"driver"` // mysql + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Database string `mapstructure:"database"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + Charset string `mapstructure:"charset"` +} + +// LogConfig logging configuration +type LogConfig struct { + Level string `mapstructure:"level"` // debug, info, warn, error + Format string `mapstructure:"format"` // json, text +} + +// DocEngineConfig document engine configuration +type DocEngineConfig struct { + Type EngineType `mapstructure:"type"` + ES *ElasticsearchConfig `mapstructure:"es"` + Infinity *InfinityConfig `mapstructure:"infinity"` +} + +// EngineType document engine type +type EngineType string + +const ( + EngineElasticsearch EngineType = "elasticsearch" + EngineInfinity EngineType = "infinity" +) + +// ElasticsearchConfig Elasticsearch configuration +type ElasticsearchConfig struct { + Hosts string `mapstructure:"hosts"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` +} + +// InfinityConfig Infinity configuration +type InfinityConfig struct { + URI string `mapstructure:"uri"` + PostgresPort int `mapstructure:"postgres_port"` + DBName string `mapstructure:"db_name"` +} + +// RedisConfig Redis configuration +type RedisConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Password string `mapstructure:"password"` + DB int `mapstructure:"db"` +} + +var ( + globalConfig *Config + globalViper *viper.Viper + zapLogger *zap.Logger +) + +// Init initialize configuration +func Init(configPath string) error { + v := viper.New() + + // Set configuration file path + if configPath != "" { + v.SetConfigFile(configPath) + } else { + // Try to load service_conf.yaml from conf directory first + v.SetConfigName("service_conf") + v.SetConfigType("yaml") + v.AddConfigPath("./conf") + v.AddConfigPath(".") + v.AddConfigPath("./config") + v.AddConfigPath("./internal/config") + v.AddConfigPath("/etc/ragflow/") + } + + // Read environment variables + v.SetEnvPrefix("RAGFLOW") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() + + // Read configuration file + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return fmt.Errorf("read config file error: %w", err) + } + zapLogger.Info("Config file not found, using environment variables only") + } + + // Save viper instance + globalViper = v + + // Unmarshal configuration to globalConfig + // Note: This will only unmarshal fields that match the Config struct + if err := v.Unmarshal(&globalConfig); err != nil { + return fmt.Errorf("unmarshal config error: %w", err) + } + + // Load REGISTER_ENABLED from environment variable (default: 1) + registerEnabled := 1 + if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" { + if parsed, err := strconv.Atoi(envVal); err == nil { + registerEnabled = parsed + } + } + globalConfig.RegisterEnabled = registerEnabled + + // If we loaded service_conf.yaml, map mysql fields to DatabaseConfig + if globalConfig != nil && globalConfig.Database.Host == "" { + // Try to map from mysql section + if v.IsSet("mysql") { + mysqlConfig := v.Sub("mysql") + if mysqlConfig != nil { + globalConfig.Database.Driver = "mysql" + globalConfig.Database.Host = mysqlConfig.GetString("host") + globalConfig.Database.Port = mysqlConfig.GetInt("port") + globalConfig.Database.Database = mysqlConfig.GetString("name") + globalConfig.Database.Username = mysqlConfig.GetString("user") + globalConfig.Database.Password = mysqlConfig.GetString("password") + globalConfig.Database.Charset = "utf8mb4" + } + } + } + + // Map ragflow section to ServerConfig + if globalConfig != nil && globalConfig.Server.Port == 0 { + // Try to map from ragflow section + if v.IsSet("ragflow") { + ragflowConfig := v.Sub("ragflow") + if ragflowConfig != nil { + globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default + // globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct + // If mode is not set, default to debug + if globalConfig.Server.Mode == "" { + globalConfig.Server.Mode = "release" + } + } + } + } + + // Map redis section to RedisConfig + if globalConfig != nil && globalConfig.Redis.Host != "" { + if v.IsSet("redis") { + redisConfig := v.Sub("redis") + if redisConfig != nil { + hostStr := redisConfig.GetString("host") + // Handle host:port format (e.g., "localhost:6379") + if hostStr == "" { + return fmt.Errorf("Empty host of redis configuration") + } + + if idx := strings.LastIndex(hostStr, ":"); idx != -1 { + globalConfig.Redis.Host = hostStr[:idx] + if portStr := hostStr[idx+1:]; portStr != "" { + if port, err := strconv.Atoi(portStr); err == nil { + globalConfig.Redis.Port = port + } + } + } else { + return fmt.Errorf("Error address format of redis: %s", hostStr) + } + + globalConfig.Redis.Password = redisConfig.GetString("password") + globalConfig.Redis.DB = redisConfig.GetInt("db") + } + } + } + + // Map doc_engine section to DocEngineConfig + if globalConfig != nil && globalConfig.DocEngine.Type == "" { + // Try to map from doc_engine section + if v.IsSet("doc_engine") { + docEngineConfig := v.Sub("doc_engine") + if docEngineConfig != nil { + globalConfig.DocEngine.Type = EngineType(docEngineConfig.GetString("type")) + } + } + // Also check legacy es section for backward compatibility + if v.IsSet("es") { + esConfig := v.Sub("es") + if esConfig != nil { + if globalConfig.DocEngine.Type == "" { + globalConfig.DocEngine.Type = EngineElasticsearch + } + if globalConfig.DocEngine.ES == nil { + globalConfig.DocEngine.ES = &ElasticsearchConfig{ + Hosts: esConfig.GetString("hosts"), + Username: esConfig.GetString("username"), + Password: esConfig.GetString("password"), + } + } + } + } + if v.IsSet("infinity") { + infConfig := v.Sub("infinity") + if infConfig != nil { + if globalConfig.DocEngine.Type == "" { + globalConfig.DocEngine.Type = EngineInfinity + } + if globalConfig.DocEngine.Infinity == nil { + globalConfig.DocEngine.Infinity = &InfinityConfig{ + URI: infConfig.GetString("uri"), + PostgresPort: infConfig.GetInt("postgres_port"), + DBName: infConfig.GetString("db_name"), + } + } + } + } + } + + return nil +} + +// Get get global configuration +func GetConfig() *Config { + return globalConfig +} + +// SetLogger sets the logger instance +func SetLogger(l *zap.Logger) { + zapLogger = l +} + +// PrintAll prints all configuration settings +func PrintAll() { + if globalViper == nil { + zapLogger.Info("Configuration not initialized") + return + } + + allSettings := globalViper.AllSettings() + zapLogger.Info("=== All Configuration Settings ===") + for key, value := range allSettings { + zapLogger.Info("config", zap.String("key", key), zap.Any("value", value)) + } + zapLogger.Info("=== End Configuration ===") +} diff --git a/internal/server/model_provider.go b/internal/server/model_provider.go new file mode 100644 index 00000000000..c94a41e91bd --- /dev/null +++ b/internal/server/model_provider.go @@ -0,0 +1,116 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package server + +import ( + "encoding/json" + "fmt" + "os" + "sync" +) + +// ModelProvider represents a model provider configuration +type ModelProvider struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + LLMs []LLM `json:"llm"` + DefaultEmbeddingURL string `json:"default_embedding_url,omitempty"` +} + +// LLM represents a language model within a provider +type LLM struct { + LLMName string `json:"llm_name"` + Tags string `json:"tags"` + MaxTokens int `json:"max_tokens"` + ModelType string `json:"model_type"` + IsTools bool `json:"is_tools"` +} + +var ( + modelProviders []ModelProvider + modelProviderMap map[string]int // name -> index in modelProviders slice + modelProvidersOnce sync.Once + modelProvidersErr error +) + +// LoadModelProviders loads model providers from JSON file. +// If path is empty, it defaults to "conf/model_providers.json" relative to current working directory. +func LoadModelProviders(path string) error { + modelProvidersOnce.Do(func() { + if path == "" { + path = "conf/llm_factories.json" + //path = "conf/model_providers.json" + } + + data, err := os.ReadFile(path) + if err != nil { + modelProvidersErr = fmt.Errorf("failed to read model providers file %s: %w", path, err) + return + } + + var root struct { + Providers []ModelProvider `json:"factory_llm_infos"` + } + if err := json.Unmarshal(data, &root); err != nil { + modelProvidersErr = fmt.Errorf("failed to unmarshal model providers JSON: %w", err) + return + } + + modelProviders = root.Providers + // Build name to index map for fast lookup + modelProviderMap = make(map[string]int, len(modelProviders)) + for i, provider := range modelProviders { + modelProviderMap[provider.Name] = i + } + }) + + return modelProvidersErr +} + +// GetModelProviders returns the loaded model providers. +// Call LoadModelProviders first, otherwise returns empty slice. +func GetModelProviders() []ModelProvider { + return modelProviders +} + +// GetModelProviderByName returns the model provider with the given name. +func GetModelProviderByName(name string) *ModelProvider { + if modelProviderMap == nil { + return nil + } + if idx, ok := modelProviderMap[name]; ok { + return &modelProviders[idx] + } + return nil +} + +// GetLLMByProviderAndName returns the LLM with the given provider name and model name. +func GetLLMByProviderAndName(providerName, modelName string) *LLM { + provider := GetModelProviderByName(providerName) + if provider == nil { + return nil + } + for i := range provider.LLMs { + if provider.LLMs[i].LLMName == modelName { + return &provider.LLMs[i] + } + } + return nil +} diff --git a/internal/server/variable.go b/internal/server/variable.go new file mode 100644 index 00000000000..23f1b4c94b9 --- /dev/null +++ b/internal/server/variable.go @@ -0,0 +1,259 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package server + +import ( + "context" + "fmt" + "ragflow/internal/utility" + "sync" + "time" + + "go.uber.org/zap" + + "ragflow/internal/logger" +) + +// Variables holds all runtime variables that can be changed during system operation +// Unlike Config, these can be modified at runtime +type Variables struct { + SecretKey string `json:"secret_key"` +} + +// VariableStore interface for persistent storage (e.g., Redis) +type VariableStore interface { + Get(key string) (string, error) + Set(key string, value string, exp time.Duration) bool + SetNX(key string, value string, exp time.Duration) bool +} + +var ( + globalVariables *Variables + variablesOnce sync.Once + variablesMu sync.RWMutex +) + +const ( + // DefaultSecretKey is used when no secret key is found in storage + DefaultSecretKey = "infiniflow-token" + // SecretKeyRedisKey is the Redis key for storing secret key + SecretKeyRedisKey = "ragflow:system:secret_key" + // SecretKeyTTL is the TTL for secret key in Redis (0 = no expiration) + SecretKeyTTL = 0 +) + +// InitVariables initializes all runtime variables from persistent storage +// This should be called after Config and Cache are initialized +func InitVariables(store VariableStore) error { + var initErr error + variablesOnce.Do(func() { + globalVariables = &Variables{} + + generatedKey, err := utility.GenerateSecretKey() + if err != nil { + initErr = fmt.Errorf("failed to generate secret key: %w", err) + } + + // Initialize SecretKey + secretKey, err := GetOrCreateKey(store, SecretKeyRedisKey, generatedKey) + if err != nil { + initErr = fmt.Errorf("failed to initialize secret key: %w", err) + } else { + globalVariables.SecretKey = secretKey + logger.Info("Secret key initialized from store") + } + + logger.Info("Server variables initialized successfully") + }) + return initErr +} + +// GetVariables returns the global variables instance +func GetVariables() *Variables { + variablesMu.RLock() + defer variablesMu.RUnlock() + return globalVariables +} + +// GetSecretKey returns the current secret key +func GetSecretKey() string { + variablesMu.RLock() + defer variablesMu.RUnlock() + if globalVariables == nil { + return DefaultSecretKey + } + return globalVariables.SecretKey +} + +// SetSecretKey updates the secret key at runtime +func SetSecretKey(key string) { + variablesMu.Lock() + defer variablesMu.Unlock() + if globalVariables != nil { + globalVariables.SecretKey = key + logger.Info("Secret key updated at runtime") + } +} + +// GetOrCreateKey gets a key from store, or creates it if not exists +// - If key exists in store, returns the stored value +// - If key doesn't exist, calls createFn to generate value, stores it, and returns it +// - Uses SetNX to ensure atomic creation (only one caller succeeds when key doesn't exist) +func GetOrCreateKey(store VariableStore, key string, newValue string) (string, error) { + if store == nil { + err := fmt.Errorf("store is nil") + logger.Warn("VariableStore is nil, cannot get or create key", zap.String("key", key)) + return "store is nil", err + } + + // Try to get existing value + value, err := store.Get(key) + if err != nil { + logger.Warn("Failed to get key from store", zap.String("key", key), zap.Error(err)) + return "", err + } + + // Key exists, return the value + if value != "" { + logger.Debug("Key found in store", zap.String("key", key)) + return value, nil + } + + // Key doesn't exist, generate new value + logger.Info("Generating new value for key", zap.String("key", key)) + + // Try to set with NX (only if not exists) - ensures atomicity + if store.SetNX(key, newValue, SecretKeyTTL) { + logger.Info("New value stored successfully", zap.String("key", key)) + return newValue, nil + } + + // Another process might have set it, try to get again + value, err = store.Get(key) + if err != nil { + logger.Warn("Failed to get key after SetNX", zap.String("key", key), zap.Error(err)) + return newValue, nil // Return our generated value as fallback + } + + if value != "" { + logger.Info("Using value set by another process", zap.String("key", key)) + return value, nil + } + + // If still empty, use our generated value + return newValue, nil +} + +// RefreshVariables refreshes all variables from storage +// Call this when you want to reload variables from persistent storage +func RefreshVariables(store VariableStore) error { + if store == nil { + return fmt.Errorf("store is nil") + } + + variablesMu.Lock() + defer variablesMu.Unlock() + + if globalVariables == nil { + globalVariables = &Variables{} + } + + // Refresh SecretKey + secretKey, err := store.Get(SecretKeyRedisKey) + if err != nil { + logger.Warn("Failed to refresh secret key from store", zap.Error(err)) + return err + } + if secretKey != "" { + globalVariables.SecretKey = secretKey + logger.Info("Secret key refreshed from store") + } + + return nil +} + +// VariableWatcher watches for variable changes in storage +// This can be used to detect changes made by other instances +type VariableWatcher struct { + store VariableStore + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewVariableWatcher creates a new variable watcher +func NewVariableWatcher(store VariableStore) *VariableWatcher { + return &VariableWatcher{ + store: store, + stopChan: make(chan struct{}), + } +} + +// Start starts watching for variable changes +func (w *VariableWatcher) Start(interval time.Duration) { + w.wg.Add(1) + go func() { + defer w.wg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := RefreshVariables(w.store); err != nil { + logger.Debug("Failed to refresh variables", zap.Error(err)) + } + case <-w.stopChan: + return + } + } + }() + logger.Info("Variable watcher started", zap.Duration("interval", interval)) +} + +// Stop stops the variable watcher +func (w *VariableWatcher) Stop() { + close(w.stopChan) + w.wg.Wait() + logger.Info("Variable watcher stopped") +} + +// SaveToStorage saves current variables to persistent storage +func SaveToStorage(store VariableStore) error { + if store == nil { + return fmt.Errorf("store is nil") + } + + variablesMu.RLock() + defer variablesMu.RUnlock() + + if globalVariables == nil { + return fmt.Errorf("variables not initialized") + } + + // Save SecretKey + if !store.Set(SecretKeyRedisKey, globalVariables.SecretKey, SecretKeyTTL) { + return fmt.Errorf("failed to save secret key to store") + } + + logger.Info("Variables saved to storage") + return nil +} + +// WithTimeout creates a context with timeout for variable operations +func WithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), timeout) +} diff --git a/internal/service/chat.go b/internal/service/chat.go new file mode 100644 index 00000000000..3192a2152d9 --- /dev/null +++ b/internal/service/chat.go @@ -0,0 +1,623 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "errors" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// ChatService chat service +type ChatService struct { + chatDAO *dao.ChatDAO + kbDAO *dao.KnowledgebaseDAO + userTenantDAO *dao.UserTenantDAO + tenantDAO *dao.TenantDAO +} + +// NewChatService create chat service +func NewChatService() *ChatService { + return &ChatService{ + chatDAO: dao.NewChatDAO(), + kbDAO: dao.NewKnowledgebaseDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + tenantDAO: dao.NewTenantDAO(), + } +} + +// ChatWithKBNames chat with knowledge base names +type ChatWithKBNames struct { + *model.Chat + KBNames []string `json:"kb_names"` +} + +// ListChatsResponse list chats response +type ListChatsResponse struct { + Chats []*ChatWithKBNames `json:"chats"` +} + +// ListChats list chats for a user +func (s *ChatService) ListChats(userID string, status string) (*ListChatsResponse, error) { + // Get tenant IDs by user ID + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + // For now, use the first tenant ID (primary tenant) + // This matches the Python implementation behavior + var tenantID string + if len(tenantIDs) > 0 { + tenantID = tenantIDs[0] + } else { + tenantID = userID + } + + // Query chats by tenant ID + chats, err := s.chatDAO.ListByTenantID(tenantID, status) + if err != nil { + return nil, err + } + + // Enrich with knowledge base names + var chatsWithKBNames []*ChatWithKBNames + for _, chat := range chats { + kbNames := s.getKBNames(chat.KBIDs) + chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ + Chat: chat, + KBNames: kbNames, + }) + } + + return &ListChatsResponse{ + Chats: chatsWithKBNames, + }, nil +} + +// ListChatsNextRequest list chats next request +type ListChatsNextRequest struct { + OwnerIDs []string `json:"owner_ids,omitempty"` +} + +// ListChatsNextResponse list chats next response +type ListChatsNextResponse struct { + Chats []*ChatWithKBNames `json:"dialogs"` + Total int64 `json:"total"` +} + +// ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next) +func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) { + var chats []*model.Chat + var total int64 + var err error + + if len(ownerIDs) == 0 { + // Get tenant IDs by user ID (joined tenants) + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + // Use database pagination + chats, total, err = s.chatDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords) + if err != nil { + return nil, err + } + } else { + // Filter by owner IDs, manual pagination + chats, total, err = s.chatDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords) + if err != nil { + return nil, err + } + + // Manual pagination + if page > 0 && pageSize > 0 { + start := (page - 1) * pageSize + end := start + pageSize + if start < int(total) { + if end > int(total) { + end = int(total) + } + chats = chats[start:end] + } else { + chats = []*model.Chat{} + } + } + } + + // Enrich with knowledge base names + var chatsWithKBNames []*ChatWithKBNames + for _, chat := range chats { + kbNames := s.getKBNames(chat.KBIDs) + chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ + Chat: chat, + KBNames: kbNames, + }) + } + + return &ListChatsNextResponse{ + Chats: chatsWithKBNames, + Total: total, + }, nil +} + +// getKBNames gets knowledge base names by IDs +func (s *ChatService) getKBNames(kbIDs model.JSONSlice) []string { + var names []string + for _, kbID := range kbIDs { + kbIDStr, ok := kbID.(string) + if !ok { + continue + } + kb, err := s.kbDAO.GetByID(kbIDStr) + if err != nil || kb == nil { + continue + } + // Only include valid KBs + if kb.Status != nil && *kb.Status == "1" { + names = append(names, kb.Name) + } + } + return names +} + +// ParameterConfig parameter configuration in prompt_config +type ParameterConfig struct { + Key string `json:"key"` + Optional bool `json:"optional"` +} + +// PromptConfig prompt configuration +type PromptConfig struct { + System string `json:"system"` + Prologue string `json:"prologue"` + Parameters []ParameterConfig `json:"parameters"` + EmptyResponse string `json:"empty_response"` + TavilyAPIKey string `json:"tavily_api_key,omitempty"` + Keyword bool `json:"keyword,omitempty"` + Quote bool `json:"quote,omitempty"` + Reasoning bool `json:"reasoning,omitempty"` + RefineMultiturn bool `json:"refine_multiturn,omitempty"` + TocEnhance bool `json:"toc_enhance,omitempty"` + TTS bool `json:"tts,omitempty"` + UseKG bool `json:"use_kg,omitempty"` +} + +// SetDialogRequest set chat request +type SetDialogRequest struct { + DialogID string `json:"dialog_id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` + TopN int64 `json:"top_n,omitempty"` + TopK int64 `json:"top_k,omitempty"` + RerankID string `json:"rerank_id,omitempty"` + SimilarityThreshold float64 `json:"similarity_threshold,omitempty"` + VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"` + LLMSetting map[string]interface{} `json:"llm_setting,omitempty"` + MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"` + PromptConfig *PromptConfig `json:"prompt_config" binding:"required"` + KBIDs []string `json:"kb_ids,omitempty"` + LLMID string `json:"llm_id,omitempty"` +} + +// SetDialogResponse set chat response +type SetDialogResponse struct { + *model.Chat + KBNames []string `json:"kb_names"` +} + +// SetDialog create or update a chat +func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialogResponse, error) { + // Determine if this is a create or update operation + isCreate := req.DialogID == "" + + // Validate and process name + name := req.Name + if name == "" { + name = "New Chat" + } + + // Validate name type and content + if strings.TrimSpace(name) == "" { + return nil, errors.New("Chat name can't be empty") + } + + // Check name length (UTF-8 byte length) + if len(name) > 255 { + return nil, fmt.Errorf("Chat name length is %d which is larger than 255", len(name)) + } + + name = strings.TrimSpace(name) + + // Get tenant ID (use userID as default tenant) + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + var tenantID string + if len(tenantIDs) > 0 { + tenantID = tenantIDs[0] + } else { + tenantID = userID + } + + // For create: check for duplicate names and generate unique name + if isCreate { + existingNames, err := s.chatDAO.GetExistingNames(tenantID, "1") + if err != nil { + return nil, err + } + + // Check if name exists (case-insensitive) + nameLower := strings.ToLower(name) + for _, existing := range existingNames { + if strings.ToLower(existing) == nameLower { + // Generate unique name + name = s.generateUniqueName(name, existingNames) + break + } + } + } + + // Set default values + description := req.Description + if description == "" { + description = "A helpful chat" + } + + topN := req.TopN + if topN == 0 { + topN = 6 + } + + topK := req.TopK + if topK == 0 { + topK = 1024 + } + + rerankID := req.RerankID + + similarityThreshold := req.SimilarityThreshold + if similarityThreshold == 0 { + similarityThreshold = 0.1 + } + + vectorSimilarityWeight := req.VectorSimilarityWeight + if vectorSimilarityWeight == 0 { + vectorSimilarityWeight = 0.3 + } + + llmSetting := req.LLMSetting + if llmSetting == nil { + llmSetting = make(map[string]interface{}) + } + + metaDataFilter := req.MetaDataFilter + if metaDataFilter == nil { + metaDataFilter = make(map[string]interface{}) + } + + promptConfig := req.PromptConfig + + // Process kb_ids + kbIDs := req.KBIDs + if kbIDs == nil { + kbIDs = []string{} + } + + // Set default parameters for datasets with knowledge retrieval + // Check if parameters is missing or empty and kb_ids is provided + if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) { + // Check if system prompt uses {knowledge} placeholder + if strings.Contains(promptConfig.System, "{knowledge}") { + // Set default parameters for any dataset with knowledge placeholder + promptConfig.Parameters = []ParameterConfig{ + {Key: "knowledge", Optional: false}, + } + } + } + + // For update: validate that {knowledge} is not used when no KBs or Tavily + if !isCreate { + if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" && strings.Contains(promptConfig.System, "{knowledge}") { + return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here") + } + } + + // Validate parameters + for _, p := range promptConfig.Parameters { + if p.Optional { + continue + } + placeholder := fmt.Sprintf("{%s}", p.Key) + if !strings.Contains(promptConfig.System, placeholder) { + return nil, fmt.Errorf("Parameter '%s' is not used", p.Key) + } + } + + // Check knowledge bases and their embedding models + if len(kbIDs) > 0 { + kbs, err := s.kbDAO.GetByIDs(kbIDs) + if err != nil { + return nil, err + } + + // Check if all KBs use the same embedding model + var embdID string + for i, kb := range kbs { + if i == 0 { + embdID = kb.EmbdID + } else { + // Extract base model name (remove vendor suffix) + embdBase := s.splitModelNameAndFactory(embdID) + kbEmbdBase := s.splitModelNameAndFactory(kb.EmbdID) + if embdBase != kbEmbdBase { + return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs)) + } + } + } + } + + // Get LLM ID (use tenant's default if not provided) + llmID := req.LLMID + if llmID == "" { + tenant, err := s.tenantDAO.GetByID(tenantID) + if err != nil { + return nil, errors.New("Tenant not found") + } + llmID = tenant.LLMID + } + + // Convert prompt config to JSONMap with all fields + promptConfigMap := model.JSONMap{ + "system": promptConfig.System, + "prologue": promptConfig.Prologue, + "empty_response": promptConfig.EmptyResponse, + "keyword": promptConfig.Keyword, + "quote": promptConfig.Quote, + "reasoning": promptConfig.Reasoning, + "refine_multiturn": promptConfig.RefineMultiturn, + "toc_enhance": promptConfig.TocEnhance, + "tts": promptConfig.TTS, + "use_kg": promptConfig.UseKG, + } + if promptConfig.TavilyAPIKey != "" { + promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey + } + if len(promptConfig.Parameters) > 0 { + params := make([]map[string]interface{}, len(promptConfig.Parameters)) + for i, p := range promptConfig.Parameters { + params[i] = map[string]interface{}{ + "key": p.Key, + "optional": p.Optional, + } + } + promptConfigMap["parameters"] = params + } + + // Convert kbIDs to JSONSlice + kbIDsJSON := make(model.JSONSlice, len(kbIDs)) + for i, id := range kbIDs { + kbIDsJSON[i] = id + } + + if isCreate { + // Generate UUID for new chat + newID := uuid.New().String() + newID = strings.ReplaceAll(newID, "-", "") + if len(newID) > 32 { + newID = newID[:32] + } + + // Get current time + now := time.Now() + createTime := now.UnixMilli() + + // Set default language + language := "English" + + // Create new chat + chat := &model.Chat{ + ID: newID, + TenantID: tenantID, + Name: &name, + Description: &description, + Icon: &req.Icon, + Language: &language, + LLMID: llmID, + LLMSetting: llmSetting, + PromptConfig: promptConfigMap, + MetaDataFilter: (*model.JSONMap)(&metaDataFilter), + TopN: topN, + TopK: topK, + RerankID: rerankID, + SimilarityThreshold: similarityThreshold, + VectorSimilarityWeight: vectorSimilarityWeight, + KBIDs: kbIDsJSON, + Status: strPtr("1"), + } + chat.CreateTime = createTime + chat.CreateDate = &now + chat.UpdateTime = &createTime + chat.UpdateDate = &now + + if err := s.chatDAO.Create(chat); err != nil { + return nil, errors.New("Fail to new a chat") + } + + // Get KB names + kbNames := s.getKBNames(chat.KBIDs) + + return &SetDialogResponse{ + Chat: chat, + KBNames: kbNames, + }, nil + } + + // Update existing chat - also update update_time + now := time.Now() + updateTime := now.UnixMilli() + updateData := map[string]interface{}{ + "name": name, + "description": description, + "icon": req.Icon, + "llm_id": llmID, + "llm_setting": llmSetting, + "prompt_config": promptConfigMap, + "meta_data_filter": metaDataFilter, + "top_n": topN, + "top_k": topK, + "rerank_id": rerankID, + "similarity_threshold": similarityThreshold, + "vector_similarity_weight": vectorSimilarityWeight, + "kb_ids": kbIDsJSON, + "update_time": updateTime, + "update_date": now, + } + + if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil { + return nil, errors.New("Dialog not found") + } + + // Get updated chat + chat, err := s.chatDAO.GetByID(req.DialogID) + if err != nil { + return nil, errors.New("Fail to update a chat") + } + + // Get KB names + kbNames := s.getKBNames(chat.KBIDs) + + return &SetDialogResponse{ + Chat: chat, + KBNames: kbNames, + }, nil +} + +// generateUniqueName generates a unique name by appending a number +func (s *ChatService) generateUniqueName(name string, existingNames []string) string { + baseName := name + counter := 1 + + // Check if name already has a suffix like "(1)" + if idx := strings.LastIndex(name, "("); idx > 0 { + if idx2 := strings.LastIndex(name, ")"); idx2 > idx { + if num, err := fmt.Sscanf(name[idx+1:idx2], "%d", &counter); err == nil && num == 1 { + baseName = strings.TrimSpace(name[:idx]) + counter++ + } + } + } + + existingMap := make(map[string]bool) + for _, n := range existingNames { + existingMap[strings.ToLower(n)] = true + } + + newName := name + for { + if !existingMap[strings.ToLower(newName)] { + return newName + } + newName = fmt.Sprintf("%s(%d)", baseName, counter) + counter++ + } +} + +// splitModelNameAndFactory extracts the base model name (removes vendor suffix) +func (s *ChatService) splitModelNameAndFactory(embdID string) string { + // Remove vendor suffix (e.g., "model@openai" -> "model") + if idx := strings.LastIndex(embdID, "@"); idx > 0 { + return embdID[:idx] + } + return embdID +} + +// getEmbdIDs extracts embedding IDs from knowledge bases +func getEmbdIDs(kbs []*model.Knowledgebase) []string { + ids := make([]string, len(kbs)) + for i, kb := range kbs { + ids[i] = kb.EmbdID + } + return ids +} + +// RemoveChats removes dialogs by setting their status to invalid (soft delete) +// Only the owner of the chat can perform this operation +func (s *ChatService) RemoveChats(userID string, chatIDs []string) error { + // Get user's tenants + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return err + } + + // Build a set of user's tenant IDs for quick lookup + tenantIDSet := make(map[string]bool) + for _, tid := range tenantIDs { + tenantIDSet[tid] = true + } + // Also add userID itself as a tenant (for cases where tenant_id = user_id) + tenantIDSet[userID] = true + + // Check each chat and build update list + var updates []map[string]interface{} + for _, chatID := range chatIDs { + // Get the chat to check ownership + chat, err := s.chatDAO.GetByID(chatID) + if err != nil { + return fmt.Errorf("chat not found: %s", chatID) + } + + // Check if user is the owner (chat's tenant_id must be in user's tenants) + if !tenantIDSet[chat.TenantID] { + return errors.New("only owner of chat authorized for this operation") + } + + // Add to update list (soft delete by setting status to "0") + updates = append(updates, map[string]interface{}{ + "id": chatID, + "status": "0", + }) + } + + // Batch update all dialogs + if err := s.chatDAO.UpdateManyByID(updates); err != nil { + return err + } + + return nil +} + +// strPtr returns a pointer to a string +func strPtr(s string) *string { + return &s +} + +// Helper to count UTF-8 characters (not bytes) +func (s *ChatService) countRunes(str string) int { + return utf8.RuneCountInString(str) +} diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go new file mode 100644 index 00000000000..7de702e92c6 --- /dev/null +++ b/internal/service/chat_session.go @@ -0,0 +1,893 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// ChatSessionService chat session (conversation) service +type ChatSessionService struct { + chatSessionDAO *dao.ChatSessionDAO + chatDAO *dao.ChatDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewChatSessionService create chat session service +func NewChatSessionService() *ChatSessionService { + return &ChatSessionService{ + chatSessionDAO: dao.NewChatSessionDAO(), + chatDAO: dao.NewChatDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// SetChatSessionRequest set chat session request +type SetChatSessionRequest struct { + SessionID string `json:"conversation_id,omitempty"` + DialogID string `json:"dialog_id,omitempty"` + Name string `json:"name,omitempty"` + IsNew bool `json:"is_new"` +} + +// SetChatSessionResponse set chat session response +type SetChatSessionResponse struct { + *model.ChatSession +} + +// SetChatSession create or update a chat session +func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) { + name := req.Name + if name == "" { + name = "New chat session" + } + // Limit name length to 255 characters + if len(name) > 255 { + name = name[:255] + } + + if !req.IsNew { + // Update existing chat session + updates := map[string]interface{}{ + "name": name, + "user_id": userID, + "update_time": time.Now().UnixMilli(), + "update_date": time.Now(), + } + + if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil { + return nil, errors.New("Chat session not found") + } + + // Get updated chat session + session, err := s.chatSessionDAO.GetByID(req.SessionID) + if err != nil { + return nil, errors.New("Fail to update a chat session") + } + + return &SetChatSessionResponse{ChatSession: session}, nil + } + + // Create new chat session + // Check if dialog exists + dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID) + if err != nil { + return nil, errors.New("Dialog not found") + } + + // Generate UUID for new chat session + newID := uuid.New().String() + newID = strings.ReplaceAll(newID, "-", "") + if len(newID) > 32 { + newID = newID[:32] + } + + // Get prologue from dialog's prompt_config + prologue := "Hi! I'm your assistant. What can I do for you?" + if dialog.PromptConfig != nil { + if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" { + prologue = p + } + } + + now := time.Now() + createTime := now.UnixMilli() + + // Create initial message - store as JSON object with messages array + messagesObj := map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "assistant", + "content": prologue, + }, + }, + } + messagesJSON, _ := json.Marshal(messagesObj) + + // Create reference - store as JSON array + referenceJSON, _ := json.Marshal([]interface{}{}) + + // Create chat session + session := &model.ChatSession{ + ID: newID, + DialogID: req.DialogID, + Name: &name, + Message: messagesJSON, + UserID: &userID, + Reference: referenceJSON, + } + session.CreateTime = createTime + session.CreateDate = &now + session.UpdateTime = &createTime + session.UpdateDate = &now + + if err := s.chatSessionDAO.Create(session); err != nil { + return nil, errors.New("Fail to create a chat session") + } + + return &SetChatSessionResponse{ChatSession: session}, nil +} + +// RemoveChatSessionRequest remove chat sessions request +type RemoveChatSessionRequest struct { + ChatSessions []string `json:"conversation_ids" binding:"required"` +} + +// RemoveChatSessions removes chat sessions (hard delete) +func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error { + // Get user's tenants + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return err + } + + // Build a set of user's tenant IDs for quick lookup + tenantIDSet := make(map[string]bool) + for _, tid := range tenantIDs { + tenantIDSet[tid] = true + } + tenantIDSet[userID] = true + + // Check each chat session + for _, convID := range chatSessions { + // Get the chat session + session, err := s.chatSessionDAO.GetByID(convID) + if err != nil { + return fmt.Errorf("Chat session not found: %s", convID) + } + + // Check if user is the owner by checking dialog ownership + isOwner := false + for tenantID := range tenantIDSet { + exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID) + if err != nil { + return err + } + if exists { + isOwner = true + break + } + } + + if !isOwner { + return errors.New("Only owner of chat session authorized for this operation") + } + + // Delete the chat session + if err := s.chatSessionDAO.DeleteByID(convID); err != nil { + return err + } + } + + return nil +} + +// ListChatSessionsRequest list chat sessions request +type ListChatSessionsRequest struct { + DialogID string `json:"dialog_id" binding:"required"` +} + +// ListChatSessionsResponse list chat sessions response +type ListChatSessionsResponse struct { + Sessions []*model.ChatSession +} + +// ListChatSessions lists chat sessions for a dialog +func (s *ChatSessionService) ListChatSessions(userID string, dialogID string) (*ListChatSessionsResponse, error) { + // Get user's tenants + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + // Check if user is the owner of the dialog + isOwner := false + for _, tenantID := range tenantIDs { + exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, dialogID) + if err != nil { + return nil, err + } + if exists { + isOwner = true + break + } + } + + // Also check with userID as tenant + if !isOwner { + exists, err := s.chatSessionDAO.CheckDialogExists(userID, dialogID) + if err != nil { + return nil, err + } + isOwner = exists + } + + if !isOwner { + return nil, errors.New("Only owner of dialog authorized for this operation") + } + + // List chat sessions + sessions, err := s.chatSessionDAO.ListByDialogID(dialogID) + if err != nil { + return nil, err + } + + return &ListChatSessionsResponse{Sessions: sessions}, nil +} + +// Completion performs chat completion with full RAG support +func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) { + // Validate the last message is from user + if len(messages) == 0 { + return nil, errors.New("messages cannot be empty") + } + lastRole, _ := messages[len(messages)-1]["role"].(string) + if lastRole != "user" { + return nil, errors.New("the last content of this conversation is not from user") + } + + // Get conversation + session, err := s.chatSessionDAO.GetByID(conversationID) + if err != nil { + return nil, errors.New("Conversation not found") + } + + // Get dialog + dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) + if err != nil { + return nil, errors.New("Dialog not found") + } + + // Deep copy messages to session + sessionMessages := s.buildSessionMessages(session, messages) + + // Initialize reference if empty + reference := s.initializeReference(session) + + // Check if custom LLM is specified and validate API key + isEmbedded := llmID != "" + if llmID != "" { + hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) + if err != nil || !hasKey { + return nil, fmt.Errorf("Cannot use specified model %s", llmID) + } + dialog.LLMID = llmID + if chatModelConfig != nil { + dialog.LLMSetting = chatModelConfig + } + } + + // Perform chat completion with RAG + result, err := s.asyncChat(dialog, session, messages, chatModelConfig, messageID, reference, false) + if err != nil { + return nil, err + } + + // Update conversation if not embedded + if !isEmbedded { + s.updateSessionMessages(session, sessionMessages, reference) + } + + return result, nil +} + +// CompletionStream performs streaming chat completion with full RAG support +func (s *ChatSessionService) CompletionStream(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { + // Validate the last message is from user + if len(messages) == 0 { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`) + return errors.New("messages cannot be empty") + } + lastRole, _ := messages[len(messages)-1]["role"].(string) + if lastRole != "user" { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`) + return errors.New("the last content of this conversation is not from user") + } + + // Get conversation + session, err := s.chatSessionDAO.GetByID(conversationID) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`) + return errors.New("Conversation not found") + } + + // Get dialog + dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`) + return errors.New("Dialog not found") + } + + // Deep copy messages to session + sessionMessages := s.buildSessionMessages(session, messages) + + // Initialize reference if empty + reference := s.initializeReference(session) + + // Check if custom LLM is specified and validate API key + isEmbedded := llmID != "" + if llmID != "" { + hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) + if err != nil || !hasKey { + errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID) + streamChan <- fmt.Sprintf("data: %s\n\n", errMsg) + return fmt.Errorf("Cannot use specified model %s", llmID) + } + dialog.LLMID = llmID + if chatModelConfig != nil { + dialog.LLMSetting = chatModelConfig + } + } + + // Perform streaming chat completion with RAG + resultChan, err := s.asyncChatStream(dialog, session, messages, chatModelConfig, messageID, reference) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error())) + return err + } + + // Stream results + for result := range resultChan { + data, _ := json.Marshal(map[string]interface{}{ + "code": 0, + "message": "", + "data": result, + }) + streamChan <- fmt.Sprintf("data: %s\n\n", string(data)) + } + + // Send final completion signal + finalData, _ := json.Marshal(map[string]interface{}{ + "code": 0, + "message": "", + "data": true, + }) + streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData)) + + // Update conversation if not embedded + if !isEmbedded { + s.updateSessionMessages(session, sessionMessages, reference) + } + + return nil +} + +// Helper methods + +func (s *ChatSessionService) buildSessionMessages(session *model.ChatSession, messages []map[string]interface{}) []map[string]interface{} { + // Deep copy messages to session + sessionMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + sessionMessages[i] = make(map[string]interface{}) + for k, v := range msg { + sessionMessages[i][k] = v + } + } + return sessionMessages +} + +func (s *ChatSessionService) initializeReference(session *model.ChatSession) []interface{} { + var reference []interface{} + if len(session.Reference) > 0 { + json.Unmarshal(session.Reference, &reference) + } + // Filter out nil entries and append new reference + var filtered []interface{} + for _, r := range reference { + if r != nil { + filtered = append(filtered, r) + } + } + filtered = append(filtered, map[string]interface{}{ + "chunks": []interface{}{}, + "doc_aggs": []interface{}{}, + }) + return filtered +} + +func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (bool, error) { + // Simplified check - in real implementation, check if tenant has API key for this model + return true, nil +} + +func (s *ChatSessionService) performChat(dialog *model.Chat, messages []map[string]interface{}, config map[string]interface{}) (string, error) { + // Get system prompt from dialog + systemPrompt := "" + if dialog.PromptConfig != nil { + if sys, ok := dialog.PromptConfig["system"].(string); ok { + systemPrompt = sys + } + } + + // Convert messages to history format + history := make([]map[string]string, 0) + for _, msg := range messages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + if role != "" && content != "" { + history = append(history, map[string]string{ + "role": role, + "content": content, + }) + } + } + + // Use ModelBundle to perform chat + bundle, err := NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID) + if err != nil { + return "", err + } + + // Merge dialog's LLM setting with request config + genConf := make(map[string]interface{}) + if dialog.LLMSetting != nil { + for k, v := range dialog.LLMSetting { + genConf[k] = v + } + } + for k, v := range config { + genConf[k] = v + } + + response, _, err := bundle.Chat(systemPrompt, history, genConf) + return response, err +} + +func (s *ChatSessionService) performChatStream(dialog *model.Chat, messages []map[string]interface{}, config map[string]interface{}) (<-chan string, error) { + // Get system prompt from dialog + systemPrompt := "" + if dialog.PromptConfig != nil { + if sys, ok := dialog.PromptConfig["system"].(string); ok { + systemPrompt = sys + } + } + + // Convert messages to history format + history := make([]map[string]string, 0) + for _, msg := range messages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + if role != "" && content != "" { + history = append(history, map[string]string{ + "role": role, + "content": content, + }) + } + } + + // Use ModelBundle to perform streaming chat + bundle, err := NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID) + if err != nil { + return nil, err + } + + // Merge dialog's LLM setting with request config + genConf := make(map[string]interface{}) + if dialog.LLMSetting != nil { + for k, v := range dialog.LLMSetting { + genConf[k] = v + } + } + for k, v := range config { + genConf[k] = v + } + + // Get chat model and call ChatStreamly + chatModel, ok := bundle.GetModel().(model.ChatModel) + if !ok { + return nil, fmt.Errorf("model is not a chat model") + } + + return chatModel.ChatStreamly(systemPrompt, history, genConf) +} + +func (s *ChatSessionService) structureAnswer(session *model.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} { + return map[string]interface{}{ + "answer": answer, + "reference": reference, + "conversation_id": conversationID, + "message_id": messageID, + } +} + +func (s *ChatSessionService) updateSessionMessages(session *model.ChatSession, messages []map[string]interface{}, reference []interface{}) { + // Update session with new messages and reference + messagesJSON, _ := json.Marshal(map[string]interface{}{ + "messages": messages, + }) + referenceJSON, _ := json.Marshal(reference) + + updates := map[string]interface{}{ + "message": messagesJSON, + "reference": referenceJSON, + "update_time": time.Now().UnixMilli(), + "update_date": time.Now(), + } + s.chatSessionDAO.UpdateByID(session.ID, updates) +} + +// asyncChat performs chat with RAG support (non-streaming) +func (s *ChatSessionService) asyncChat(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { + // Check if we need RAG (knowledge base or tavily) + hasKB := len(dialog.KBIDs) > 0 + hasTavily := false + if dialog.PromptConfig != nil { + if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" { + hasTavily = true + } + } + + if !hasKB && !hasTavily { + // Simple chat without RAG + return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) + } + + // TODO: Full RAG implementation with knowledge base retrieval + // This would include: + // 1. Get embedding model and rerank model + // 2. Extract questions from messages + // 3. Retrieve chunks from knowledge bases + // 4. Rerank chunks + // 5. Build prompt with context + // 6. Call LLM + + // For now, fall back to solo chat + return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) +} + +// asyncChatStream performs streaming chat with RAG support +func (s *ChatSessionService) asyncChatStream(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) { + resultChan := make(chan map[string]interface{}) + + go func() { + defer close(resultChan) + + // Check if we need RAG + hasKB := len(dialog.KBIDs) > 0 + hasTavily := false + if dialog.PromptConfig != nil { + if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" { + hasTavily = true + } + } + + if !hasKB && !hasTavily { + // Simple chat without RAG + s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan) + return + } + + // TODO: Full RAG streaming implementation + // For now, fall back to solo chat + s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan) + }() + + return resultChan, nil +} + +// asyncChatSolo performs simple chat without RAG (non-streaming) +func (s *ChatSessionService) asyncChatSolo(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { + // Get system prompt + systemPrompt := s.buildSystemPrompt(dialog) + + // Process messages - handle attachments and image files + processedMessages := s.processMessages(messages, dialog) + + // Get LLM type + llmType := s.getLLMType(dialog.LLMID) + + // Build generation config + genConf := s.buildGenConf(dialog, config) + + // Create ModelBundle for chat + var bundle *ModelBundle + var err error + if llmType == "image2text" { + bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeImage2Text, dialog.LLMID) + } else { + bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID) + } + if err != nil { + return nil, err + } + + // Convert messages to history format + history := s.convertToHistory(processedMessages) + + // Perform chat + response, _, err := bundle.Chat(systemPrompt, history, genConf) + if err != nil { + return nil, err + } + + // Structure the answer + ans := map[string]interface{}{ + "answer": response, + "reference": reference[len(reference)-1], + "final": true, + } + + return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil +} + +// asyncChatSoloStream performs simple streaming chat without RAG +func (s *ChatSessionService) asyncChatSoloStream(dialog *model.Chat, session *model.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) { + // Get system prompt + systemPrompt := s.buildSystemPrompt(dialog) + + // Process messages + processedMessages := s.processMessages(messages, dialog) + + // Get LLM type + llmType := s.getLLMType(dialog.LLMID) + + // Build generation config + genConf := s.buildGenConf(dialog, config) + + // Create ModelBundle + var bundle *ModelBundle + var err error + if llmType == "image2text" { + bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeImage2Text, dialog.LLMID) + } else { + bundle, err = NewModelBundle(dialog.TenantID, model.ModelTypeChat, dialog.LLMID) + } + if err != nil { + resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) + return + } + + // Convert messages to history + history := s.convertToHistory(processedMessages) + + // Get chat model + chatModel, ok := bundle.GetModel().(model.ChatModel) + if !ok { + resultChan <- s.structureAnswer(session, "**ERROR**: model is not a chat model", messageID, session.ID, reference) + return + } + + // Perform streaming chat + streamChan, err := chatModel.ChatStreamly(systemPrompt, history, genConf) + if err != nil { + resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) + return + } + + // Stream results + fullAnswer := "" + for chunk := range streamChan { + fullAnswer += chunk + // Clean up reasoning content + fullAnswer = s.removeReasoningContent(fullAnswer) + ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) + resultChan <- ans + } +} + +// buildSystemPrompt builds the system prompt from dialog configuration +func (s *ChatSessionService) buildSystemPrompt(dialog *model.Chat) string { + if dialog.PromptConfig == nil { + return "" + } + + system, _ := dialog.PromptConfig["system"].(string) + return system +} + +// processMessages processes messages and handles attachments +func (s *ChatSessionService) processMessages(messages []map[string]interface{}, dialog *model.Chat) []map[string]interface{} { + // Process each message + processed := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + processed[i] = make(map[string]interface{}) + for k, v := range msg { + processed[i][k] = v + } + + // Clean content - remove file markers + if content, ok := msg["content"].(string); ok { + content = s.cleanContent(content) + processed[i]["content"] = content + } + } + + return processed +} + +// cleanContent removes file markers from content +func (s *ChatSessionService) cleanContent(content string) string { + // Remove ##N$$ markers + // This is a simplified version - full implementation would use regex + return content +} + +// convertToHistory converts messages to history format for LLM +func (s *ChatSessionService) convertToHistory(messages []map[string]interface{}) []map[string]string { + history := make([]map[string]string, 0) + for _, msg := range messages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + if role != "" && content != "" && role != "system" { + history = append(history, map[string]string{ + "role": role, + "content": content, + }) + } + } + return history +} + +// buildGenConf builds generation config from dialog and request +func (s *ChatSessionService) buildGenConf(dialog *model.Chat, config map[string]interface{}) map[string]interface{} { + genConf := make(map[string]interface{}) + + // Start with dialog's LLM setting + if dialog.LLMSetting != nil { + for k, v := range dialog.LLMSetting { + genConf[k] = v + } + } + + // Override with request config + for k, v := range config { + genConf[k] = v + } + + return genConf +} + +// getLLMType gets the LLM type from model ID +func (s *ChatSessionService) getLLMType(llmID string) string { + // Simplified - would need to query TenantLLMService + if strings.Contains(llmID, "image") || strings.Contains(llmID, "vision") { + return "image2text" + } + return "chat" +} + +// removeReasoningContent removes reasoning/thinking content from answer +func (s *ChatSessionService) removeReasoningContent(answer string) string { + // Remove tags + if strings.HasSuffix(answer, "") { + answer = answer[:len(answer)-len("")] + } + return answer +} + +// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer) +func (s *ChatSessionService) structureAnswerWithConv(session *model.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} { + // Extract reference from answer + ref, _ := ans["reference"].(map[string]interface{}) + if ref == nil { + ref = map[string]interface{}{ + "chunks": []interface{}{}, + "doc_aggs": []interface{}{}, + } + ans["reference"] = ref + } + + // Format chunks + chunkList := s.chunksFormat(ref) + ref["chunks"] = chunkList + + // Add message ID and session ID + ans["id"] = messageID + ans["session_id"] = conversationID + + // Update session message + content, _ := ans["answer"].(string) + if ans["start_to_think"] != nil { + content = "" + } else if ans["end_to_think"] != nil { + content = "" + } + + // Parse existing messages + var messagesObj map[string]interface{} + if len(session.Message) > 0 { + json.Unmarshal(session.Message, &messagesObj) + } + messages, _ := messagesObj["messages"].([]interface{}) + + // Update or append assistant message + if len(messages) == 0 || s.getLastRole(messages) != "assistant" { + messages = append(messages, map[string]interface{}{ + "role": "assistant", + "content": content, + "created_at": float64(time.Now().Unix()), + "id": messageID, + }) + } else { + lastIdx := len(messages) - 1 + lastMsg, _ := messages[lastIdx].(map[string]interface{}) + if lastMsg != nil { + if ans["final"] == true && ans["answer"] != nil { + lastMsg["content"] = ans["answer"] + } else { + lastMsg["content"] = (lastMsg["content"].(string)) + content + } + lastMsg["created_at"] = float64(time.Now().Unix()) + lastMsg["id"] = messageID + messages[lastIdx] = lastMsg + } + } + + // Update reference + if len(reference) > 0 { + reference[len(reference)-1] = ref + } + + return ans +} + +// getLastRole gets the role of the last message +func (s *ChatSessionService) getLastRole(messages []interface{}) string { + if len(messages) == 0 { + return "" + } + lastMsg, _ := messages[len(messages)-1].(map[string]interface{}) + if lastMsg != nil { + role, _ := lastMsg["role"].(string) + return role + } + return "" +} + +// chunksFormat formats chunks for reference (simplified version) +func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []interface{} { + chunks, _ := reference["chunks"].([]interface{}) + if chunks == nil { + return []interface{}{} + } + + // Format each chunk + formatted := make([]interface{}, len(chunks)) + for i, chunk := range chunks { + formatted[i] = chunk + } + return formatted +} diff --git a/internal/service/chunk.go b/internal/service/chunk.go new file mode 100644 index 00000000000..cbed1665d41 --- /dev/null +++ b/internal/service/chunk.go @@ -0,0 +1,465 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "fmt" + "ragflow/internal/server" + + "go.uber.org/zap" + + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/logger" + "ragflow/internal/model" + "ragflow/internal/service/nlp" + "ragflow/internal/utility" +) + +// ChunkService chunk service +type ChunkService struct { + docEngine engine.DocEngine + engineType server.EngineType + modelProvider ModelProvider + embeddingCache *utility.EmbeddingLRU + kbDAO *dao.KnowledgebaseDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewChunkService creates chunk service +func NewChunkService() *ChunkService { + cfg := server.GetConfig() + return &ChunkService{ + docEngine: engine.Get(), + engineType: cfg.DocEngine.Type, + modelProvider: NewModelProvider(), + embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity + kbDAO: dao.NewKnowledgebaseDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// RetrievalTestRequest retrieval test request +type RetrievalTestRequest struct { + KbID interface{} `json:"kb_id" binding:"required"` // string or []string + Question string `json:"question" binding:"required"` + Page *int `json:"page,omitempty"` + Size *int `json:"size,omitempty"` + DocIDs []string `json:"doc_ids,omitempty"` + UseKG *bool `json:"use_kg,omitempty"` + TopK *int `json:"top_k,omitempty"` + CrossLanguages []string `json:"cross_languages,omitempty"` + SearchID *string `json:"search_id,omitempty"` + MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"` + RerankID *string `json:"rerank_id,omitempty"` + Keyword *bool `json:"keyword,omitempty"` + SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"` + VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"` + TenantIDs []string `json:"tenant_ids,omitempty"` +} + +// RetrievalTestResponse retrieval test response +type RetrievalTestResponse struct { + Chunks []map[string]interface{} `json:"chunks"` + Labels []map[string]interface{} `json:"labels"` + Total int64 `json:"total,omitempty"` +} + +// RetrievalTest performs retrieval test +func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error) { + if s.docEngine == nil { + return nil, fmt.Errorf("doc engine not initialized") + } + + // Validate question is required + if req.Question == "" { + return nil, fmt.Errorf("question is required") + } + + ctx := context.Background() + + // Get user's tenants + tenants, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return nil, fmt.Errorf("failed to get user tenants: %w", err) + } + if len(tenants) == 0 { + return nil, fmt.Errorf("user has no accessible tenants") + } + logger.Debug("Retrieved user tenants from database", zap.String("userID", userID), zap.Int("tenantCount", len(tenants))) + + // Determine kb_id list + var kbIDs []string + switch v := req.KbID.(type) { + case string: + kbIDs = []string{v} + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + kbIDs = append(kbIDs, str) + } else { + return nil, fmt.Errorf("kb_id array must contain strings") + } + } + case []string: + kbIDs = v + default: + return nil, fmt.Errorf("kb_id must be string or array of strings") + } + + if len(kbIDs) == 0 { + return nil, fmt.Errorf("kb_id cannot be empty") + } + + // Check permission for each kb_id + var tenantIDs []string + var kbRecords []*model.Knowledgebase + + for _, kbID := range kbIDs { + found := false + for _, tenant := range tenants { + kb, err := s.kbDAO.GetByIDAndTenantID(kbID, tenant.TenantID) + if err == nil && kb != nil { + logger.Debug("Found knowledge base record in database", + zap.String("kbID", kbID), + zap.String("tenantID", tenant.TenantID), + zap.String("kbName", kb.Name), + zap.String("embdID", kb.EmbdID)) + tenantIDs = append(tenantIDs, tenant.TenantID) + kbRecords = append(kbRecords, kb) + found = true + break + } + } + if !found { + return nil, fmt.Errorf("only owner of dataset is authorized for this operation") + } + } + + // Check if all kb records have the same embedding model + if len(kbRecords) > 1 { + firstEmbdID := kbRecords[0].EmbdID + for i := 1; i < len(kbRecords); i++ { + if kbRecords[i].EmbdID != firstEmbdID { + return nil, fmt.Errorf("cannot retrieve across datasets with different embedding models") + } + } + } + + // Get user's owner tenants to prioritize + ownerTenants, err := s.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, fmt.Errorf("failed to get user owner tenants: %w", err) + } + logger.Debug("Retrieved owner tenants from database", + zap.String("userID", userID), + zap.Int("ownerTenantCount", len(ownerTenants))) + + req.TenantIDs = tenantIDs + // Choose target tenant: prioritize owner tenant if available in tenantIDs + targetTenantID := tenantIDs[0] + + // Get embedding model for the target tenant + embeddingModel, err := s.modelProvider.GetEmbeddingModel(ctx, targetTenantID, kbRecords[0].EmbdID) + if err != nil { + return nil, fmt.Errorf("failed to get embedding model: %w", err) + } + logger.Debug("Retrieved embedding model from database", + zap.String("targetTenantID", targetTenantID), + zap.String("embdID", kbRecords[0].EmbdID)) + + // Try to get embedding from cache first + embdID := kbRecords[0].EmbdID + var questionVector []float64 + + if s.embeddingCache != nil { + if cachedVector, ok := s.embeddingCache.Get(req.Question, embdID); ok { + logger.Debug("Embedding cache hit", + zap.String("question", req.Question), + zap.String("embdID", embdID), + zap.Int("cacheSize", s.embeddingCache.Len())) + questionVector = cachedVector + } else { + // Cache miss, encode and store + questionVector, err = embeddingModel.EncodeQuery(req.Question) + if err != nil { + return nil, fmt.Errorf("failed to encode query: %w", err) + } + s.embeddingCache.Put(req.Question, embdID, questionVector) + logger.Debug("Embedding cache miss, stored", + zap.String("question", req.Question), + zap.String("embdID", embdID), + zap.Int("vectorDim", len(questionVector)), + zap.Int("cacheSize", s.embeddingCache.Len())) + } + } else { + // No cache, just encode + questionVector, err = embeddingModel.EncodeQuery(req.Question) + if err != nil { + return nil, fmt.Errorf("failed to encode query: %w", err) + } + } + + // Use global QueryBuilder to process question and get matchText and keywords + // Reference: rag/nlp/search.py L115 + queryBuilder := nlp.GetQueryBuilder() + if queryBuilder == nil { + return nil, fmt.Errorf("query builder not initialized") + } + matchTextExpr, keywords := queryBuilder.Question(req.Question, "qa", 0.6) + + //if matchTextExpr == nil { + // return nil, fmt.Errorf("failed to process question") + //} + logger.Debug("QueryBuilder processed question", + zap.String("original", req.Question), + zap.String("matchingText", matchTextExpr.MatchingText), + zap.Strings("keywords", keywords)) + + // Build unified search request + searchReq := &engine.SearchRequest{ + IndexNames: buildIndexNames(tenantIDs), + Question: req.Question, + MatchText: matchTextExpr.MatchingText, + Keywords: keywords, + Vector: questionVector, + KbIDs: kbIDs, + DocIDs: req.DocIDs, + Page: getPageNum(req.Page), + Size: getPageSize(req.Size), + TopK: getTopK(req.TopK), + KeywordOnly: req.Keyword != nil && *req.Keyword, + SimilarityThreshold: getSimilarityThreshold(req.SimilarityThreshold), + VectorSimilarityWeight: getVectorSimilarityWeight(req.VectorSimilarityWeight), + } + + // Execute search through unified engine interface + result, err := s.docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + + // Convert result to unified response + searchResp, ok := result.(*engine.SearchResponse) + if !ok { + return nil, fmt.Errorf("invalid search response type") + } + + //return &RetrievalTestResponse{ + // Chunks: searchResp.Chunks, + // Labels: []map[string]interface{}{}, // Empty labels for now + // Total: searchResp.Total, + //}, nil + + //// Build SearchResult for reranker + //sres := buildSearchResult(searchResp, questionVector) + // + // Get rerank model if RerankID is specified (can be nil) + var rerankModel nlp.RerankModel + if req.RerankID != nil && *req.RerankID != "" { + rerankModel, err = s.modelProvider.GetRerankModel(ctx, targetTenantID, *req.RerankID) + if err != nil { + logger.Warn("Failed to get rerank model, falling back to standard reranking", zap.Error(err)) + rerankModel = nil + } + } + + // Perform reranking + // Reference: rag/nlp/search.py L404-L429 + tkWeight := 1.0 - *req.VectorSimilarityWeight + vtWeight := *req.VectorSimilarityWeight + useInfinity := s.engineType == server.EngineInfinity + + sim, term_similarity, vector_similarity := nlp.Rerank( + rerankModel, + searchResp, + keywords, + questionVector, + nil, + req.Question, + tkWeight, + vtWeight, + useInfinity, + "content_ltks", + queryBuilder, + ) + // + // Apply similarity threshold and sort chunks + similarityThreshold := getSimilarityThreshold(req.SimilarityThreshold) + filteredChunks := applyRerankResults(searchResp.Chunks, sim, similarityThreshold) + for idx, _ := range filteredChunks { + filteredChunks[idx]["similarity"] = sim[idx] + filteredChunks[idx]["term_similarity"] = term_similarity[idx] + filteredChunks[idx]["vector_similarity"] = vector_similarity[idx] + } + + convertedChunks := buildRetrievalTestResults(filteredChunks) + + return &RetrievalTestResponse{ + Chunks: convertedChunks, + Labels: []map[string]interface{}{}, // Empty labels for now + Total: int64(len(convertedChunks)), + }, nil +} + +// Helper functions + +func getPageNum(page *int) int { + if page != nil && *page > 0 { + return *page + } + return 1 +} + +func getPageSize(size *int) int { + if size != nil && *size > 0 { + return *size + } + return 30 +} + +func getTopK(topk *int) int { + if topk != nil && *topk > 0 { + return *topk + } + return 1024 +} + +func getSimilarityThreshold(threshold *float64) float64 { + if threshold != nil && *threshold >= 0 { + return *threshold + } + return 0.1 +} + +func getVectorSimilarityWeight(weight *float64) float64 { + //if weight != nil && *weight >= 0 && *weight <= 1 { + // return *weight + //} + return 0.95 +} + +func buildIndexNames(tenantIDs []string) []string { + indexNames := make([]string, len(tenantIDs)) + for i, tenantID := range tenantIDs { + indexNames[i] = fmt.Sprintf("ragflow_%s", tenantID) + } + return indexNames +} + +// buildSearchResult converts engine.SearchResponse to nlp.SearchResult for reranking +func buildSearchResult(resp *engine.SearchResponse, queryVector []float64) *nlp.SearchResult { + field := make(map[string]map[string]interface{}) + ids := make([]string, 0, len(resp.Chunks)) + + for i, chunk := range resp.Chunks { + // Extract ID from chunk + id := "" + if idVal, ok := chunk["_id"].(string); ok { + id = idVal + } else { + id = fmt.Sprintf("chunk_%d", i) + } + ids = append(ids, id) + + // Store fields by id + field[id] = chunk + } + + return &nlp.SearchResult{ + Total: len(resp.Chunks), + IDs: ids, + QueryVector: queryVector, + Field: field, + } +} + +// applyRerankResults sorts and filters chunks based on reranking results +// Reference: rag/nlp/search.py L430-L439 +func applyRerankResults(chunks []map[string]interface{}, sim []float64, threshold float64) []map[string]interface{} { + if len(chunks) == 0 || len(sim) == 0 { + return chunks + } + + // Get sorted indices (descending by similarity) + sortedIndices := nlp.ArgsortDescending(sim) + + // Sort and filter chunks based on reranking results + var filteredChunks []map[string]interface{} + for _, idx := range sortedIndices { + if idx < 0 || idx >= len(chunks) { + continue + } + if sim[idx] >= threshold { + chunk := chunks[idx] + // Add similarity score to chunk + chunk["_score"] = sim[idx] + filteredChunks = append(filteredChunks, chunk) + } + } + + return filteredChunks +} + +// buildRetrievalTestResults converts filtered chunks to retrieval test results with renamed keys +func buildRetrievalTestResults(filteredChunks []map[string]interface{}) []map[string]interface{} { + results := make([]map[string]interface{}, 0, len(filteredChunks)) + + for _, chunk := range filteredChunks { + result := make(map[string]interface{}) + + // Key mappings + if v, ok := chunk["_id"]; ok { + result["chunk_id"] = v + } + if v, ok := chunk["content_ltks"]; ok { + result["content_ltks"] = v + } + if v, ok := chunk["content_with_weight"]; ok { + result["content_with_weight"] = v + } + if v, ok := chunk["doc_id"]; ok { + result["doc_id"] = v + } + if v, ok := chunk["docnm_kwd"]; ok { + result["docnm_kwd"] = v + } + if v, ok := chunk["img_id"]; ok { + result["image_id"] = v + } + if v, ok := chunk["kb_id"]; ok { + result["kb_id"] = v + } + if v, ok := chunk["position_int"]; ok { + result["positions"] = v + } + if v, ok := chunk["similarity"]; ok { + result["similarity"] = v + } + if v, ok := chunk["term_similarity"]; ok { + result["term_similarity"] = v + } + if v, ok := chunk["vector_similarity"]; ok { + result["vector_similarity"] = v + } + + results = append(results, result) + } + + return results +} diff --git a/internal/service/connector.go b/internal/service/connector.go new file mode 100644 index 00000000000..bebf8e5e81e --- /dev/null +++ b/internal/service/connector.go @@ -0,0 +1,69 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/dao" +) + +// ConnectorService connector service +type ConnectorService struct { + connectorDAO *dao.ConnectorDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewConnectorService create connector service +func NewConnectorService() *ConnectorService { + return &ConnectorService{ + connectorDAO: dao.NewConnectorDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// ListConnectorsResponse list connectors response +type ListConnectorsResponse struct { + Connectors []*dao.ConnectorListItem `json:"connectors"` +} + +// ListConnectors list connectors for a user +// Equivalent to Python's ConnectorService.list(current_user.id) +func (s *ConnectorService) ListConnectors(userID string) (*ListConnectorsResponse, error) { + // Get tenant IDs by user ID + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + // For now, use the first tenant ID (primary tenant) + // This matches the Python implementation behavior + var tenantID string + if len(tenantIDs) > 0 { + tenantID = tenantIDs[0] + } else { + tenantID = userID + } + + // Query connectors by tenant ID + connectors, err := s.connectorDAO.ListByTenantID(tenantID) + if err != nil { + return nil, err + } + + return &ListConnectorsResponse{ + Connectors: connectors, + }, nil +} diff --git a/internal/service/document.go b/internal/service/document.go new file mode 100644 index 00000000000..94267b797e0 --- /dev/null +++ b/internal/service/document.go @@ -0,0 +1,208 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "fmt" + "time" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// DocumentService document service +type DocumentService struct { + documentDAO *dao.DocumentDAO +} + +// NewDocumentService create document service +func NewDocumentService() *DocumentService { + return &DocumentService{ + documentDAO: dao.NewDocumentDAO(), + } +} + +// CreateDocumentRequest create document request +type CreateDocumentRequest struct { + Name string `json:"name" binding:"required"` + KbID string `json:"kb_id" binding:"required"` + ParserID string `json:"parser_id" binding:"required"` + CreatedBy string `json:"created_by" binding:"required"` + Type string `json:"type"` + Source string `json:"source"` +} + +// UpdateDocumentRequest update document request +type UpdateDocumentRequest struct { + Name *string `json:"name"` + Run *string `json:"run"` + TokenNum *int64 `json:"token_num"` + ChunkNum *int64 `json:"chunk_num"` + Progress *float64 `json:"progress"` + ProgressMsg *string `json:"progress_msg"` +} + +// DocumentResponse document response +type DocumentResponse struct { + ID string `json:"id"` + Name *string `json:"name,omitempty"` + KbID string `json:"kb_id"` + ParserID string `json:"parser_id"` + PipelineID *string `json:"pipeline_id,omitempty"` + Type string `json:"type"` + SourceType string `json:"source_type"` + CreatedBy string `json:"created_by"` + Location *string `json:"location,omitempty"` + Size int64 `json:"size"` + TokenNum int64 `json:"token_num"` + ChunkNum int64 `json:"chunk_num"` + Progress float64 `json:"progress"` + ProgressMsg *string `json:"progress_msg,omitempty"` + ProcessDuration float64 `json:"process_duration"` + Suffix string `json:"suffix"` + Run *string `json:"run,omitempty"` + Status *string `json:"status,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// CreateDocument create document +func (s *DocumentService) CreateDocument(req *CreateDocumentRequest) (*model.Document, error) { + document := &model.Document{ + Name: &req.Name, + KbID: req.KbID, + ParserID: req.ParserID, + CreatedBy: req.CreatedBy, + Type: req.Type, + SourceType: req.Source, + Suffix: ".doc", + Status: func() *string { s := "0"; return &s }(), + } + + if err := s.documentDAO.Create(document); err != nil { + return nil, fmt.Errorf("failed to create document: %w", err) + } + + return document, nil +} + +// GetDocumentByID get document by ID +func (s *DocumentService) GetDocumentByID(id string) (*DocumentResponse, error) { + document, err := s.documentDAO.GetByID(id) + if err != nil { + return nil, err + } + + return s.toResponse(document), nil +} + +// UpdateDocument update document +func (s *DocumentService) UpdateDocument(id string, req *UpdateDocumentRequest) error { + document, err := s.documentDAO.GetByID(id) + if err != nil { + return err + } + + if req.Name != nil { + document.Name = req.Name + } + if req.Run != nil { + document.Run = req.Run + } + if req.TokenNum != nil { + document.TokenNum = *req.TokenNum + } + if req.ChunkNum != nil { + document.ChunkNum = *req.ChunkNum + } + if req.Progress != nil { + document.Progress = *req.Progress + } + if req.ProgressMsg != nil { + document.ProgressMsg = req.ProgressMsg + } + + return s.documentDAO.Update(document) +} + +// DeleteDocument delete document +func (s *DocumentService) DeleteDocument(id string) error { + return s.documentDAO.Delete(id) +} + +// ListDocuments list documents +func (s *DocumentService) ListDocuments(page, pageSize int) ([]*DocumentResponse, int64, error) { + offset := (page - 1) * pageSize + documents, total, err := s.documentDAO.List(offset, pageSize) + if err != nil { + return nil, 0, err + } + + responses := make([]*DocumentResponse, len(documents)) + for i, doc := range documents { + responses[i] = s.toResponse(doc) + } + + return responses, total, nil +} + +// GetDocumentsByAuthorID get documents by author ID +func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) ([]*DocumentResponse, int64, error) { + offset := (page - 1) * pageSize + documents, total, err := s.documentDAO.GetByAuthorID(fmt.Sprintf("%d", authorID), offset, pageSize) + if err != nil { + return nil, 0, err + } + + responses := make([]*DocumentResponse, len(documents)) + for i, doc := range documents { + responses[i] = s.toResponse(doc) + } + + return responses, total, nil +} + +// toResponse convert model.Document to DocumentResponse +func (s *DocumentService) toResponse(doc *model.Document) *DocumentResponse { + createdAt := time.Unix(doc.CreateTime, 0).Format("2006-01-02 15:04:05") + updatedAt := "" + if doc.UpdateTime != nil { + updatedAt = time.Unix(*doc.UpdateTime, 0).Format("2006-01-02 15:04:05") + } + return &DocumentResponse{ + ID: doc.ID, + Name: doc.Name, + KbID: doc.KbID, + ParserID: doc.ParserID, + PipelineID: doc.PipelineID, + Type: doc.Type, + SourceType: doc.SourceType, + CreatedBy: doc.CreatedBy, + Location: doc.Location, + Size: doc.Size, + TokenNum: doc.TokenNum, + ChunkNum: doc.ChunkNum, + Progress: doc.Progress, + ProgressMsg: doc.ProgressMsg, + ProcessDuration: doc.ProcessDuration, + Suffix: doc.Suffix, + Run: doc.Run, + Status: doc.Status, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + } +} diff --git a/internal/service/file.go b/internal/service/file.go new file mode 100644 index 00000000000..34a08123d27 --- /dev/null +++ b/internal/service/file.go @@ -0,0 +1,220 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// FileService file service +type FileService struct { + fileDAO *dao.FileDAO + file2DocumentDAO *dao.File2DocumentDAO +} + +// NewFileService create file service +func NewFileService() *FileService { + return &FileService{ + fileDAO: dao.NewFileDAO(), + file2DocumentDAO: dao.NewFile2DocumentDAO(), + } +} + +// FileInfo file info with additional fields +type FileInfo struct { + *model.File + Size int64 `json:"size"` + KbsInfo []map[string]interface{} `json:"kbs_info"` + HasChildFolder bool `json:"has_child_folder,omitempty"` +} + +// ListFilesResponse list files response +type ListFilesResponse struct { + Total int64 `json:"total"` + Files []map[string]interface{} `json:"files"` + ParentFolder map[string]interface{} `json:"parent_folder"` +} + +// GetRootFolder gets or creates root folder for tenant +func (s *FileService) GetRootFolder(tenantID string) (map[string]interface{}, error) { + file, err := s.fileDAO.GetRootFolder(tenantID) + if err != nil { + return nil, err + } + return s.toFileResponse(file), nil +} + +// ListFiles lists files by parent folder ID +func (s *FileService) ListFiles(tenantID, pfID string, page, pageSize int, orderby string, desc bool, keywords string) (*ListFilesResponse, error) { + // If pfID is empty, get root folder + if pfID == "" { + rootFolder, err := s.fileDAO.GetRootFolder(tenantID) + if err != nil { + return nil, err + } + pfID = rootFolder.ID + } + + // Check if parent folder exists + if _, err := s.fileDAO.GetByID(pfID); err != nil { + return nil, err + } + + // Get files by parent folder ID + files, total, err := s.fileDAO.GetByPfID(tenantID, pfID, page, pageSize, orderby, desc, keywords) + if err != nil { + return nil, err + } + + // Get parent folder + parentFolder, err := s.fileDAO.GetParentFolder(pfID) + if err != nil { + return nil, err + } + + // Process files to add additional info + fileResponses := make([]map[string]interface{}, len(files)) + for i, file := range files { + fileInfo := s.toFileInfo(file) + + // If folder, calculate size and check for child folders + if file.Type == "folder" { + folderSize, err := s.fileDAO.GetFolderSize(file.ID) + if err == nil { + fileInfo.Size = folderSize + } + hasChild, err := s.fileDAO.HasChildFolder(file.ID) + if err == nil { + fileInfo.HasChildFolder = hasChild + } + fileInfo.KbsInfo = []map[string]interface{}{} + } else { + // Get KB info for non-folder files + kbsInfo, err := s.file2DocumentDAO.GetKBInfoByFileID(file.ID) + if err != nil { + kbsInfo = []map[string]interface{}{} + } + fileInfo.KbsInfo = kbsInfo + } + + fileResponses[i] = s.fileInfoToResponse(fileInfo) + } + + return &ListFilesResponse{ + Total: total, + Files: fileResponses, + ParentFolder: s.toFileResponse(parentFolder), + }, nil +} + +// toFileResponse converts file model to response format +func (s *FileService) toFileResponse(file *model.File) map[string]interface{} { + result := map[string]interface{}{ + "id": file.ID, + "parent_id": file.ParentID, + "tenant_id": file.TenantID, + "created_by": file.CreatedBy, + "name": file.Name, + "size": file.Size, + "type": file.Type, + "create_time": file.CreateTime, + "update_time": file.UpdateTime, + } + + if file.Location != nil { + result["location"] = *file.Location + } + result["source_type"] = file.SourceType + + return result +} + +// toFileInfo converts file model to FileInfo +func (s *FileService) toFileInfo(file *model.File) *FileInfo { + return &FileInfo{ + File: file, + Size: file.Size, + KbsInfo: []map[string]interface{}{}, + HasChildFolder: false, + } +} + +// fileInfoToResponse converts FileInfo to response map +func (s *FileService) fileInfoToResponse(info *FileInfo) map[string]interface{} { + result := map[string]interface{}{ + "id": info.File.ID, + "parent_id": info.File.ParentID, + "tenant_id": info.File.TenantID, + "created_by": info.File.CreatedBy, + "name": info.File.Name, + "size": info.Size, + "type": info.File.Type, + "create_time": info.File.CreateTime, + "update_time": info.File.UpdateTime, + "kbs_info": info.KbsInfo, + } + + if info.File.Location != nil { + result["location"] = *info.File.Location + } + result["source_type"] = info.File.SourceType + + if info.File.Type == "folder" { + result["has_child_folder"] = info.HasChildFolder + } + + return result +} + +// GetParentFolder gets parent folder of a file +func (s *FileService) GetParentFolder(fileID string) (map[string]interface{}, error) { + // Check if file exists + if _, err := s.fileDAO.GetByID(fileID); err != nil { + return nil, err + } + + // Get parent folder + parentFolder, err := s.fileDAO.GetParentFolder(fileID) + if err != nil { + return nil, err + } + + return s.toFileResponse(parentFolder), nil +} + +// GetAllParentFolders gets all parent folders in path +func (s *FileService) GetAllParentFolders(fileID string) ([]map[string]interface{}, error) { + // Check if file exists + if _, err := s.fileDAO.GetByID(fileID); err != nil { + return nil, err + } + + // Get all parent folders + parentFolders, err := s.fileDAO.GetAllParentFolders(fileID) + if err != nil { + return nil, err + } + + // Convert to response format + result := make([]map[string]interface{}, len(parentFolders)) + for i, folder := range parentFolders { + result[i] = s.toFileResponse(folder) + } + + return result, nil +} diff --git a/internal/service/kb.go b/internal/service/kb.go new file mode 100644 index 00000000000..8b982ebe6f6 --- /dev/null +++ b/internal/service/kb.go @@ -0,0 +1,82 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// KnowledgebaseService knowledge base service +type KnowledgebaseService struct { + kbDAO *dao.KnowledgebaseDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewKnowledgebaseService create knowledge base service +func NewKnowledgebaseService() *KnowledgebaseService { + return &KnowledgebaseService{ + kbDAO: dao.NewKnowledgebaseDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// ListKbsRequest list knowledge bases request +type ListKbsRequest struct { + Keywords *string `json:"keywords,omitempty"` + Page *int `json:"page,omitempty"` + PageSize *int `json:"page_size,omitempty"` + ParserID *string `json:"parser_id,omitempty"` + Orderby *string `json:"orderby,omitempty"` + Desc *bool `json:"desc,omitempty"` + OwnerIDs *[]string `json:"owner_ids,omitempty"` +} + +// ListKbsResponse list knowledge bases response +type ListKbsResponse struct { + KBs []*model.Knowledgebase `json:"kbs"` + Total int64 `json:"total"` +} + +// ListKbs list knowledge bases +func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, error) { + var kbs []*model.Knowledgebase + var total int64 + var err error + + // If owner IDs are provided, filter by them + if ownerIDs != nil && len(ownerIDs) > 0 { + kbs, total, err = s.kbDAO.ListByOwnerIDs(ownerIDs, page, pageSize, orderby, desc, keywords, parserID) + } else { + // Get tenant IDs by user ID + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + kbs, total, err = s.kbDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID) + } + + if err != nil { + return nil, err + } + + return &ListKbsResponse{ + KBs: kbs, + Total: total, + }, nil +} diff --git a/internal/service/llm.go b/internal/service/llm.go new file mode 100644 index 00000000000..5478f3d18fc --- /dev/null +++ b/internal/service/llm.go @@ -0,0 +1,248 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "strings" + + "ragflow/internal/dao" +) + +// LLMService LLM service +type LLMService struct { + tenantLLMDAO *dao.TenantLLMDAO + llmDAO *dao.LLMDAO +} + +// NewLLMService create LLM service +func NewLLMService() *LLMService { + return &LLMService{ + tenantLLMDAO: dao.NewTenantLLMDAO(), + llmDAO: dao.NewLLMDAO(), + } +} + +// MyLLMItem represents a single LLM item in the response +type MyLLMItem struct { + Type string `json:"type"` + Name string `json:"name"` + UsedToken int64 `json:"used_token"` + Status string `json:"status"` + APIBase string `json:"api_base,omitempty"` + MaxTokens int64 `json:"max_tokens,omitempty"` +} + +// MyLLMResponse represents the response structure for my LLMs +type MyLLMResponse struct { + Tags string `json:"tags"` + LLM []MyLLMItem `json:"llm"` +} + +// GetMyLLMs get my LLMs for a tenant +func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMResponse, error) { + // Get LLM list from database + myLLMs, err := s.tenantLLMDAO.GetMyLLMs(tenantID, includeDetails) + if err != nil { + return nil, err + } + + // Group by factory + result := make(map[string]MyLLMResponse) + providerDAO := dao.NewModelProviderDAO() + for _, llm := range myLLMs { + // Get or create factory entry + resp, exists := result[llm.LLMFactory] + if !exists { + resp = MyLLMResponse{ + Tags: llm.Tags, + LLM: []MyLLMItem{}, + } + } + + // Create LLM item + item := MyLLMItem{ + Type: llm.ModelType, + Name: llm.LLMName, + UsedToken: llm.UsedTokens, + Status: llm.Status, + } + + // Add detailed fields if requested + if includeDetails { + item.APIBase = llm.APIBase + item.MaxTokens = llm.MaxTokens + + // If APIBase is empty, try to get from model provider configuration + if item.APIBase == "" { + provider := providerDAO.GetProviderByName(llm.LLMFactory) + if provider != nil { + // Determine appropriate API base URL based on model type + switch llm.ModelType { + case "embedding": + if provider.DefaultEmbeddingURL != "" { + item.APIBase = provider.DefaultEmbeddingURL + } + // Add other model types here if needed + // case "chat": + // case "rerank": + // etc. + } + } + } + } + + resp.LLM = append(resp.LLM, item) + result[llm.LLMFactory] = resp + } + + return result, nil +} + +// LLMListItem represents a single LLM item in the list response +type LLMListItem struct { + LLMName string `json:"llm_name"` + ModelType string `json:"model_type"` + FID string `json:"fid"` + Available bool `json:"available"` + Status string `json:"status"` + MaxTokens int64 `json:"max_tokens,omitempty"` + CreateDate *string `json:"create_date,omitempty"` + CreateTime int64 `json:"create_time,omitempty"` + UpdateDate *string `json:"update_date,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` + IsTools bool `json:"is_tools"` + Tags string `json:"tags,omitempty"` +} + +// ListLLMsResponse represents the response for list LLMs +type ListLLMsResponse map[string][]LLMListItem + +// ListLLMs lists LLMs for a tenant with availability info +func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsResponse, error) { + selfDeployed := map[string]bool{ + "FastEmbed": true, + "Ollama": true, + "Xinference": true, + "LocalAI": true, + "LM-Studio": true, + "GPUStack": true, + } + + // Get tenant LLMs + tenantLLMs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) + if err != nil { + return nil, err + } + + // Build set of factories with valid API keys + facts := make(map[string]bool) + // Build set of valid LLM names@factories + status := make(map[string]bool) + for _, tl := range tenantLLMs { + if tl.APIKey != "" && tl.Status == "1" { + facts[tl.LLMFactory] = true + } + key := tl.LLMName + "@" + tl.LLMFactory + if tl.Status == "1" { + status[key] = true + } + } + + // Get all valid LLMs + allLLMs, err := s.llmDAO.GetAllValid() + if err != nil { + return nil, err + } + + // Filter and build result + llmSet := make(map[string]bool) + result := make(ListLLMsResponse) + + for _, llm := range allLLMs { + if llm.Status == nil || *llm.Status != "1" { + continue + } + + key := llm.LLMName + "@" + llm.FID + + // Check if valid (Builtin factory or in status set) + if llm.FID != "Builtin" && !status[key] { + continue + } + + // Filter by model type if specified + if modelType != "" && !strings.Contains(llm.ModelType, modelType) { + continue + } + + // Determine availability + available := facts[llm.FID] || selfDeployed[llm.FID] || llm.LLMName == "flag-embedding" + + item := LLMListItem{ + LLMName: llm.LLMName, + ModelType: llm.ModelType, + FID: llm.FID, + Available: available, + Status: "1", + MaxTokens: llm.MaxTokens, + IsTools: llm.IsTools, + Tags: llm.Tags, + } + + // Add BaseModel fields + if llm.CreateDate != nil { + createDateStr := llm.CreateDate.Format("2006-01-02T15:04:05") + item.CreateDate = &createDateStr + } + item.CreateTime = llm.CreateTime + if llm.UpdateDate != nil { + updateDateStr := llm.UpdateDate.Format("2006-01-02T15:04:05") + item.UpdateDate = &updateDateStr + } + if llm.UpdateTime != nil { + item.UpdateTime = llm.UpdateTime + } + + result[llm.FID] = append(result[llm.FID], item) + llmSet[key] = true + } + + // Add tenant LLMs that are not in the global list + for _, tl := range tenantLLMs { + key := tl.LLMName + "@" + tl.LLMFactory + if llmSet[key] { + continue + } + + // Filter by model type if specified + if modelType != "" && !strings.Contains(tl.ModelType, modelType) { + continue + } + + item := LLMListItem{ + LLMName: tl.LLMName, + ModelType: tl.ModelType, + FID: tl.LLMFactory, + Available: true, + Status: tl.Status, + } + + result[tl.LLMFactory] = append(result[tl.LLMFactory], item) + } + + return result, nil +} diff --git a/internal/service/model_bundle.go b/internal/service/model_bundle.go new file mode 100644 index 00000000000..0fff9652c9c --- /dev/null +++ b/internal/service/model_bundle.go @@ -0,0 +1,173 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "fmt" + + "ragflow/internal/model" +) + +// ModelBundle provides a unified interface for various model operations +// Similar to Python's LLMBundle but with a more generic name +type ModelBundle struct { + tenantID string + modelType model.ModelType + modelName string + model interface{} // underlying model instance +} + +// NewModelBundle creates a new ModelBundle for the given tenant and model type +// If modelName is empty, uses the default model for the tenant and type +func NewModelBundle(tenantID string, modelType model.ModelType, modelName ...string) (*ModelBundle, error) { + bundle := &ModelBundle{ + tenantID: tenantID, + modelType: modelType, + } + + // Use provided model name if available + if len(modelName) > 0 && modelName[0] != "" { + bundle.modelName = modelName[0] + } + + // Get model instance based on type + provider := NewModelProvider() + switch modelType { + case model.ModelTypeEmbedding: + embeddingModel, err := provider.GetEmbeddingModel(context.Background(), tenantID, bundle.modelName) + if err != nil { + return nil, fmt.Errorf("failed to get embedding model: %w", err) + } + bundle.model = embeddingModel + case model.ModelTypeChat: + chatModel, err := provider.GetChatModel(context.Background(), tenantID, bundle.modelName) + if err != nil { + return nil, fmt.Errorf("failed to get chat model: %w", err) + } + bundle.model = chatModel + case model.ModelTypeRerank: + rerankModel, err := provider.GetRerankModel(context.Background(), tenantID, bundle.modelName) + if err != nil { + return nil, fmt.Errorf("failed to get rerank model: %w", err) + } + bundle.model = rerankModel + default: + return nil, fmt.Errorf("unsupported model type: %s", modelType) + } + + return bundle, nil +} + +// Encode encodes a list of texts into embeddings +// Returns embeddings and token count (for compatibility with Python interface) +func (b *ModelBundle) Encode(texts []string) ([][]float64, int64, error) { + if b.modelType != model.ModelTypeEmbedding { + return nil, 0, fmt.Errorf("model type %s does not support encode", b.modelType) + } + + embeddingModel, ok := b.model.(model.EmbeddingModel) + if !ok { + return nil, 0, fmt.Errorf("model is not an embedding model") + } + + embeddings, err := embeddingModel.Encode(texts) + if err != nil { + return nil, 0, err + } + + // TODO: Calculate actual token count + // For now, return a dummy token count + tokenCount := int64(0) + for _, text := range texts { + tokenCount += int64(len(text) / 4) // rough approximation + } + + return embeddings, tokenCount, nil +} + +// EncodeQuery encodes a single query string into embedding +// Returns embedding and token count +func (b *ModelBundle) EncodeQuery(query string) ([]float64, int64, error) { + if b.modelType != model.ModelTypeEmbedding { + return nil, 0, fmt.Errorf("model type %s does not support encode query", b.modelType) + } + + embeddingModel, ok := b.model.(model.EmbeddingModel) + if !ok { + return nil, 0, fmt.Errorf("model is not an embedding model") + } + + embedding, err := embeddingModel.EncodeQuery(query) + if err != nil { + return nil, 0, err + } + + // TODO: Calculate actual token count + tokenCount := int64(len(query) / 4) + + return embedding, tokenCount, nil +} + +// Chat sends a chat message and returns response +func (b *ModelBundle) Chat(system string, history []map[string]string, genConf map[string]interface{}) (string, int64, error) { + if b.modelType != model.ModelTypeChat { + return "", 0, fmt.Errorf("model type %s does not support chat", b.modelType) + } + + chatModel, ok := b.model.(model.ChatModel) + if !ok { + return "", 0, fmt.Errorf("model is not a chat model") + } + + response, err := chatModel.Chat(system, history, genConf) + if err != nil { + return "", 0, err + } + + // TODO: Calculate actual token count + tokenCount := int64(len(response) / 4) + + return response, tokenCount, nil +} + +// Similarity calculates similarity between query and texts +func (b *ModelBundle) Similarity(query string, texts []string) ([]float64, int64, error) { + if b.modelType != model.ModelTypeRerank { + return nil, 0, fmt.Errorf("model type %s does not support similarity", b.modelType) + } + + rerankModel, ok := b.model.(model.RerankModel) + if !ok { + return nil, 0, fmt.Errorf("model is not a rerank model") + } + + similarities, err := rerankModel.Similarity(query, texts) + if err != nil { + return nil, 0, err + } + + // TODO: Calculate actual token count + tokenCount := int64(len(query)/4) + int64(len(texts)*10) + + return similarities, tokenCount, nil +} + +// GetModel returns the underlying model instance +func (b *ModelBundle) GetModel() interface{} { + return b.model +} diff --git a/internal/service/model_service.go b/internal/service/model_service.go new file mode 100644 index 00000000000..423c6856079 --- /dev/null +++ b/internal/service/model_service.go @@ -0,0 +1,117 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "fmt" + "net/http" + "ragflow/internal/dao" + "strings" + "time" + + "ragflow/internal/model" + "ragflow/internal/service/models" +) + +// ModelProvider provides model instances based on tenant and model type +type ModelProvider interface { + // GetEmbeddingModel returns an embedding model for the given tenant + GetEmbeddingModel(ctx context.Context, tenantID string, modelName string) (model.EmbeddingModel, error) + // GetChatModel returns a chat model for the given tenant + GetChatModel(ctx context.Context, tenantID string, modelName string) (model.ChatModel, error) + // GetRerankModel returns a rerank model for the given tenant + GetRerankModel(ctx context.Context, tenantID string, modelName string) (model.RerankModel, error) +} + +// ModelProviderImpl implements ModelProvider +type ModelProviderImpl struct { + httpClient *http.Client +} + +// NewModelProvider creates a new ModelProvider +func NewModelProvider() *ModelProviderImpl { + return &ModelProviderImpl{ + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// parseModelName parses a composite model name in format "model_name@provider" +// Returns modelName and provider separately +func parseModelName(compositeName string) (modelName, provider string, err error) { + parts := strings.Split(compositeName, "@") + if len(parts) == 2 { + return parts[0], parts[1], nil + } else if len(parts) == 1 { + return parts[0], "", fmt.Errorf("provider name missing in model name: %s", compositeName) + } else { + return "", "", fmt.Errorf("invalid model name format: %s", compositeName) + } +} + +// GetEmbeddingModel returns an embedding model for the given tenant +func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID string, compositeModelName string) (model.EmbeddingModel, error) { + // Parse composite model name to extract model name and provider + modelName, provider, err := parseModelName(compositeModelName) + if err != nil { + return nil, err + } + + // Get API key and configuration + embeddingModel, err := dao.NewTenantLLMDAO().GetByTenantFactoryAndModelName(tenantID, provider, modelName) + if err != nil { + return nil, err + } + + apiKey := embeddingModel.APIKey + if apiKey == "" { + return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName) + } + // Always get API base from model provider configuration + providerDAO := dao.NewModelProviderDAO() + providerConfig := providerDAO.GetProviderByName(provider) + if providerConfig == nil || providerConfig.DefaultEmbeddingURL == "" { + return nil, fmt.Errorf("no API base found for provider %s", provider) + } + apiBase := providerConfig.DefaultEmbeddingURL + + return models.CreateEmbeddingModel(provider, apiKey, apiBase, modelName, p.httpClient) +} + +// GetChatModel returns a chat model for the given tenant +func (p *ModelProviderImpl) GetChatModel(ctx context.Context, tenantID string, compositeModelName string) (model.ChatModel, error) { + // Parse composite model name to extract model name and provider + _, _, err := parseModelName(compositeModelName) + if err != nil { + return nil, err + } + // TODO: implement chat model creation + return nil, fmt.Errorf("chat model not implemented yet for model: %s", compositeModelName) +} + +// GetRerankModel returns a rerank model for the given tenant +func (p *ModelProviderImpl) GetRerankModel(ctx context.Context, tenantID string, compositeModelName string) (model.RerankModel, error) { + // Parse composite model name to extract model name and provider + _, _, err := parseModelName(compositeModelName) + if err != nil { + return nil, err + } + // TODO: implement rerank model creation + return nil, fmt.Errorf("rerank model not implemented yet for model: %s", compositeModelName) +} diff --git a/internal/service/models/deepseek_model.go b/internal/service/models/deepseek_model.go new file mode 100644 index 00000000000..0f7ccf37c77 --- /dev/null +++ b/internal/service/models/deepseek_model.go @@ -0,0 +1,33 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "net/http" + "ragflow/internal/model" +) + +func init() { + RegisterEmbeddingModelFactory("DeepSeek", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &openAIEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/factory.go b/internal/service/models/factory.go new file mode 100644 index 00000000000..36ad4d71e55 --- /dev/null +++ b/internal/service/models/factory.go @@ -0,0 +1,58 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "fmt" + "net/http" + "ragflow/internal/model" + "sync" +) + +// EmbeddingModelFactory creates an EmbeddingModel instance +type EmbeddingModelFactory func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel + +var ( + embeddingModelFactories = make(map[string]EmbeddingModelFactory) + factoryMu sync.RWMutex +) + +// RegisterEmbeddingModelFactory registers a factory for a provider name. +// Should be called from init() functions of provider implementations. +func RegisterEmbeddingModelFactory(providerName string, factory EmbeddingModelFactory) { + factoryMu.Lock() + defer factoryMu.Unlock() + embeddingModelFactories[providerName] = factory +} + +// GetEmbeddingModelFactory returns the factory for the given provider name. +// Returns nil if not found. +func GetEmbeddingModelFactory(providerName string) EmbeddingModelFactory { + factoryMu.RLock() + defer factoryMu.RUnlock() + return embeddingModelFactories[providerName] +} + +// CreateEmbeddingModel creates an EmbeddingModel instance for the given provider. +// Returns error if provider not registered. +func CreateEmbeddingModel(providerName, apiKey, apiBase, modelName string, httpClient *http.Client) (model.EmbeddingModel, error) { + factory := GetEmbeddingModelFactory(providerName) + if factory == nil { + return nil, fmt.Errorf("no embedding model factory registered for provider %s", providerName) + } + return factory(apiKey, apiBase, modelName, httpClient), nil +} diff --git a/internal/service/models/gitee_model.go b/internal/service/models/gitee_model.go new file mode 100644 index 00000000000..5b7e2d447c5 --- /dev/null +++ b/internal/service/models/gitee_model.go @@ -0,0 +1,126 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/model" + "strings" +) + +// giteeEmbeddingModel implements EmbeddingModel for GiteeAI API (assumed OpenAI-compatible) +type giteeEmbeddingModel struct { + apiKey string + apiBase string + model string + httpClient *http.Client +} + +// GiteeEmbeddingRequest represents GiteeAI embedding request +type GiteeEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + EncodeFormat string `json:"encode_format"` +} + +// GiteeEmbeddingResponse represents GiteeAI embedding response +type GiteeEmbeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` +} + +// Encode encodes a list of texts into embeddings using GiteeAI API +func (m *giteeEmbeddingModel) Encode(texts []string) ([][]float64, error) { + if len(texts) == 0 { + return [][]float64{}, nil + } + + reqBody := GiteeEmbeddingRequest{ + Model: m.model, + Input: texts, + EncodeFormat: "float", + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", m.apiBase, strings.NewReader(string(jsonData))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+m.apiKey) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("GiteeAI API error: %s, body: %s", resp.Status, string(body)) + } + + var embeddingResp GiteeEmbeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Sort embeddings by index to ensure correct order + embeddings := make([][]float64, len(texts)) + for _, data := range embeddingResp.Data { + if data.Index < len(embeddings) { + embeddings[data.Index] = data.Embedding + } + } + + return embeddings, nil +} + +// EncodeQuery encodes a single query string into embedding +func (m *giteeEmbeddingModel) EncodeQuery(query string) ([]float64, error) { + embeddings, err := m.Encode([]string{query}) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +// init registers the GiteeAI embedding model factory +func init() { + RegisterEmbeddingModelFactory("GiteeAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &giteeEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/moonshot_model.go b/internal/service/models/moonshot_model.go new file mode 100644 index 00000000000..ed0d3f72c3c --- /dev/null +++ b/internal/service/models/moonshot_model.go @@ -0,0 +1,33 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "net/http" + "ragflow/internal/model" +) + +func init() { + RegisterEmbeddingModelFactory("Moonshot", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &openAIEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/openai_api_compatible_model.go b/internal/service/models/openai_api_compatible_model.go new file mode 100644 index 00000000000..56f33af83e6 --- /dev/null +++ b/internal/service/models/openai_api_compatible_model.go @@ -0,0 +1,33 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "net/http" + "ragflow/internal/model" +) + +func init() { + RegisterEmbeddingModelFactory("OpenAI-API-Compatible", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &openAIEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/openai_model.go b/internal/service/models/openai_model.go new file mode 100644 index 00000000000..f52e4f04be3 --- /dev/null +++ b/internal/service/models/openai_model.go @@ -0,0 +1,123 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/model" + "strings" +) + +// openAIEmbeddingModel implements EmbeddingModel for OpenAI API +type openAIEmbeddingModel struct { + apiKey string + apiBase string + model string + httpClient *http.Client +} + +// OpenAIEmbeddingRequest represents OpenAI embedding request +type OpenAIEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +// OpenAIEmbeddingResponse represents OpenAI embedding response +type OpenAIEmbeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` +} + +// Encode encodes a list of texts into embeddings using OpenAI API +func (m *openAIEmbeddingModel) Encode(texts []string) ([][]float64, error) { + if len(texts) == 0 { + return [][]float64{}, nil + } + + reqBody := OpenAIEmbeddingRequest{ + Model: m.model, + Input: texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+m.apiKey) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OpenAI API error: %s, body: %s", resp.Status, string(body)) + } + + var embeddingResp OpenAIEmbeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Sort embeddings by index to ensure correct order + embeddings := make([][]float64, len(texts)) + for _, data := range embeddingResp.Data { + if data.Index < len(embeddings) { + embeddings[data.Index] = data.Embedding + } + } + + return embeddings, nil +} + +// EncodeQuery encodes a single query string into embedding +func (m *openAIEmbeddingModel) EncodeQuery(query string) ([]float64, error) { + embeddings, err := m.Encode([]string{query}) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +// init registers the OpenAI embedding model factory +func init() { + RegisterEmbeddingModelFactory("OpenAI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &openAIEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/siliconflow_model.go b/internal/service/models/siliconflow_model.go new file mode 100644 index 00000000000..2b40976c4da --- /dev/null +++ b/internal/service/models/siliconflow_model.go @@ -0,0 +1,123 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/model" + "strings" +) + +// siliconflowEmbeddingModel implements EmbeddingModel for SILICONFLOW API (OpenAI-compatible) +type siliconflowEmbeddingModel struct { + apiKey string + apiBase string + model string + httpClient *http.Client +} + +// SiliconflowEmbeddingRequest represents SILICONFLOW embedding request +type SiliconflowEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +// SiliconflowEmbeddingResponse represents SILICONFLOW embedding response +type SiliconflowEmbeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` +} + +// Encode encodes a list of texts into embeddings using SILICONFLOW API +func (m *siliconflowEmbeddingModel) Encode(texts []string) ([][]float64, error) { + if len(texts) == 0 { + return [][]float64{}, nil + } + + reqBody := SiliconflowEmbeddingRequest{ + Model: m.model, + Input: texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", m.apiBase+"/embeddings", strings.NewReader(string(jsonData))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+m.apiKey) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) + } + + var embeddingResp SiliconflowEmbeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Sort embeddings by index to ensure correct order + embeddings := make([][]float64, len(texts)) + for _, data := range embeddingResp.Data { + if data.Index < len(embeddings) { + embeddings[data.Index] = data.Embedding + } + } + + return embeddings, nil +} + +// EncodeQuery encodes a single query string into embedding +func (m *siliconflowEmbeddingModel) EncodeQuery(query string) ([]float64, error) { + embeddings, err := m.Encode([]string{query}) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +// init registers the SILICONFLOW embedding model factory +func init() { + RegisterEmbeddingModelFactory("SILICONFLOW", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &siliconflowEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/models/zhipu_model.go b/internal/service/models/zhipu_model.go new file mode 100644 index 00000000000..617cdf56472 --- /dev/null +++ b/internal/service/models/zhipu_model.go @@ -0,0 +1,33 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "net/http" + "ragflow/internal/model" +) + +func init() { + RegisterEmbeddingModelFactory("ZHIPU-AI", func(apiKey, apiBase, modelName string, httpClient *http.Client) model.EmbeddingModel { + return &openAIEmbeddingModel{ + apiKey: apiKey, + apiBase: apiBase, + model: modelName, + httpClient: httpClient, + } + }) +} diff --git a/internal/service/nlp/query_builder.go b/internal/service/nlp/query_builder.go new file mode 100644 index 00000000000..1a4cdf37b39 --- /dev/null +++ b/internal/service/nlp/query_builder.go @@ -0,0 +1,655 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "fmt" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + + "ragflow/internal/engine/infinity" + "ragflow/internal/tokenizer" + + "github.com/siongui/gojianfan" +) + +var ( + // globalQueryBuilder is the global query builder instance + globalQueryBuilder *QueryBuilder + // qbOnce ensures the query builder is initialized only once + qbOnce sync.Once + // qbInitError stores any error during initialization + qbInitError error +) + +// QueryBuilder provides functionality to build query expressions based on text, referencing Python's FulltextQueryer and QueryBase. +type QueryBuilder struct { + queryFields []string + termWeight *TermWeightDealer + synonym *Synonym +} + +// InitQueryBuilder initializes the global QueryBuilder with the given wordnet directory. +// It should be called during the initialization phase of main.go, after tokenizer.Init. +// The wordnetDir is typically filepath.Join(tokenizer.Config.DictPath, "wordnet") +func InitQueryBuilder(wordnetDir string) error { + qbOnce.Do(func() { + globalQueryBuilder = &QueryBuilder{ + queryFields: []string{ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "question_tks^20", + "content_ltks^2", + "content_sm_ltks", + }, + termWeight: NewTermWeightDealer(""), + synonym: NewSynonym(nil, "", wordnetDir), + } + }) + return qbInitError +} + +// InitQueryBuilderFromTokenizer initializes the global QueryBuilder using tokenizer's DictPath. +// The wordnet directory is derived from tokenizer's DictPath as: DictPath/wordnet +// This should be called after tokenizer.Init(). +func InitQueryBuilderFromTokenizer(tokenizerDictPath string) error { + wordnetDir := filepath.Join(tokenizerDictPath, "wordnet") + return InitQueryBuilder(wordnetDir) +} + +// GetQueryBuilder returns the global QueryBuilder instance. +// Returns nil if InitQueryBuilder has not been called. +func GetQueryBuilder() *QueryBuilder { + return globalQueryBuilder +} + +// NewQueryBuilder creates a new QueryBuilder with default query fields. +// Deprecated: Use GetQueryBuilder() to get the global instance for better performance. +func NewQueryBuilder() *QueryBuilder { + return &QueryBuilder{ + queryFields: []string{ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "question_tks^20", + "content_ltks^2", + "content_sm_ltks", + }, + termWeight: NewTermWeightDealer(""), + synonym: NewSynonym(nil, "", ""), + } +} + +// IsChinese determines whether a line of text is primarily Chinese. +// Algorithm: split by whitespace, if segments <=3 return true; otherwise count ratio of non-pure-alphabet segments, return true if ratio >=0.7. +func (qb *QueryBuilder) IsChinese(line string) bool { + fields := strings.Fields(line) + if len(fields) <= 3 { + return true + } + nonAlpha := 0 + for _, f := range fields { + matched, _ := regexp.MatchString(`^[a-zA-Z]+$`, f) + if !matched { + nonAlpha++ + } + } + return float64(nonAlpha)/float64(len(fields)) >= 0.7 +} + +// SubSpecialChar escapes special characters for use in queries. +func (qb *QueryBuilder) SubSpecialChar(line string) string { + // Regex matches : { } / [ ] - * " ( ) | + ~ ^ and prepends backslash + re := regexp.MustCompile(`([:{}/\[\]\-\*"\(\)\|\+~\^])`) + return re.ReplaceAllString(line, `\$1`) +} + +// RmWWW removes common stop words and question words from queries. +func (qb *QueryBuilder) RmWWW(txt string) string { + patterns := []struct { + regex string + repl string + }{ + // Chinese stop words + {`是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*`, ""}, + // English stop words (case-insensitive) + {`(^| )(what|who|how|which|where|why)('re|'s)? `, " "}, + {`(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) `, " "}, + } + original := txt + for _, p := range patterns { + re := regexp.MustCompile(`(?i)` + p.regex) + txt = re.ReplaceAllString(txt, p.repl) + } + if txt == "" { + txt = original + } + return txt +} + +// AddSpaceBetweenEngZh adds spaces between English letters and Chinese characters to improve tokenization. +func (qb *QueryBuilder) AddSpaceBetweenEngZh(txt string) string { + // (ENG/ENG+NUM) + ZH: e.g., "ABC123中文" -> "ABC123 中文" + re1 := regexp.MustCompile(`([A-Za-z]+[0-9]*)([\x{4e00}-\x{9fa5}]+)`) + txt = re1.ReplaceAllString(txt, "$1 $2") + + // ENG + ZH: e.g., "ABC中文" -> "ABC 中文" + re2 := regexp.MustCompile(`([A-Za-z])([\x{4e00}-\x{9fa5}]+)`) + txt = re2.ReplaceAllString(txt, "$1 $2") + + // ZH + (ENG/ENG+NUM): e.g., "中文ABC123" -> "中文 ABC123" + re3 := regexp.MustCompile(`([\x{4e00}-\x{9fa5}]+)([A-Za-z]+[0-9]*)`) + txt = re3.ReplaceAllString(txt, "$1 $2") + + // ZH + ENG: e.g., "中文ABC" -> "中文 ABC" + re4 := regexp.MustCompile(`([\x{4e00}-\x{9fa5}]+)([A-Za-z])`) + txt = re4.ReplaceAllString(txt, "$1 $2") + return txt +} + +// StrFullWidth2HalfWidth converts full-width characters to half-width characters. +// Algorithm: For each character: +// - Full-width space (U+3000) is converted to half-width space (U+0020). +// - For other characters, subtract 0xFEE0 from its code point. +// - If the resulting code point is not in the half-width character range (0x0020 to 0x7E), +// the original character is kept. +func (qb *QueryBuilder) StrFullWidth2HalfWidth(ustring string) string { + var rstring strings.Builder + for _, uchar := range ustring { + insideCode := int32(uchar) + if insideCode == 0x3000 { + insideCode = 0x0020 + } else { + insideCode -= 0xFEE0 + } + if insideCode < 0x0020 || insideCode > 0x7E { + rstring.WriteRune(uchar) + } else { + rstring.WriteRune(insideCode) + } + } + return rstring.String() +} + +// Traditional2Simplified converts traditional Chinese characters to simplified Chinese characters. +// Uses gojianfan library which provides conversion similar to Python's HanziConv. +func (qb *QueryBuilder) Traditional2Simplified(line string) string { + return gojianfan.T2S(line) +} + +// NeedFineGrainedTokenize determines if fine-grained tokenization is needed for a token. +// Reference: rag/nlp/query.py L88-93 +func (qb *QueryBuilder) NeedFineGrainedTokenize(tk string) bool { + if len(tk) < 3 { + return false + } + if matched, _ := regexp.MatchString(`^[0-9a-z\.\+#_\*-]+$`, tk); matched { + return false + } + return true +} + +// Question builds a full-text query expression based on input text. +// References Python FulltextQueryer.question method. +// Currently, a simplified version, returns basic MatchTextExpr; future integration of term weight and synonyms. +func (qb *QueryBuilder) Question(txt string, tbl string, minMatch float64) (*infinity.MatchTextExpr, []string) { + // originalQuery stores the original input text for later use in query expression. + originalQuery := txt + + // Add space between English and Chinese + txtWithSpaces := qb.AddSpaceBetweenEngZh(txt) + + // Convert to lowercase and remove punctuation (simplified) + txtLower := strings.ToLower(txtWithSpaces) + + // Convert to half-width + txtHalfWidth := qb.StrFullWidth2HalfWidth(txtLower) + + // Convert to simplified Chinese + txtSimplified := qb.Traditional2Simplified(txtHalfWidth) + + // Replace punctuation and special characters with space + // Reference: rag/nlp/query.py L44-48 + // re is the regex pattern for matching punctuation and special characters. + re := regexp.MustCompile(`[ :|\r\n\t,,.。??/\` + "`" + `!!&^%()\[\]{}<>]+`) + // txtCleaned is the text after removing punctuation and special characters. + txtCleaned := re.ReplaceAllString(txtSimplified, " ") + + // Remove stop words + txtNoStopWords := qb.RmWWW(txtCleaned) + + // Determine if text is Chinese + if !qb.IsChinese(txtNoStopWords) { + // Non-Chinese processing + // Reference: rag/nlp/query.py L52-88 + + // Remove stop words again + // txtFinal is the text after removing stop words again. + txtFinal := qb.RmWWW(txtNoStopWords) + + // Tokenize using rag_tokenizer + tokenized, err := tokenizer.Tokenize(txtFinal) + if err != nil { + // If tokenizer fails, use simple split + tokenized = txtFinal + } + + // tks are tokens obtained by splitting the tokenized text by whitespace. + tks := strings.Fields(tokenized) + // keywords stores the non‑empty tokens as keywords. + keywords := make([]string, 0, len(tks)) + for _, t := range tks { + if t != "" { + keywords = append(keywords, t) + } + } + + // Calculate term weights using TermWeightDealer + // Reference: rag/nlp/query.py L56 + // tws holds the term weight list for each token. + tws := qb.termWeight.Weights(tks, false) + + // Clean tokens and filter + // Reference: rag/nlp/query.py L57-60 + type tokenWeight struct { + tk string + w float64 + } + // tksW holds the cleaned tokens with their weights. + var tksW []tokenWeight + for _, tw := range tws { + tk := tw.Term + w := tw.Weight + + // Clean token: remove special chars + tk = regexp.MustCompile(`[ \"'^]+`).ReplaceAllString(tk, "") + // Remove single alphanumeric chars + tk = regexp.MustCompile(`^[a-z0-9]$`).ReplaceAllString(tk, "") + // Remove leading +/- + tk = regexp.MustCompile(`^[\+\-]+`).ReplaceAllString(tk, "") + tk = strings.TrimSpace(tk) + + if tk == "" { + continue + } + tksW = append(tksW, tokenWeight{tk, w}) + } + + // Limit to 256 tokens + // Reference: rag/nlp/query.py L62 + if len(tksW) > 256 { + tksW = tksW[:256] + } + + // TODO: Synonym expansion (reference L61-67) + // For now, use empty synonyms + // syns is a placeholder for synonym expansion (currently empty). + syns := make([]string, len(tksW)) + + // Build query parts + // Reference: rag/nlp/query.py L69-70 + // q collects the query part strings. + var q []string + for i, tw := range tksW { + tk := tw.tk + w := tw.w + // Skip tokens with special regex chars + if matched, _ := regexp.MatchString(`[.^+\(\)-]`, tk); matched { + continue + } + // Format: (token^weight synonym) + q = append(q, fmt.Sprintf("(%s^%.4f %s)", tk, w, syns[i])) + } + + // Add phrase queries for adjacent tokens + // Reference: rag/nlp/query.py L71-82 + for i := 1; i < len(tksW); i++ { + left := strings.TrimSpace(tksW[i-1].tk) + right := strings.TrimSpace(tksW[i].tk) + if left == "" || right == "" { + continue + } + // maxW is the maximum weight between two adjacent tokens. + maxW := tksW[i-1].w + if tksW[i].w > maxW { + maxW = tksW[i].w + } + q = append(q, fmt.Sprintf(`"%s %s"^%.4f`, left, right, maxW*2)) + } + + if len(q) == 0 { + q = append(q, txtFinal) + } + + // query is the final query string built from all query parts. + query := strings.Join(q, " ") + return &infinity.MatchTextExpr{ + Fields: qb.queryFields, + MatchingText: query, + TopN: 100, + ExtraOptions: map[string]interface{}{ + "original_query": originalQuery, + }, + }, keywords + } + // Chinese processing + // Reference: rag/nlp/query.py L88-172 + + // Save original text before removing stop words (for fallback) + // otxt holds the original text before removing stop words, used as fallback. + otxt := txtNoStopWords + + // Remove stop words for Chinese processing + // txtChinese is the text after removing stop words for Chinese processing. + txtChinese := qb.RmWWW(txtNoStopWords) + + // qs collects query strings for each segment. + var qs []string + // keywords stores keywords extracted from segments. + var keywords []string + + // Split text and process each segment (limit to 256) + // segments are the text segments after splitting by term weight. + segments := qb.termWeight.Split(txtChinese) + if len(segments) > 256 { + segments = segments[:256] + } + + for _, segment := range segments { + if segment == "" { + continue + } + keywords = append(keywords, segment) + + // Get term weights + // termWeightList holds term weights for the current segment. + termWeightList := qb.termWeight.Weights([]string{segment}, true) + + // Lookup synonyms + // syns are synonyms for the current segment. + syns := qb.synonym.Lookup(segment, 8) + if len(syns) > 0 && len(keywords) < 32 { + keywords = append(keywords, syns...) + } + + // Sort by weight descending + sort.Slice(termWeightList, func(i, j int) bool { + return termWeightList[i].Weight > termWeightList[j].Weight + }) + + // terms stores term strings with their weights for the current segment. + var terms []struct { + term string + weight float64 + } + + for _, termWeight := range termWeightList { + term := termWeight.Term + weight := termWeight.Weight + + // Fine-grained tokenization if needed + // sm holds fine‑grained tokens for the current term. + var sm []string + if qb.NeedFineGrainedTokenize(term) { + fineGrained, err := tokenizer.FineGrainedTokenize(term) + if err == nil && fineGrained != "" { + sm = strings.Fields(fineGrained) + } + } + + // Clean special characters from sm + // cleanSm holds cleaned fine‑grained tokens with special characters removed. + var cleanSm []string + // specialCharRe is the regex pattern for matching special characters. + specialCharRe := regexp.MustCompile(`[,\.\/;'\[\]\\\` + "`" + `~!@#$%\^&\*\(\)=\+_<>\?:"\{\}\|,。;'‘’【】、!¥……()——《》?:"""-]+`) + for _, m := range sm { + m = specialCharRe.ReplaceAllString(m, "") + m = qb.SubSpecialChar(m) + if len(m) > 1 { + cleanSm = append(cleanSm, m) + } + } + sm = cleanSm + + // Add to keywords if under limit + if len(keywords) < 32 { + // cleanTk is the term with quotes and spaces removed. + cleanTk := regexp.MustCompile(`[ \"']+`).ReplaceAllString(term, "") + if cleanTk != "" { + keywords = append(keywords, cleanTk) + } + keywords = append(keywords, sm...) + } + + // Lookup synonyms for this token + // tkSyns are synonyms for the current term. + tkSyns := qb.synonym.Lookup(term, 8) + for i, s := range tkSyns { + tkSyns[i] = qb.SubSpecialChar(s) + } + if len(keywords) < 32 { + for _, s := range tkSyns { + if s != "" { + keywords = append(keywords, s) + } + } + } + + // Fine-grained tokenize synonyms + // fineGrainedSyns holds fine‑grained tokenized synonyms. + var fineGrainedSyns []string + for _, s := range tkSyns { + if s == "" { + continue + } + fg, err := tokenizer.FineGrainedTokenize(s) + if err == nil && fg != "" { + // Quote if contains space + if strings.Contains(fg, " ") { + fg = fmt.Sprintf(`"%s"`, fg) + } + fineGrainedSyns = append(fineGrainedSyns, fg) + } + } + + if len(keywords) >= 32 { + break + } + + // Clean token for query + term = qb.SubSpecialChar(term) + if term == "" { + continue + } + + // Quote if contains space + if strings.Contains(term, " ") { + term = fmt.Sprintf(`"%s"`, term) + } + + // Build query part with synonyms + if len(fineGrainedSyns) > 0 { + term = fmt.Sprintf("(%s OR (%s)^0.2)", term, strings.Join(fineGrainedSyns, " ")) + } + if len(sm) > 0 { + smStr := strings.Join(sm, " ") + term = fmt.Sprintf(`%s OR "%s" OR ("%s"~2)^0.5`, term, smStr, smStr) + } + + terms = append(terms, struct { + term string + weight float64 + }{term, weight}) + } + + // Build query string for this segment + // termParts collects query parts for each term in the segment. + var termParts []string + for _, termWeight := range terms { + termParts = append(termParts, fmt.Sprintf("(%s)^%.4f", termWeight.term, termWeight.weight)) + } + // tmsStr is the query string for the current segment. + tmsStr := strings.Join(termParts, " ") + + // Add proximity query if multiple tokens + if len(termWeightList) > 1 { + // tokenized is the tokenized version of the segment. + tokenized, _ := tokenizer.Tokenize(segment) + if tokenized != "" { + tmsStr += fmt.Sprintf(` ("%s"~2)^1.5`, tokenized) + } + } + + // Add segment-level synonyms + if len(syns) > 0 && tmsStr != "" { + // synParts collects synonym query parts. + var synParts []string + for _, s := range syns { + s = qb.SubSpecialChar(s) + if s != "" { + tokenized, _ := tokenizer.Tokenize(s) + if tokenized != "" { + synParts = append(synParts, fmt.Sprintf(`"%s"`, tokenized)) + } + } + } + if len(synParts) > 0 { + tmsStr = fmt.Sprintf("(%s)^5 OR (%s)^0.7", tmsStr, strings.Join(synParts, " OR ")) + } + } + + if tmsStr != "" { + qs = append(qs, tmsStr) + } else { + fmt.Println("tmsStr is empty") + } + } + + // Build final query + if len(qs) > 0 { + // queryParts collects final query parts for each segment. + var queryParts []string + for _, q := range qs { + if q != "" { + queryParts = append(queryParts, fmt.Sprintf("(%s)", q)) + } + } + // query is the final query string built from all segments. + query := strings.Join(queryParts, " OR ") + if query == "" { + query = otxt + } + return &infinity.MatchTextExpr{ + Fields: qb.queryFields, + MatchingText: query, + TopN: 100, + ExtraOptions: map[string]interface{}{ + "minimum_should_match": minMatch, + "original_query": originalQuery, + }, + }, keywords + } + + return nil, keywords +} + +// Paragraph builds a query expression based on content terms and keywords. +// References Python FulltextQueryer.paragraph method. +func (qb *QueryBuilder) Paragraph(contentTks string, keywords []string, keywordsTopN int) *infinity.MatchTextExpr { + // Simplified implementation: merge keywords and content terms + allTerms := make([]string, 0, len(keywords)) + for _, k := range keywords { + k = strings.TrimSpace(k) + if k != "" { + allTerms = append(allTerms, `"`+k+`"`) + } + } + // Limit number of keywords + if keywordsTopN > 0 && len(allTerms) > keywordsTopN { + allTerms = allTerms[:keywordsTopN] + } + // Could add content term processing here, e.g., tokenization, weight calculation + // Currently only uses keywords + query := strings.Join(allTerms, " ") + // Calculate minimum_should_match (could be used for extra_options in future) + _ = 3 + if len(allTerms) > 0 { + calc := int(float64(len(allTerms)) / 10.0) + if calc < 3 { + calc = 3 + } + _ = calc + } + return &infinity.MatchTextExpr{ + Fields: qb.queryFields, + MatchingText: query, + TopN: 100, + } +} + +// Similarity calculates similarity between two term weight dictionaries. +// Algorithm: s = sum(qtwt[k] for k in qtwt if k in dtwt) / sum(qtwt[k]) +func (qb *QueryBuilder) Similarity(qtwt map[string]float64, dtwt map[string]float64) float64 { + if len(qtwt) == 0 { + return 0.0 + } + var sum float64 + for k, v := range qtwt { + if _, ok := dtwt[k]; ok { + sum += v + } + } + var total float64 + for _, v := range qtwt { + total += v + } + if total == 0 { + return 0.0 + } + return sum / total +} + +// TokenSimilarity calculates similarity between query terms and multiple document term sets. +// To be implemented: requires term weight processing module. +func (qb *QueryBuilder) TokenSimilarity(atks string, btkss []string) []float64 { + // Placeholder implementation, returns zero values + result := make([]float64, len(btkss)) + for i := range result { + result[i] = 0.0 + } + return result +} + +// HybridSimilarity calculates weighted combination of vector similarity and term similarity. +// To be implemented: requires vector cosine similarity calculation. +func (qb *QueryBuilder) HybridSimilarity(avec []float64, bvecs [][]float64, atks string, btkss []string, tkweight float64, vtweight float64) ([]float64, []float64, []float64) { + // Placeholder implementation, returns zero values + n := len(btkss) + sims := make([]float64, n) + tksim := make([]float64, n) + vecsim := make([]float64, n) + return sims, tksim, vecsim +} + +// SetQueryFields sets the list of query fields. +func (qb *QueryBuilder) SetQueryFields(fields []string) { + qb.queryFields = fields +} diff --git a/internal/service/nlp/query_builder_test.go b/internal/service/nlp/query_builder_test.go new file mode 100644 index 00000000000..238a40317a7 --- /dev/null +++ b/internal/service/nlp/query_builder_test.go @@ -0,0 +1,471 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "reflect" + "testing" + + "ragflow/internal/engine/infinity" +) + +func TestNewQueryBuilder(t *testing.T) { + qb := NewQueryBuilder() + if qb == nil { + t.Fatal("NewQueryBuilder returned nil") + } + // Check default fields + expectedFields := []string{ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "question_tks^20", + "content_ltks^2", + "content_sm_ltks", + } + if !reflect.DeepEqual(qb.queryFields, expectedFields) { + t.Errorf("Default query fields mismatch, got %v, want %v", qb.queryFields, expectedFields) + } +} + +func TestQueryBuilder_IsChinese(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + line string + expected bool + }{ + {"Empty", "", true}, // fields <=3 + {"Single Chinese char", "中", true}, + {"Two Chinese chars", "中文", true}, + {"Three Chinese chars", "中文字", true}, + {"Four Chinese chars", "中文字符", true}, // ratio >=0.7 + {"Mixed with English", "hello world", true}, // fields=2 <=3 + {"Mostly Chinese", "hello 世界 测试", true}, // fields=3 <=3 + {"Mostly English", "hello world test", true}, // fields=3 <=3 + {"English with punctuation", "Hello, world!", true}, // fields=2 <=3 (after split) + {"Chinese with spaces", "这 是 一个 测试", true}, // fields=4, non-alpha=4, ratio=1 >=0.7 + {"Mixed with numbers", "123 abc", true}, // fields=2 <=3 + // Additional cases where fields >3 and ratio determines result + {"Many English words", "this is a long english sentence", false}, // fields=6, non-alpha=0, ratio=0 <0.7 + {"Mixed with mostly Chinese", "hello world 中文 测试 多个", false}, // fields=5, non-alpha=3, ratio=0.6 <0.7 => false + {"Mostly Chinese with many words", "这 是 一个 中文 测试 多个 汉字", true}, // fields=7, non-alpha=7, ratio=1 >=0.7 + {"English with Chinese suffix", "hello world 中文", true}, // fields=3 <=3 + {"Chinese with English suffix", "中文 test", true}, // fields=2 <=3 + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.IsChinese(tt.line) + if result != tt.expected { + t.Errorf("IsChinese(%q) = %v, want %v", tt.line, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_SubSpecialChar(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + input string + expected string + }{ + {"No special chars", "hello world", "hello world"}, + {"Colon", "test: colon", `test\: colon`}, + {"Curly braces", "{braces}", `\{braces\}`}, + {"Slash", "path/to/file", `path\/to\/file`}, + {"Square brackets", "[brackets]", `\[brackets\]`}, + {"Hyphen", "a-b-c", `a\-b\-c`}, + {"Asterisk", "a*b", `a\*b`}, + {"Quote", `"quote"`, `\"quote\"`}, + {"Parentheses", "(parens)", `\(parens\)`}, + {"Pipe", "a|b", `a\|b`}, + {"Plus", "a+b", `a\+b`}, + {"Tilde", "~tilde", `\~tilde`}, + {"Caret", "^caret", `\^caret`}, + {"Multiple", `:{}/[]-*"()|+~^`, `\:\{\}\/\[\]\-\*\"\(\)\|\+\~\^`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.SubSpecialChar(tt.input) + if result != tt.expected { + t.Errorf("SubSpecialChar(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_RmWWW(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + input string + expected string + }{ + {"Empty", "", ""}, + {"No stop words", "普通文本", "普通文本"}, + {"Chinese question word", "请问如何操作", "操作"}, // "请问" and "如何" both matched + {"Chinese stop word 怎么办", "怎么办安装", "安装"}, + {"English what", "what is this", " this"}, // removes "what " and "is " + {"English who", "who are you", " you"}, // removes "who " and "are " + {"Mixed stop words", "请问what is the problem", " the problem"}, // Chinese removed, "what ", "is " removed + {"All removed becomes empty", "请问", "请问"}, // should revert to original + {"English articles", "the cat is on a mat", " cat on mat"}, // removes "the ", "is ", "a " + {"Case insensitive", "WHAT IS THIS", " THIS"}, // removes "WHAT " and "IS " + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.RmWWW(tt.input) + if result != tt.expected { + t.Errorf("RmWWW(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_AddSpaceBetweenEngZh(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + input string + expected string + }{ + {"Empty", "", ""}, + {"English only", "hello world", "hello world"}, + {"Chinese only", "你好世界", "你好世界"}, + {"ENG+ZH", "hello世界", "hello 世界"}, + {"ZH+ENG", "世界hello", "世界 hello"}, + {"ENG+NUM+ZH", "abc123测试", "abc123 测试"}, + {"ZH+ENG+NUM", "测试abc123", "测试 abc123"}, + {"Multiple", "hello世界test测试", "hello 世界 test 测试"}, + {"Already spaced", "hello 世界", "hello 世界"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.AddSpaceBetweenEngZh(tt.input) + if result != tt.expected { + t.Errorf("AddSpaceBetweenEngZh(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_StrFullWidth2HalfWidth(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + input string + expected string + }{ + {"Empty", "", ""}, + {"Half-width remains", "hello world 123", "hello world 123"}, + {"Full-width uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "ABCDEFGHIJKLMNOPQRSTUVWXYZ"}, + {"Full-width lowercase", "abcdefghijklmnopqrstuvwxyz", "abcdefghijklmnopqrstuvwxyz"}, + {"Full-width digits", "0123456789", "0123456789"}, + {"Full-width punctuation", "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~", "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"}, + {"Full-width space", " ", " "}, + {"Mixed full-width and half-width", "Hello World!123", "Hello World!123"}, + {"Chinese characters unchanged", "你好世界", "你好世界"}, + {"Japanese characters unchanged", "こんにちは", "こんにちは"}, + {"Korean characters unchanged", "안녕하세요", "안녕하세요"}, + {"Full-width symbols outside range", "@@@", "@@@"}, // Actually full-width '@' is U+FF20 which maps to U+0040 + {"Edge case: character just below range", "\u001F", "\u001F"}, // U+001F is < 0x0020, should remain + {"Edge case: character just above range", "\u007F", "\u007F"}, // U+007F is > 0x7E, should remain + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.StrFullWidth2HalfWidth(tt.input) + if result != tt.expected { + t.Errorf("StrFullWidth2HalfWidth(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_Traditional2Simplified(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + input string + expected string + }{ + {"Empty", "", ""}, + {"Simplified unchanged", "简体中文测试", "简体中文测试"}, + {"Traditional conversion", "繁體中文測試", "繁体中文测试"}, + {"Traditional sentence", "我學習中文已經三年了", "我学习中文已经三年了"}, + {"Traditional with numbers", "電話號碼123", "电话号码123"}, + {"Traditional with English", "Hello世界", "Hello世界"}, + {"Traditional punctuation", "請問,你好嗎?", "请问,你好吗?"}, + {"Mixed traditional and simplified", "這是一個简体测试", "这是一个简体测试"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.Traditional2Simplified(tt.input) + if result != tt.expected { + t.Errorf("Traditional2Simplified(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_Question(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + txt string + tbl string + minMatch float64 + expectNil bool + checkExpr func(*infinity.MatchTextExpr) bool + checkKeywords func([]string) bool + }{ + { + name: "Chinese text", + txt: "请问如何安装软件", + tbl: "test", + minMatch: 0.5, + checkExpr: func(expr *infinity.MatchTextExpr) bool { + // Should return a valid query expression with processed text + return expr != nil && expr.MatchingText != "" + }, + checkKeywords: func(keywords []string) bool { + // Should return extracted keywords + return len(keywords) > 0 + }, + }, + { + name: "English text", + txt: "How to install software", + tbl: "test", + minMatch: 0.5, + checkExpr: func(expr *infinity.MatchTextExpr) bool { + // Should return a valid query expression with processed text + return expr != nil && expr.MatchingText != "" + }, + checkKeywords: func(keywords []string) bool { + // Should return extracted keywords + return len(keywords) > 0 + }, + }, + { + name: "Mixed text", + txt: "hello世界", + tbl: "test", + minMatch: 0.5, + checkExpr: func(expr *infinity.MatchTextExpr) bool { + // Should return a valid query expression with processed text + return expr != nil && expr.MatchingText != "" + }, + checkKeywords: func(keywords []string) bool { + // Should return extracted keywords + return len(keywords) > 0 + }, + }, + { + name: "Empty text", + txt: "", + tbl: "test", + minMatch: 0.5, + expectNil: true, + checkExpr: func(expr *infinity.MatchTextExpr) bool { + return expr == nil + }, + checkKeywords: func(keywords []string) bool { + return len(keywords) == 0 + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, keywords := qb.Question(tt.txt, tt.tbl, tt.minMatch) + if tt.expectNil && expr != nil { + t.Errorf("Question(%q) expected nil expr, got %v", tt.txt, expr) + } + if !tt.expectNil && expr == nil { + t.Errorf("Question(%q) returned nil expr", tt.txt) + } + if expr != nil && !tt.checkExpr(expr) { + t.Errorf("Question(%q) expr check failed, got %+v", tt.txt, expr) + } + if tt.checkKeywords != nil && !tt.checkKeywords(keywords) { + t.Errorf("Question(%q) keywords check failed, got %v", tt.txt, keywords) + } + }) + } +} + +func TestQueryBuilder_Paragraph(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + contentTks string + keywords []string + keywordsTopN int + expectedQuery string + }{ + { + name: "No keywords", + contentTks: "some content terms", + keywords: []string{}, + keywordsTopN: 0, + expectedQuery: "", + }, + { + name: "Single keyword", + contentTks: "content", + keywords: []string{"hello"}, + keywordsTopN: 0, + expectedQuery: `"hello"`, + }, + { + name: "Multiple keywords", + contentTks: "content", + keywords: []string{"hello", "world", "test"}, + keywordsTopN: 0, + expectedQuery: `"hello" "world" "test"`, + }, + { + name: "Trim spaces", + contentTks: "", + keywords: []string{" hello ", " world "}, + keywordsTopN: 0, + expectedQuery: `"hello" "world"`, + }, + { + name: "TopN limit", + contentTks: "", + keywords: []string{"a", "b", "c", "d", "e"}, + keywordsTopN: 3, + expectedQuery: `"a" "b" "c"`, + }, + { + name: "TopN larger than slice", + contentTks: "", + keywords: []string{"a", "b"}, + keywordsTopN: 10, + expectedQuery: `"a" "b"`, + }, + { + name: "Empty keyword filtered", + contentTks: "", + keywords: []string{"a", "", "b"}, + keywordsTopN: 0, + expectedQuery: `"a" "b"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr := qb.Paragraph(tt.contentTks, tt.keywords, tt.keywordsTopN) + if expr == nil { + t.Fatal("Paragraph returned nil expr") + } + if expr.MatchingText != tt.expectedQuery { + t.Errorf("Paragraph query mismatch, got %q, want %q", expr.MatchingText, tt.expectedQuery) + } + // Check default fields + defaultFields := []string{ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "question_tks^20", + "content_ltks^2", + "content_sm_ltks", + } + if !reflect.DeepEqual(expr.Fields, defaultFields) { + t.Errorf("Paragraph fields mismatch, got %v, want %v", expr.Fields, defaultFields) + } + if expr.TopN != 100 { + t.Errorf("Paragraph TopN mismatch, got %d, want 100", expr.TopN) + } + }) + } +} + +func TestQueryBuilder_Similarity(t *testing.T) { + qb := NewQueryBuilder() + tests := []struct { + name string + qtwt map[string]float64 + dtwt map[string]float64 + expected float64 + }{ + {"Empty query", map[string]float64{}, map[string]float64{"a": 1.0}, 0.0}, + {"Empty doc", map[string]float64{"a": 1.0}, map[string]float64{}, 0.0}, + {"Exact match", map[string]float64{"a": 1.0, "b": 2.0}, map[string]float64{"a": 5.0, "b": 3.0}, 1.0}, + {"Partial match", map[string]float64{"a": 1.0, "b": 2.0, "c": 3.0}, map[string]float64{"a": 1.0, "c": 1.0}, (1.0 + 3.0) / (1.0 + 2.0 + 3.0)}, // sum=4, total=6 => 0.666... + {"No match", map[string]float64{"a": 1.0}, map[string]float64{"b": 2.0}, 0.0}, + {"Zero total weight", map[string]float64{"a": 0.0, "b": 0.0}, map[string]float64{"a": 1.0}, 0.0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := qb.Similarity(tt.qtwt, tt.dtwt) + // Use tolerance for floating point + if result < tt.expected-1e-9 || result > tt.expected+1e-9 { + t.Errorf("Similarity(%v, %v) = %v, want %v", tt.qtwt, tt.dtwt, result, tt.expected) + } + }) + } +} + +func TestQueryBuilder_TokenSimilarity(t *testing.T) { + qb := NewQueryBuilder() + // Currently placeholder returns zero slice + atks := "query terms" + btkss := []string{"doc1", "doc2", "doc3"} + result := qb.TokenSimilarity(atks, btkss) + if len(result) != len(btkss) { + t.Errorf("TokenSimilarity length mismatch, got %d, want %d", len(result), len(btkss)) + } + for i, v := range result { + if v != 0.0 { + t.Errorf("TokenSimilarity[%d] = %v, want 0.0", i, v) + } + } +} + +func TestQueryBuilder_HybridSimilarity(t *testing.T) { + qb := NewQueryBuilder() + avec := []float64{1.0, 2.0} + bvecs := [][]float64{{1.0, 2.0}, {3.0, 4.0}} + atks := "query" + btkss := []string{"doc1", "doc2"} + tkweight := 0.5 + vtweight := 0.5 + sims, tksim, vecsim := qb.HybridSimilarity(avec, bvecs, atks, btkss, tkweight, vtweight) + if len(sims) != 2 || len(tksim) != 2 || len(vecsim) != 2 { + t.Errorf("HybridSimilarity returned slices of wrong length: sims=%d, tksim=%d, vecsim=%d", len(sims), len(tksim), len(vecsim)) + } + for i := range sims { + if sims[i] != 0.0 || tksim[i] != 0.0 || vecsim[i] != 0.0 { + t.Errorf("HybridSimilarity[%d] non-zero: sims=%v, tksim=%v, vecsim=%v", i, sims[i], tksim[i], vecsim[i]) + } + } +} + +func TestQueryBuilder_SetQueryFields(t *testing.T) { + qb := NewQueryBuilder() + newFields := []string{"field1", "field2^5"} + qb.SetQueryFields(newFields) + if !reflect.DeepEqual(qb.queryFields, newFields) { + t.Errorf("SetQueryFields failed, got %v, want %v", qb.queryFields, newFields) + } + // Ensure other methods use updated fields + expr := qb.Paragraph("", []string{"test"}, 0) + if !reflect.DeepEqual(expr.Fields, newFields) { + t.Errorf("Paragraph fields not updated after SetQueryFields, got %v, want %v", expr.Fields, newFields) + } +} \ No newline at end of file diff --git a/internal/service/nlp/reranker.go b/internal/service/nlp/reranker.go new file mode 100644 index 00000000000..17699a43d6b --- /dev/null +++ b/internal/service/nlp/reranker.go @@ -0,0 +1,471 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "math" + "ragflow/internal/engine" + "sort" + "strconv" + "strings" +) + +// RerankModel defines the interface for reranker models +// This matches model.RerankModel interface +type RerankModel interface { + // Similarity calculates similarity between query and texts + Similarity(query string, texts []string) ([]float64, error) +} + +// SearchResult represents the result of a search operation +type SearchResult struct { + Total int + IDs []string + QueryVector []float64 + Field map[string]map[string]interface{} // id -> fields +} + +// Rerank performs reranking based on whether a reranker model is provided +// This implements the logic from rag/nlp/search.py L404-L429 +// Parameters: +// - rerankModel: the reranker model (can be nil) +// - sres: search results +// - query: the query string +// - tkWeight: weight for token similarity +// - vtWeight: weight for vector similarity +// - useInfinity: whether using Infinity engine +// - cfield: content field name (default: "content_ltks") +// - qb: QueryBuilder instance for token processing +// +// Returns: +// - sim: combined similarity scores +// - tsim: token similarity scores +// - vsim: vector similarity scores +func Rerank( + rerankModel RerankModel, + resp *engine.SearchResponse, + keywords []string, + questionVector []float64, + sres *SearchResult, + query string, + tkWeight, vtWeight float64, + useInfinity bool, + cfield string, + qb *QueryBuilder, +) (sim []float64, tsim []float64, vsim []float64) { + // If reranker model is provided and there are results, use model reranking + if rerankModel != nil && resp.Total > 0 { + return RerankByModel(rerankModel, nil, query, tkWeight, vtWeight, cfield, qb) + } + + // Otherwise, use fallback logic based on engine type + if useInfinity { + // For Infinity: scores are already normalized before fusion + // Just extract the scores from results + return RerankInfinityFallback(sres) + } + + // For Elasticsearch: need to perform reranking + return RerankStandard(resp, keywords, questionVector, nil, query, tkWeight, vtWeight, cfield, qb) +} + +// RerankByModel performs reranking using a reranker model +// Reference: rag/nlp/search.py L333-L354 +func RerankByModel( + rerankModel RerankModel, + sres *SearchResult, + query string, + tkWeight, vtWeight float64, + cfield string, + qb *QueryBuilder, +) (sim []float64, tsim []float64, vsim []float64) { + if sres.Total == 0 || len(sres.IDs) == 0 { + return []float64{}, []float64{}, []float64{} + } + + // Extract keywords from query + _, keywords := qb.Question(query, "qa", 0.6) + + // Build token lists and document texts for each chunk + insTw := make([][]string, 0, len(sres.IDs)) + docs := make([]string, 0, len(sres.IDs)) + + for _, id := range sres.IDs { + fields := sres.Field[id] + if fields == nil { + insTw = append(insTw, []string{}) + docs = append(docs, "") + continue + } + + contentLtks := extractContentTokens(fields, cfield) + titleTks := extractTitleTokens(fields) + importantKwd := extractImportantKeywords(fields) + + // Combine tokens without repetition (simpler version for model reranking) + tks := make([]string, 0, len(contentLtks)+len(titleTks)+len(importantKwd)) + tks = append(tks, contentLtks...) + tks = append(tks, titleTks...) + tks = append(tks, importantKwd...) + insTw = append(insTw, tks) + + // Build document text for model reranking + docText := removeRedundantSpaces(strings.Join(tks, " ")) + docs = append(docs, docText) + } + + // Calculate token similarity + tsim = TokenSimilarity(keywords, insTw, qb) + + // Get similarity scores from reranker model + modelSim, err := rerankModel.Similarity(query, docs) + if err != nil { + // If model fails, fall back to token similarity only + modelSim = make([]float64, len(tsim)) + } + + // Combine token similarity with model similarity + // Model similarity is treated as vector similarity component + sim = make([]float64, len(tsim)) + for i := range tsim { + sim[i] = tkWeight*tsim[i] + vtWeight*modelSim[i] + } + + return sim, tsim, modelSim +} + +// RerankStandard performs standard reranking without a reranker model +// Used for Elasticsearch when no reranker model is provided +// Reference: rag/nlp/search.py L294-L331 +func RerankStandard( + resp *engine.SearchResponse, + keywords []string, + questionVector []float64, + sres *SearchResult, + query string, + tkWeight, vtWeight float64, + cfield string, + qb *QueryBuilder, +) (sim []float64, tsim []float64, vsim []float64) { + chunkCount := len(resp.Chunks) + if resp.Total == 0 || chunkCount == 0 { + return []float64{}, []float64{}, []float64{} + } + + // Get vector information + vectorSize := len(questionVector) + vectorColumn := getVectorColumnName(vectorSize) + zeroVector := make([]float64, vectorSize) + + // Extract embeddings and tokens from search results + insEmbd := make([][]float64, 0, chunkCount) + insTw := make([][]string, 0, chunkCount) + + for index := range resp.Chunks { + // Extract vector + chunk := resp.Chunks[index] + chunkVector := extractVector(chunk, vectorColumn, zeroVector) + insEmbd = append(insEmbd, chunkVector) + + // Extract tokens + contentLtks := extractContentTokens(chunk, cfield) + titleTks := extractTitleTokens(chunk) + questionTks := extractQuestionTokens(chunk) + importantKwd := extractImportantKeywords(chunk) + + // Combine tokens with weights: content + title*2 + important_kwd*5 + question_tks*6 + tks := make([]string, 0, len(contentLtks)+len(titleTks)*2+len(importantKwd)*5+len(questionTks)*6) + tks = append(tks, contentLtks...) + for i := 0; i < 2; i++ { + tks = append(tks, titleTks...) + } + for i := 0; i < 5; i++ { + tks = append(tks, importantKwd...) + } + for i := 0; i < 6; i++ { + tks = append(tks, questionTks...) + } + insTw = append(insTw, tks) + } + + if len(insEmbd) == 0 { + return []float64{}, []float64{}, []float64{} + } + + // Calculate hybrid similarity + return HybridSimilarity(questionVector, insEmbd, keywords, insTw, tkWeight, vtWeight, qb) +} + +// RerankInfinityFallback extracts scores from Infinity search results +// Infinity normalizes each way score before fusion, so we just extract them +func RerankInfinityFallback(sres *SearchResult) (sim []float64, tsim []float64, vsim []float64) { + sim = make([]float64, len(sres.IDs)) + for i, id := range sres.IDs { + if fields := sres.Field[id]; fields != nil { + if score, ok := fields["_score"].(float64); ok { + sim[i] = score + } + } + } + // For Infinity, tsim and vsim are the same as overall similarity + return sim, sim, sim +} + +// HybridSimilarity calculates hybrid similarity between query and documents +// Reference: rag/nlp/query.py L174-L182 +func HybridSimilarity( + avec []float64, + bvecs [][]float64, + atks []string, + btkss [][]string, + tkWeight, vtWeight float64, + qb *QueryBuilder, +) (sim []float64, tsim []float64, vsim []float64) { + // Calculate vector similarities using cosine similarity + vsim = make([]float64, len(bvecs)) + for i, bvec := range bvecs { + vsim[i] = cosineSimilarity(avec, bvec) + } + + tsim = TokenSimilarity(atks, btkss, qb) + + // Check if all vector similarities are zero + allZero := true + for _, s := range vsim { + if s != 0 { + allZero = false + break + } + } + + if allZero { + return tsim, tsim, vsim + } + + // Combine similarities + sim = make([]float64, len(tsim)) + for i := range tsim { + sim[i] = vsim[i]*vtWeight + tsim[i]*tkWeight + } + + return sim, tsim, vsim +} + +// TokenSimilarity calculates token-based similarity +// Reference: rag/nlp/query.py L184-L199 +func TokenSimilarity(atks []string, btkss [][]string, qb *QueryBuilder) []float64 { + atksDict := tokensToDict(atks, qb) + btkssDicts := make([]map[string]float64, len(btkss)) + for i, btks := range btkss { + btkssDicts[i] = tokensToDict(btks, qb) + } + + similarities := make([]float64, len(btkssDicts)) + for i, btkDict := range btkssDicts { + similarities[i] = tokenDictSimilarity(atksDict, btkDict) + } + + return similarities +} + +// tokensToDict converts tokens to a weighted dictionary +// Reference: rag/nlp/query.py L185-L195 +func tokensToDict(tks []string, qb *QueryBuilder) map[string]float64 { + d := make(map[string]float64) + wts := qb.termWeight.Weights(tks, false) + + for i, tw := range wts { + t := tw.Term + c := tw.Weight + d[t] += c * 0.4 + if i+1 < len(wts) { + _t := wts[i+1].Term + _c := wts[i+1].Weight + d[t+_t] += math.Max(c, _c) * 0.6 + } + } + + return d +} + +// tokenDictSimilarity calculates similarity between two token dictionaries +// Reference: rag/nlp/query.py L201-L213 +func tokenDictSimilarity(qtwt, dtwt map[string]float64) float64 { + if len(qtwt) == 0 || len(dtwt) == 0 { + return 0.0 + } + + // s = sum of query weights for matching tokens + s := 1e-9 + for t, qw := range qtwt { + if _, ok := dtwt[t]; ok { + s += qw + } + } + + // q = sum of all query weights (L1 normalization) + q := 1e-9 + for _, qw := range qtwt { + q += qw + } + + return s / q +} + +// ArgsortDescending returns indices sorted by values in descending order +func ArgsortDescending(values []float64) []int { + indices := make([]int, len(values)) + for i := range indices { + indices[i] = i + } + + sort.Slice(indices, func(i, j int) bool { + return values[indices[i]] > values[indices[j]] + }) + + return indices +} + +// Helper functions + +// getVectorColumnName returns the vector column name based on dimension +func getVectorColumnName(dim int) string { + return "q_" + strconv.Itoa(dim) + "_vec" +} + +// extractVector extracts vector from chunk fields +func extractVector(fields map[string]interface{}, column string, zeroVector []float64) []float64 { + v, ok := fields[column] + if !ok { + return zeroVector + } + + switch val := v.(type) { + case []float64: + return val + case []interface{}: + vec := make([]float64, len(val)) + for i, v := range val { + vec[i] = v.(float64) + } + return vec + default: + return zeroVector + } +} + +// extractContentTokens extracts content tokens from chunk fields +func extractContentTokens(fields map[string]interface{}, cfield string) []string { + v, ok := fields[cfield].(string) + if !ok { + return []string{} + } + + // Remove duplicates while preserving order + seen := make(map[string]bool) + var result []string + for _, t := range strings.Fields(v) { + if !seen[t] { + seen[t] = true + result = append(result, t) + } + } + return result +} + +// extractTitleTokens extracts title tokens from chunk fields +func extractTitleTokens(fields map[string]interface{}) []string { + v, ok := fields["title_tks"].(string) + if !ok { + return []string{} + } + var result []string + for _, t := range strings.Fields(v) { + if t != "" { + result = append(result, t) + } + } + return result +} + +// extractQuestionTokens extracts question tokens from chunk fields +func extractQuestionTokens(fields map[string]interface{}) []string { + v, ok := fields["question_tks"].(string) + if !ok { + return []string{} + } + var result []string + for _, t := range strings.Fields(v) { + if t != "" { + result = append(result, t) + } + } + return result +} + +// extractImportantKeywords extracts important keywords from chunk fields +func extractImportantKeywords(fields map[string]interface{}) []string { + v, ok := fields["important_kwd"] + if !ok { + return []string{} + } + + switch val := v.(type) { + case string: + return []string{val} + case []string: + return val + case []interface{}: + result := make([]string, 0, len(val)) + for _, item := range val { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result + default: + return []string{} + } +} + +// cosineSimilarity calculates cosine similarity between two vectors +func cosineSimilarity(a, b []float64) float64 { + if len(a) != len(b) { + return 0.0 + } + + var dot, normA, normB float64 + for i := range a { + dot += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dot / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// removeRedundantSpaces removes redundant spaces from text +func removeRedundantSpaces(s string) string { + return strings.Join(strings.Fields(s), " ") +} + +// parseFloat parses a string to float64 +func parseFloat(s string) (float64, error) { + return strconv.ParseFloat(strings.TrimSpace(s), 64) +} diff --git a/internal/service/nlp/synonym.go b/internal/service/nlp/synonym.go new file mode 100644 index 00000000000..f5f0871cd99 --- /dev/null +++ b/internal/service/nlp/synonym.go @@ -0,0 +1,222 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "encoding/json" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "ragflow/internal/logger" + + "go.uber.org/zap" +) + +// Synonym provides synonym lookup functionality +// Reference: rag/nlp/synonym.py Dealer class +type Synonym struct { + lookupNum int + loadTm time.Time + dictionary map[string][]string + redis RedisClient // Optional Redis client for real-time synonym loading + wordNet *WordNet + resPath string +} + +// RedisClient interface for Redis operations +// This should be implemented by the caller if Redis support is needed +type RedisClient interface { + Get(key string) (string, error) +} + +// NewSynonym creates a new Synonym instance +// Reference: synonym.py Dealer.__init__ +// wordnetDir: path to wordnet directory (e.g., "/usr/share/infinity/resource/wordnet"). +// +// If empty, WordNet will not be initialized. +func NewSynonym(redis RedisClient, resPath string, wordnetDir string) *Synonym { + s := &Synonym{ + lookupNum: 100000000, + loadTm: time.Now().Add(-1000000 * time.Second), + dictionary: make(map[string][]string), + redis: redis, + wordNet: nil, // Will be initialized below + resPath: resPath, + } + + if resPath == "" { + s.resPath = "rag/res" + } + + // Initialize WordNet with provided path + if wordnetDir != "" { + wordNet, err := NewWordNet(wordnetDir) + if err != nil { + // WordNet is optional, continue without it + s.wordNet = nil + } else { + s.wordNet = wordNet + } + } + + // Load synonym.json + path := filepath.Join(s.resPath, "synonym.json") + if data, err := os.ReadFile(path); err == nil { + var dict map[string]interface{} + if err := json.Unmarshal(data, &dict); err == nil { + // Convert to lowercase keys and string slices + for k, v := range dict { + key := strings.ToLower(k) + switch val := v.(type) { + case string: + s.dictionary[key] = []string{val} + case []interface{}: + strSlice := make([]string, 0, len(val)) + for _, item := range val { + if str, ok := item.(string); ok { + strSlice = append(strSlice, str) + } + } + s.dictionary[key] = strSlice + } + } + } else { + logger.Warn("Failed to parse synonym.json", zap.Error(err)) + } + } else { + logger.Warn("Missing synonym.json", zap.Error(err)) + } + + if redis == nil { + logger.Warn("Realtime synonym is disabled, since no redis connection.") + } + + if len(s.dictionary) == 0 { + logger.Warn("Fail to load synonym") + } + + s.load() + + return s +} + +// load loads synonyms from Redis if available +// Reference: synonym.py Dealer.load +func (s *Synonym) load() { + //if s.redis == nil { + // return + //} + // + //if s.lookupNum < 100 { + // return + //} + // + //tm := time.Now() + //if tm.Sub(s.loadTm).Seconds() < 3600 { + // return + //} + // + //s.loadTm = time.Now() + //s.lookupNum = 0 + // + //data, err := s.redis.Get("kevin_synonyms") + //if err != nil || data == "" { + // return + //} + // + //var dict map[string][]string + //if jsonErr := json.Unmarshal([]byte(data), &dict); jsonErr != nil { + // logger.Error("Fail to load synonym!", jsonErr) + // return + //} + // + //s.dictionary = dict +} + +// Lookup looks up synonyms for a given token +// Reference: synonym.py Dealer.lookup +func (s *Synonym) Lookup(tk string, topN int) []string { + if tk == "" { + return []string{} + } + + if topN <= 0 { + topN = 8 + } + + // 1) Check the custom dictionary first + //s.lookupNum++ + //s.load() + + key := regexp.MustCompile(`[ \t]+`).ReplaceAllString(strings.TrimSpace(tk), " ") + key = strings.ToLower(key) + + if res, ok := s.dictionary[key]; ok { + if len(res) > topN { + return res[:topN] + } + return res + } + + // 2) If not found and tk is purely alphabetical, fallback to WordNet + if matched, _ := regexp.MatchString(`^[a-z]+$`, tk); matched && s.wordNet != nil { + wnSet := make(map[string]struct{}) + synsets := s.wordNet.Synsets(tk, "") + for _, syn := range synsets { + // Extract word from synset name (format: word.pos.num) + parts := strings.Split(syn.Name, ".") + if len(parts) > 0 { + word := strings.ReplaceAll(parts[0], "_", " ") + wnSet[word] = struct{}{} + } + } + // Remove the original token itself + delete(wnSet, tk) + + // Convert to slice + wnRes := make([]string, 0, len(wnSet)) + for w := range wnSet { + if w != "" { + wnRes = append(wnRes, w) + } + } + + if len(wnRes) > topN { + return wnRes[:topN] + } + return wnRes + } + + // 3) Nothing found in either source + return []string{} +} + +// GetDictionary returns the synonym dictionary +func (s *Synonym) GetDictionary() map[string][]string { + return s.dictionary +} + +// GetLookupNum returns the number of lookups since last load +func (s *Synonym) GetLookupNum() int { + return s.lookupNum +} + +// GetLoadTime returns the last load time +func (s *Synonym) GetLoadTime() time.Time { + return s.loadTm +} diff --git a/internal/service/nlp/synonym_test.go b/internal/service/nlp/synonym_test.go new file mode 100644 index 00000000000..3667d906d4e --- /dev/null +++ b/internal/service/nlp/synonym_test.go @@ -0,0 +1,444 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + "time" +) + +var testSynonymWordNetDir string + +func init() { + // Find project root by locating go.mod file + dir, err := os.Getwd() + if err != nil { + panic(err) + } + for { + goModPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(goModPath); err == nil { + // Found go.mod, project root is dir + testSynonymWordNetDir = filepath.Join(dir, "resource", "wordnet") + return + } + parent := filepath.Dir(dir) + if parent == dir { + // Reached root directory + break + } + dir = parent + } + // Fallback to relative path if go.mod not found + testSynonymWordNetDir = "../../../resource/wordnet" +} + +// MockRedisClient is a mock implementation of RedisClient for testing +type MockRedisClient struct { + data map[string]string +} + +func NewMockRedisClient() *MockRedisClient { + return &MockRedisClient{ + data: make(map[string]string), + } +} + +func (m *MockRedisClient) Get(key string) (string, error) { + return m.data[key], nil +} + +func (m *MockRedisClient) Set(key, value string) { + m.data[key] = value +} + +// TestNewSynonym tests the constructor +func TestNewSynonym(t *testing.T) { + t.Run("without redis", func(t *testing.T) { + s := NewSynonym(nil, "", testSynonymWordNetDir) + if s == nil { + t.Fatal("NewSynonym returned nil") + } + if s.dictionary == nil { + t.Error("Dictionary not initialized") + } + if s.wordNet == nil { + t.Error("WordNet not initialized") + } + }) + + t.Run("with redis", func(t *testing.T) { + redis := NewMockRedisClient() + s := NewSynonym(redis, "", testSynonymWordNetDir) + if s == nil { + t.Fatal("NewSynonym returned nil") + } + if s.redis != redis { + t.Error("Redis client not set") + } + }) +} + +// TestNewSynonymWithMockFile tests loading from synonym.json +func TestNewSynonymWithMockFile(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock synonym.json + synonymData := map[string]interface{}{ + "happy": []string{"joyful", "cheerful", "glad"}, + "sad": []string{"unhappy", "sorrowful"}, + "test": "single", // Test string value + "UPPER": []string{"lower"}, // Test case conversion + } + data, _ := json.Marshal(synonymData) + if err := os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644); err != nil { + t.Fatalf("Failed to create mock synonym.json: %v", err) + } + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + // Check dictionary loaded correctly + if len(s.dictionary) != 4 { + t.Errorf("Expected 4 entries, got %d", len(s.dictionary)) + } + + // Check case conversion (UPPER -> upper) + if _, ok := s.dictionary["upper"]; !ok { + t.Error("Expected 'upper' key (converted from UPPER)") + } + + // Check string value converted to slice (test -> [single]) + if val, ok := s.dictionary["test"]; !ok || len(val) != 1 || val[0] != "single" { + t.Error("Expected 'test' to be converted to single-element slice") + } +} + +// TestSynonymLookup tests the Lookup method +func TestSynonymLookup(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock synonym.json + synonymData := map[string]interface{}{ + "hello": []string{"hi", "greetings", "hey"}, + "world": []string{"earth", "globe"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + tests := []struct { + name string + tk string + topN int + expected []string + }{ + { + name: "found in dictionary", + tk: "hello", + topN: 8, + expected: []string{"hi", "greetings", "hey"}, + }, + { + name: "found with topN limit", + tk: "hello", + topN: 2, + expected: []string{"hi", "greetings"}, + }, + { + name: "not found", + tk: "xyzabc123", + topN: 8, + expected: []string{}, + }, + { + name: "empty token", + tk: "", + topN: 8, + expected: []string{}, + }, + { + name: "whitespace normalization", + tk: " hello ", + topN: 8, + expected: []string{"hi", "greetings", "hey"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := s.Lookup(tt.tk, tt.topN) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Lookup(%q, %d) = %v, expected %v", tt.tk, tt.topN, result, tt.expected) + } + }) + } +} + +// TestSynonymLookupFromWordNet tests WordNet fallback +func TestSynonymLookupFromWordNet(t *testing.T) { + // Create synonym with empty dictionary to force WordNet fallback + s := NewSynonym(nil, "", "") + s.dictionary = make(map[string][]string) // Clear dictionary + + t.Run("pure alphabetical token", func(t *testing.T) { + // Since WordNet is a placeholder, it should return empty + result := s.Lookup("test", 8) + // WordNet placeholder returns empty, so we expect empty result + if len(result) != 0 { + t.Logf("WordNet returned: %v (placeholder implementation)", result) + } + }) + + t.Run("non-alphabetical token", func(t *testing.T) { + result := s.Lookup("test123", 8) + if len(result) != 0 { + t.Errorf("Expected empty result for non-alphabetical token, got %v", result) + } + }) +} + +// TestSynonymLoad tests loading from Redis +func TestSynonymLoad(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial synonym.json + synonymData := map[string]interface{}{ + "initial": []string{"first"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + redis := NewMockRedisClient() + + // Set up Redis data + redisData := map[string][]string{ + "redis_key": []string{"from", "redis"}, + } + redisBytes, _ := json.Marshal(redisData) + redis.Set("kevin_synonyms", string(redisBytes)) + + s := NewSynonym(redis, tmpDir, testSynonymWordNetDir) + + // Simulate multiple lookups to trigger load + s.lookupNum = 200 // Set above threshold + s.loadTm = time.Now().Add(-4000 * time.Second) // Set load time > 1 hour ago + + // Call load directly + s.load() + + // After load, dictionary should be updated from Redis + if _, ok := s.dictionary["redis_key"]; !ok { + t.Log("Dictionary not updated from Redis (may be expected due to timing)") + } +} + +// TestSynonymLoadNoRedis tests load without Redis +func TestSynonymLoadNoRedis(t *testing.T) { + s := NewSynonym(nil, "", "") + + // Should not panic + s.load() + + // Lookup num should remain unchanged + originalNum := s.lookupNum + s.load() + if s.lookupNum != originalNum { + t.Error("Lookup num should not change when Redis is nil") + } +} + +// TestSynonymLoadNotTriggered tests load conditions +func TestSynonymLoadNotTriggered(t *testing.T) { + redis := NewMockRedisClient() + s := NewSynonym(redis, "", "") + + // Set conditions that should prevent load + s.lookupNum = 50 // Below threshold + s.loadTm = time.Now() + + // Call load + s.load() + + // Should not attempt to load from Redis + // (indirect check: lookupNum should not reset) + if s.lookupNum != 50 { + t.Error("Load should not be triggered when lookupNum < 100") + } +} + +// TestGetDictionary tests GetDictionary method +func TestGetDictionary(t *testing.T) { + tmpDir := t.TempDir() + + synonymData := map[string]interface{}{ + "test": []string{"value"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + dict := s.GetDictionary() + if dict == nil { + t.Error("GetDictionary returned nil") + } + if len(dict) != 1 { + t.Errorf("Expected 1 entry, got %d", len(dict)) + } +} + +// TestGetLookupNum tests GetLookupNum method +func TestGetLookupNum(t *testing.T) { + s := NewSynonym(nil, "", "") + initialNum := s.GetLookupNum() + + // Perform some lookups + s.Lookup("test1", 8) + s.Lookup("test2", 8) + s.Lookup("test3", 8) + + newNum := s.GetLookupNum() + if newNum != initialNum+3 { + t.Errorf("Expected lookup num %d, got %d", initialNum+3, newNum) + } +} + +// TestGetLoadTime tests GetLoadTime method +func TestGetLoadTime(t *testing.T) { + s := NewSynonym(nil, "", "") + loadTime := s.GetLoadTime() + + // Load time should be in the past (since we set it to -1000000 seconds) + if loadTime.After(time.Now()) { + t.Error("Load time should be in the past") + } +} + +// TestLookupCaseSensitivity tests case insensitivity +func TestLookupCaseSensitivity(t *testing.T) { + tmpDir := t.TempDir() + + synonymData := map[string]interface{}{ + "lowercase": []string{"result"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + // Lookup with different cases + tests := []string{"lowercase", "LOWERCASE", "LowerCase", "LoWeRcAsE"} + for _, tk := range tests { + result := s.Lookup(tk, 8) + if len(result) == 0 { + t.Errorf("Expected result for %q, got none", tk) + } + } +} + +// TestLookupWithSpaces tests whitespace normalization +func TestLookupWithSpaces(t *testing.T) { + tmpDir := t.TempDir() + + synonymData := map[string]interface{}{ + "two words": []string{"result"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + // Lookup with various whitespace + tests := []string{ + "two words", + "two words", + "two\twords", + "two\t\twords", + " two words ", + } + + for _, tk := range tests { + result := s.Lookup(tk, 8) + if len(result) == 0 { + t.Errorf("Expected result for %q, got none", tk) + } + } +} + +// TestSynonymMissingFile tests behavior when synonym.json is missing +func TestSynonymMissingFile(t *testing.T) { + tmpDir := t.TempDir() + // Don't create synonym.json + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + if len(s.dictionary) != 0 { + t.Errorf("Expected empty dictionary, got %d entries", len(s.dictionary)) + } + + // Lookup should return empty + result := s.Lookup("anything", 8) + if len(result) != 0 { + t.Errorf("Expected empty result, got %v", result) + } +} + +// TestSynonymInvalidJSON tests behavior with invalid JSON +func TestSynonymInvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + + // Create invalid JSON file + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), []byte("invalid json"), 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + // Should have empty dictionary but not panic + if s.dictionary == nil { + t.Error("Dictionary should be initialized even with invalid JSON") + } +} + +// BenchmarkLookup benchmarks the Lookup method +func BenchmarkLookup(b *testing.B) { + tmpDir := b.TempDir() + + synonymData := map[string]interface{}{ + "test": []string{"synonym1", "synonym2", "synonym3"}, + } + data, _ := json.Marshal(synonymData) + os.WriteFile(filepath.Join(tmpDir, "synonym.json"), data, 0644) + + s := NewSynonym(nil, tmpDir, testSynonymWordNetDir) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Lookup("test", 8) + } +} + +// BenchmarkLookupNotFound benchmarks lookup for non-existent tokens +func BenchmarkLookupNotFound(b *testing.B) { + s := NewSynonym(nil, "", "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Lookup("nonexistent", 8) + } +} diff --git a/internal/service/nlp/term_weight.go b/internal/service/nlp/term_weight.go new file mode 100644 index 00000000000..215d608bacd --- /dev/null +++ b/internal/service/nlp/term_weight.go @@ -0,0 +1,496 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "encoding/json" + "math" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "ragflow/internal/logger" + "ragflow/internal/tokenizer" + + "go.uber.org/zap" +) + +// TermWeightDealer calculates term weights for text processing +// Reference: rag/nlp/term_weight.py +type TermWeightDealer struct { + stopWords map[string]struct{} + ne map[string]string // named entities + df map[string]int // document frequency +} + +// TermWeight represents a term and its weight +type TermWeight struct { + Term string + Weight float64 +} + +// NewTermWeightDealer creates a new TermWeightDealer +func NewTermWeightDealer(resPath string) *TermWeightDealer { + d := &TermWeightDealer{ + stopWords: initStopWords(), + ne: make(map[string]string), + df: make(map[string]int), + } + + // Load named entity dictionary + if resPath == "" { + resPath = "rag/res" + } + + nerPath := filepath.Join(resPath, "ner.json") + if data, err := os.ReadFile(nerPath); err == nil { + if err := json.Unmarshal(data, &d.ne); err != nil { + logger.Warn("Failed to load ner.json", zap.Error(err)) + } + } else { + logger.Warn("Failed to load ner.json", zap.Error(err)) + } + + // Load term frequency dictionary + freqPath := filepath.Join(resPath, "term.freq") + d.df = loadDict(freqPath) + + return d +} + +// initStopWords initializes the stop words set +func initStopWords() map[string]struct{} { + words := []string{ + "请问", "您", "你", "我", "他", "是", "的", "就", "有", "于", + "及", "即", "在", "为", "最", "有", "从", "以", "了", "将", + "与", "吗", "吧", "中", "#", "什么", "怎么", "哪个", "哪些", + "啥", "相关", + } + stopWords := make(map[string]struct{}, len(words)) + for _, w := range words { + stopWords[w] = struct{}{} + } + return stopWords +} + +// loadDict loads a dictionary file +// Format: term\tfreq or just term +func loadDict(fnm string) map[string]int { + res := make(map[string]int) + data, err := os.ReadFile(fnm) + if err != nil { + logger.Warn("Failed to load dictionary", zap.String("file", fnm), zap.Error(err)) + return res + } + + lines := strings.Split(string(data), "\n") + totalFreq := 0 + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + arr := strings.Split(line, "\t") + if len(arr) >= 2 { + if freq, err := strconv.Atoi(arr[1]); err == nil { + res[arr[0]] = freq + totalFreq += freq + } + } else { + res[arr[0]] = 0 + } + } + + // If no frequencies, return as set (all 0) + if totalFreq == 0 { + return res + } + return res +} + +// Pretoken preprocesses and tokenizes text +// Reference: term_weight.py L92-114 +func (d *TermWeightDealer) Pretoken(txt string, num bool, stpwd bool) []string { + patt := `[~—\t @#%!<>,\.\?":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;':""【¥ 】…¥!、·()×\` + "`" + `&/「」\]` + + res := []string{} + tokenized, err := tokenizer.Tokenize(txt) + if err != nil { + // Fallback to simple split + tokenized = txt + } + + for _, t := range strings.Fields(tokenized) { + tk := t + // Check stop words + if stpwd { + if _, isStop := d.stopWords[tk]; isStop { + continue + } + } + // Check single digit (unless num is true) + if matched, _ := regexp.MatchString("^[0-9]$", tk); matched && !num { + continue + } + // Check patterns + if matched, _ := regexp.MatchString(patt, t); matched { + tk = "#" + } + if tk != "#" && tk != "" { + res = append(res, tk) + } + } + return res +} + +// TokenMerge merges short tokens into phrases +// Reference: term_weight.py L116-143 +func (d *TermWeightDealer) TokenMerge(tks []string) []string { + oneTerm := func(t string) bool { + // Use rune count for proper Unicode handling + runeCount := len([]rune(t)) + if runeCount == 1 { + return true + } + // Match 1-2 alphanumeric characters + matched, _ := regexp.MatchString("^[0-9a-z]{1,2}$", t) + return matched + } + + if len(tks) == 0 { + return []string{} + } + + res := []string{} + i := 0 + for i < len(tks) { + // Special case: first term is single char and next is multi-char Chinese + if i == 0 && len(tks) > 1 && oneTerm(tks[i]) { + nextLen := len([]rune(tks[i+1])) + isNextMultiChar := nextLen > 1 + isNextNotAlnum, _ := regexp.MatchString("^[0-9a-zA-Z]", tks[i+1]) + if isNextMultiChar && !isNextNotAlnum { + res = append(res, tks[0]+" "+tks[1]) + i = 2 + continue + } + } + + j := i + for j < len(tks) && tks[j] != "" { + if _, isStop := d.stopWords[tks[j]]; isStop { + break + } + if !oneTerm(tks[j]) { + break + } + j++ + } + + if j-i > 1 { + if j-i < 5 { + res = append(res, strings.Join(tks[i:j], " ")) + i = j + } else { + // Split into pairs for 5+ consecutive short tokens + for k := i; k < j; k += 2 { + if k+1 < j { + res = append(res, tks[k]+" "+tks[k+1]) + } else { + res = append(res, tks[k]) + } + } + i = j + } + } else { + if len(tks[i]) > 0 { + res = append(res, tks[i]) + } + i++ + } + } + + // Filter empty strings + filtered := []string{} + for _, t := range res { + if t != "" { + filtered = append(filtered, t) + } + } + return filtered +} + +// Ner gets named entity type for a term +// Reference: term_weight.py L145-150 +func (d *TermWeightDealer) Ner(t string) string { + if d.ne == nil { + return "" + } + if res, ok := d.ne[t]; ok { + return res + } + return "" +} + +// Split splits text into tokens, merging consecutive English words +// Reference: term_weight.py L152-161 +func (d *TermWeightDealer) Split(txt string) []string { + if txt == "" { + return []string{""} + } + + tks := []string{} + // Normalize spaces (tabs and multiple spaces -> single space) + txt = regexp.MustCompile("[ \\t]+").ReplaceAllString(txt, " ") + txt = strings.TrimSpace(txt) + + for _, t := range strings.Split(txt, " ") { + t = strings.TrimSpace(t) + if t == "" { + continue + } + if len(tks) > 0 { + prevEndsWithLetter, _ := regexp.MatchString(".*[a-zA-Z]$", tks[len(tks)-1]) + currEndsWithLetter, _ := regexp.MatchString(".*[a-zA-Z]$", t) + prevNE := d.ne[tks[len(tks)-1]] + currNE := d.ne[t] + if prevEndsWithLetter && currEndsWithLetter && + currNE != "func" && prevNE != "func" { + tks[len(tks)-1] = tks[len(tks)-1] + " " + t + continue + } + } + tks = append(tks, t) + } + return tks +} + +// Weights calculates weights for tokens +// Reference: term_weight.py L163-246 +func (d *TermWeightDealer) Weights(tks []string, preprocess bool) []TermWeight { + numPattern := regexp.MustCompile("^[0-9,.]{2,}$") + shortLetterPattern := regexp.MustCompile("^[a-z]{1,2}$") + numSpacePattern := regexp.MustCompile("^[0-9. -]{2,}$") + letterPattern := regexp.MustCompile("^[a-z. -]+$") + + // ner weight function + nerWeight := func(t string) float64 { + if numPattern.MatchString(t) { + return 2 + } + if shortLetterPattern.MatchString(t) { + return 0.01 + } + if d.ne == nil { + return 1 + } + if neType, ok := d.ne[t]; ok { + weights := map[string]float64{ + "toxic": 2, "func": 1, "corp": 3, "loca": 3, + "sch": 3, "stock": 3, "firstnm": 1, + } + if w, exists := weights[neType]; exists { + return w + } + } + return 1 + } + + // postag weight function using real POS tagger + postagWeight := func(t string) float64 { + tag := tokenizer.GetTermTag(t) + // Map POS tags to weights (matching Python implementation) + if tag == "r" || tag == "c" || tag == "d" { + return 0.3 + } + if tag == "ns" || tag == "nt" { + return 3 + } + if tag == "n" { + return 2 + } + // Fallback to heuristic for terms without tags + if matched, _ := regexp.MatchString("^[0-9-]+", tag); matched { + return 2 + } + return 1 + } + + // freq function using real frequency dictionary + var freq func(t string) float64 + freq = func(t string) float64 { + if numSpacePattern.MatchString(t) { + return 3 + } + // Use tokenizer's freq function + s := tokenizer.GetTermFreq(t) + if s == 0 && letterPattern.MatchString(t) { + return 300 + } + if s == 0 && len([]rune(t)) >= 4 { + // Try fine-grained tokenization + fgTokens, _ := tokenizer.Tokenize(t) + tokens := strings.Fields(fgTokens) + + var validTokens []float64 + if len(tokens) > 1 { + for _, tt := range tokens { + f := freq(tt) + validTokens = append(validTokens, f) + } + + minVal := validTokens[0] + for _, v := range validTokens[1:] { + if v < minVal { + minVal = v + } + } + return minVal / 6.0 + } + + // Default frequency + return 10 + } + return math.Max(float64(s), 10) + } + + // df function + var df func(t string) float64 + df = func(t string) float64 { + if numSpacePattern.MatchString(t) { + return 5 + } + if v, ok := d.df[t]; ok { + return float64(v) + 3 + } + if letterPattern.MatchString(t) { + return 300 + } + if len([]rune(t)) >= 4 { + fgTokens, _ := tokenizer.Tokenize(t) + tokens := strings.Fields(fgTokens) + + var validTokens []float64 + if len(tokens) > 1 { + for _, tt := range tokens { + f := df(tt) + validTokens = append(validTokens, f) + } + + minVal := validTokens[0] + for _, v := range validTokens[1:] { + if v < minVal { + minVal = v + } + } + return math.Max(3, minVal/6.0) + } + } + return 3 + } + + // idf function + idf := func(s, N float64) float64 { + return math.Log10(10 + ((N - s + 0.5) / (s + 0.5))) + } + + tw := []TermWeight{} + + if !preprocess { + // Direct calculation without preprocessing + idf1Vals := make([]float64, len(tks)) + idf2Vals := make([]float64, len(tks)) + nerPosVals := make([]float64, len(tks)) + + for i, t := range tks { + //fmt.Println("index:", i, "term:", t) + idf1Vals[i] = idf(freq(t), 10000000) + idf2Vals[i] = idf(df(t), 1000000000) + nerPosVals[i] = nerWeight(t) * postagWeight(t) + } + + wts := make([]float64, len(tks)) + for i := range tks { + wts[i] = (0.3*idf1Vals[i] + 0.7*idf2Vals[i]) * nerPosVals[i] + } + + for i, t := range tks { + tw = append(tw, TermWeight{Term: t, Weight: wts[i]}) + } + } else { + // With preprocessing + for _, tk := range tks { + tokens := d.Pretoken(tk, true, true) + tt := d.TokenMerge(tokens) + if len(tt) == 0 { + continue + } + + idf1Vals := make([]float64, len(tt)) + idf2Vals := make([]float64, len(tt)) + nerPosVals := make([]float64, len(tt)) + + for i, t := range tt { + idf1Vals[i] = idf(freq(t), 10000000) + idf2Vals[i] = idf(df(t), 1000000000) + nerPosVals[i] = nerWeight(t) * postagWeight(t) + } + + wts := make([]float64, len(tt)) + for i := range tt { + wts[i] = (0.3*idf1Vals[i] + 0.7*idf2Vals[i]) * nerPosVals[i] + } + + for i, t := range tt { + tw = append(tw, TermWeight{Term: t, Weight: wts[i]}) + } + } + } + + // Normalize weights + if len(tw) == 0 { + return tw + } + + S := 0.0 + for _, twItem := range tw { + S += twItem.Weight + } + + if S > 0 { + for i := range tw { + tw[i].Weight = tw[i].Weight / S + } + } + + return tw +} + +// GetStopWords returns the stop words set +func (d *TermWeightDealer) GetStopWords() map[string]struct{} { + return d.stopWords +} + +// GetNE returns the named entity dictionary +func (d *TermWeightDealer) GetNE() map[string]string { + return d.ne +} + +// GetDF returns the document frequency dictionary +func (d *TermWeightDealer) GetDF() map[string]int { + return d.df +} diff --git a/internal/service/nlp/term_weight_test.go b/internal/service/nlp/term_weight_test.go new file mode 100644 index 00000000000..f731e2403cb --- /dev/null +++ b/internal/service/nlp/term_weight_test.go @@ -0,0 +1,832 @@ +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nlp + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +// TestNewTermWeightDealer tests the constructor +func TestNewTermWeightDealer(t *testing.T) { + // Test with empty resPath + d := NewTermWeightDealer("") + if d == nil { + t.Fatal("NewTermWeightDealer returned nil") + } + + // Check stop words are initialized + if len(d.stopWords) == 0 { + t.Error("Stop words not initialized") + } + + // Check stop word exists + if _, ok := d.stopWords["请问"]; !ok { + t.Error("Expected stop word '请问' not found") + } + + // Test with non-existent resPath (should not panic) + d2 := NewTermWeightDealer("/nonexistent/path") + if d2 == nil { + t.Fatal("NewTermWeightDealer returned nil for non-existent path") + } +} + +// TestNewTermWeightDealerWithMockFiles tests with mock dictionary files +func TestNewTermWeightDealerWithMockFiles(t *testing.T) { + // Create temporary directory with mock files + tmpDir := t.TempDir() + + // Create mock ner.json + nerData := `{ + "北京": "loca", + "腾讯": "corp", + "func": "func", + "toxic": "toxic" + }` + if err := os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644); err != nil { + t.Fatalf("Failed to create mock ner.json: %v", err) + } + + // Create mock term.freq + freqData := "hello\t100\nworld\t200\ntest\t50\n" + if err := os.WriteFile(filepath.Join(tmpDir, "term.freq"), []byte(freqData), 0644); err != nil { + t.Fatalf("Failed to create mock term.freq: %v", err) + } + + d := NewTermWeightDealer(tmpDir) + + // Check NE dictionary + if ne := d.Ner("北京"); ne != "loca" { + t.Errorf("Expected NE 'loca' for '北京', got '%s'", ne) + } + if ne := d.Ner("腾讯"); ne != "corp" { + t.Errorf("Expected NE 'corp' for '腾讯', got '%s'", ne) + } + + // Check DF dictionary + if df := d.GetDF(); len(df) != 3 { + t.Errorf("Expected 3 entries in DF, got %d", len(df)) + } +} + +// TestPretoken tests the pretokenization function +func TestPretoken(t *testing.T) { + d := NewTermWeightDealer("") + + tests := []struct { + name string + txt string + num bool + stpwd bool + expected []string + }{ + { + name: "simple text", + txt: "hello world", + num: false, + stpwd: true, + expected: []string{}, // May vary based on tokenizer + }, + { + name: "with stop words", + txt: "请问你好吗", + num: false, + stpwd: true, + expected: []string{}, // Stop words should be removed + }, + { + name: "with numbers (num=true)", + txt: "123", + num: true, + stpwd: true, + expected: []string{}, // Single digit may be filtered + }, + { + name: "empty text", + txt: "", + num: false, + stpwd: true, + expected: []string{}, + }, + { + name: "only punctuation", + txt: ",。!?", + num: false, + stpwd: true, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := d.Pretoken(tt.txt, tt.num, tt.stpwd) + // Just check it doesn't panic and returns a slice + if result == nil { + t.Error("Pretoken returned nil") + } + }) + } +} + +// TestTokenMerge tests token merging +func TestTokenMerge(t *testing.T) { + d := NewTermWeightDealer("") + + tests := []struct { + name string + tks []string + expected []string + }{ + { + name: "empty input", + tks: []string{}, + expected: []string{}, + }, + { + name: "single token", + tks: []string{"hello"}, + expected: []string{"hello"}, + }, + { + name: "consecutive short tokens", + tks: []string{"a", "b", "c"}, + expected: []string{"a b c"}, // Should merge + }, + { + name: "mixed tokens", + tks: []string{"a", "hello", "b"}, + expected: []string{"a", "hello", "b"}, + }, + { + name: "first term single char followed by multi-char", + tks: []string{"多", "工位"}, + expected: []string{"多 工位"}, // Special case + }, + { + name: "too many short tokens (>=5)", + tks: []string{"a", "b", "c", "d", "e", "f"}, + expected: []string{"a b", "c d", "e f"}, // Merge in pairs + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := d.TokenMerge(tt.tks) + if !reflect.DeepEqual(result, tt.expected) { + // Debug: print detailed comparison + t.Errorf("TokenMerge(%v) = %v (len=%d), expected %v (len=%d)", + tt.tks, result, len(result), tt.expected, len(tt.expected)) + for i, r := range result { + t.Errorf(" result[%d] = %q (len=%d)", i, r, len(r)) + } + for i, e := range tt.expected { + t.Errorf(" expected[%d] = %q (len=%d)", i, e, len(e)) + } + } + }) + } +} + +// TestNer tests named entity recognition +func TestNer(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock ner.json + nerData := `{ + "北京": "loca", + "腾讯": "corp", + "阿里巴巴": "corp" + }` + if err := os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644); err != nil { + t.Fatalf("Failed to create mock ner.json: %v", err) + } + + d := NewTermWeightDealer(tmpDir) + + tests := []struct { + term string + expected string + }{ + {"北京", "loca"}, + {"腾讯", "corp"}, + {"阿里巴巴", "corp"}, + {"不存在", ""}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.term, func(t *testing.T) { + result := d.Ner(tt.term) + if result != tt.expected { + t.Errorf("Ner('%s') = '%s', expected '%s'", tt.term, result, tt.expected) + } + }) + } +} + +// TestSplit tests text splitting +func TestSplit(t *testing.T) { + d := NewTermWeightDealer("") + + tests := []struct { + name string + txt string + expected []string + }{ + { + name: "simple split", + txt: "hello world test", + // Consecutive English words ending with letters are merged + expected: []string{"hello world test"}, + }, + { + name: "consecutive English words", + txt: "machine learning algorithm", + expected: []string{"machine learning algorithm"}, // Should merge + }, + { + name: "mixed Chinese and English", + txt: "hello 世界 world", + // "hello" ends with letter, "世界" doesn't start with letter but doesn't end with letter either + expected: []string{"hello", "世界", "world"}, + }, + { + name: "empty string", + txt: "", + expected: []string{""}, + }, + { + name: "multiple spaces", + txt: "hello world", + // Multiple spaces are normalized, then merged if both end with letters + expected: []string{"hello world"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := d.Split(tt.txt) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Split('%s') = %v (len=%d), expected %v (len=%d)", + tt.txt, result, len(result), tt.expected, len(tt.expected)) + for i, r := range result { + t.Errorf(" result[%d] = %q", i, r) + } + for i, e := range tt.expected { + t.Errorf(" expected[%d] = %q", i, e) + } + } + }) + } +} + +// TestWeights tests weight calculation +func TestWeights(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock ner.json + nerData := `{ + "toxic": "toxic", + "func": "func", + "corp": "corp", + "loca": "loca" + }` + if err := os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644); err != nil { + t.Fatalf("Failed to create mock ner.json: %v", err) + } + + // Create mock term.freq + freqData := "hello\t100\nworld\t200\n" + if err := os.WriteFile(filepath.Join(tmpDir, "term.freq"), []byte(freqData), 0644); err != nil { + t.Fatalf("Failed to create mock term.freq: %v", err) + } + + d := NewTermWeightDealer(tmpDir) + + t.Run("without preprocess", func(t *testing.T) { + tks := []string{"hello", "world", "123"} + weights := d.Weights(tks, false) + + if len(weights) != len(tks) { + t.Errorf("Expected %d weights, got %d", len(tks), len(weights)) + } + + // Check weights sum to 1 (normalized) + sum := 0.0 + for _, tw := range weights { + sum += tw.Weight + } + if sum < 0.99 || sum > 1.01 { + t.Errorf("Weights should sum to ~1, got %f", sum) + } + }) + + t.Run("with preprocess", func(t *testing.T) { + tks := []string{"hello world", "test"} + weights := d.Weights(tks, true) + + // Check it doesn't panic and returns results + if weights == nil { + t.Error("Weights returned nil") + } + }) + + t.Run("empty input", func(t *testing.T) { + weights := d.Weights([]string{}, false) + if len(weights) != 0 { + t.Errorf("Expected empty weights for empty input, got %d", len(weights)) + } + }) + + t.Run("ner weight effect", func(t *testing.T) { + tmpDir2 := t.TempDir() + nerData := `{"toxicterm": "toxic"}` + os.WriteFile(filepath.Join(tmpDir2, "ner.json"), []byte(nerData), 0644) + d2 := NewTermWeightDealer(tmpDir2) + + tks := []string{"toxicterm", "normal"} + weights := d2.Weights(tks, false) + + if len(weights) != 2 { + t.Fatalf("Expected 2 weights, got %d", len(weights)) + } + + // toxicterm should have higher weight (nerWeight=2) + if weights[0].Weight <= weights[1].Weight { + t.Error("Expected toxicterm to have higher weight than normal term") + } + }) +} + +// TestWeightsWithNER tests NER type weight effects +func TestWeightsWithNER(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock ner.json with all types + nerData := `{ + "toxic_word": "toxic", + "func_word": "func", + "corp_name": "corp", + "location": "loca", + "school": "sch", + "stock": "stock", + "firstname": "firstnm" + }` + if err := os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644); err != nil { + t.Fatalf("Failed to create mock ner.json: %v", err) + } + + d := NewTermWeightDealer(tmpDir) + + tests := []struct { + term string + expectedType string + }{ + {"toxic_word", "toxic"}, + {"func_word", "func"}, + {"corp_name", "corp"}, + {"location", "loca"}, + {"school", "sch"}, + {"stock", "stock"}, + {"firstname", "firstnm"}, + } + + for _, tt := range tests { + t.Run(tt.term, func(t *testing.T) { + ne := d.Ner(tt.term) + if ne != tt.expectedType { + t.Errorf("Ner('%s') = '%s', expected '%s'", tt.term, ne, tt.expectedType) + } + }) + } +} + +// TestGetters tests the getter methods +func TestGetters(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock files + nerData := `{"test": "type"}` + os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644) + os.WriteFile(filepath.Join(tmpDir, "term.freq"), []byte("word\t10\n"), 0644) + + d := NewTermWeightDealer(tmpDir) + + t.Run("GetStopWords", func(t *testing.T) { + sw := d.GetStopWords() + if len(sw) == 0 { + t.Error("GetStopWords returned empty map") + } + if _, ok := sw["请问"]; !ok { + t.Error("Expected stop word '请问' not in map") + } + }) + + t.Run("GetNE", func(t *testing.T) { + ne := d.GetNE() + if len(ne) != 1 { + t.Errorf("Expected 1 NE entry, got %d", len(ne)) + } + if ne["test"] != "type" { + t.Error("NE dictionary content incorrect") + } + }) + + t.Run("GetDF", func(t *testing.T) { + df := d.GetDF() + if len(df) != 1 { + t.Errorf("Expected 1 DF entry, got %d", len(df)) + } + if df["word"] != 10 { + t.Error("DF dictionary content incorrect") + } + }) +} + +// TestLoadDict tests dictionary loading +func TestLoadDict(t *testing.T) { + t.Run("load with frequency", func(t *testing.T) { + tmpDir := t.TempDir() + content := "word1\t100\nword2\t200\nword3\t300\n" + fn := filepath.Join(tmpDir, "test.freq") + os.WriteFile(fn, []byte(content), 0644) + + dict := loadDict(fn) + if len(dict) != 3 { + t.Errorf("Expected 3 entries, got %d", len(dict)) + } + if dict["word1"] != 100 { + t.Errorf("Expected word1=100, got %d", dict["word1"]) + } + }) + + t.Run("load without frequency (set mode)", func(t *testing.T) { + tmpDir := t.TempDir() + content := "word1\nword2\nword3\n" + fn := filepath.Join(tmpDir, "test.freq") + os.WriteFile(fn, []byte(content), 0644) + + dict := loadDict(fn) + if len(dict) != 3 { + t.Errorf("Expected 3 entries, got %d", len(dict)) + } + // All values should be 0 in set mode + for k, v := range dict { + if v != 0 { + t.Errorf("Expected %s=0 in set mode, got %d", k, v) + } + } + }) + + t.Run("load non-existent file", func(t *testing.T) { + dict := loadDict("/nonexistent/file.txt") + if dict == nil { + t.Error("loadDict should return empty map, not nil") + } + if len(dict) != 0 { + t.Error("loadDict should return empty map for non-existent file") + } + }) + + t.Run("load with malformed lines", func(t *testing.T) { + tmpDir := t.TempDir() + content := "word1\t100\n\n\nword2\tnotanumber\nword3" + fn := filepath.Join(tmpDir, "test.freq") + os.WriteFile(fn, []byte(content), 0644) + + dict := loadDict(fn) + // Should handle empty lines and invalid numbers gracefully + if len(dict) < 1 { + t.Error("Should handle malformed lines gracefully") + } + }) +} + +// TestWeightsNormalization tests weight normalization +func TestWeightsNormalization(t *testing.T) { + d := NewTermWeightDealer("") + + tests := []struct { + name string + tks []string + }{ + { + name: "single token", + tks: []string{"hello"}, + }, + { + name: "multiple tokens", + tks: []string{"hello", "world", "test"}, + }, + { + name: "many tokens", + tks: []string{"a", "b", "c", "d", "e"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + weights := d.Weights(tt.tks, false) + + if len(weights) != len(tt.tks) { + t.Fatalf("Expected %d weights, got %d", len(tt.tks), len(weights)) + } + + // Sum should be approximately 1 + sum := 0.0 + for _, tw := range weights { + sum += tw.Weight + // Individual weights should be non-negative + if tw.Weight < 0 { + t.Errorf("Weight for '%s' is negative: %f", tw.Term, tw.Weight) + } + } + + if sum < 0.99 || sum > 1.01 { + t.Errorf("Weights sum to %f, expected ~1.0", sum) + } + }) + } +} + +// TestSplitWithNER tests Split with NER considerations +func TestSplitWithNER(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock ner.json + nerData := `{ + "function": "func" + }` + os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644) + + d := NewTermWeightDealer(tmpDir) + + t.Run("func type should not merge", func(t *testing.T) { + // If one of the words has NE type "func", they should not merge + result := d.Split("hello function") + // "hello" and "function" should not merge because function has type "func" + if len(result) != 2 { + t.Logf("Result: %v", result) + } + }) +} + +// BenchmarkWeights benchmarks the Weights function +func BenchmarkWeights(b *testing.B) { + d := NewTermWeightDealer("") + tks := []string{"hello", "world", "this", "is", "a", "test", "of", "term", "weights", "calculation"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.Weights(tks, false) + } +} + +// BenchmarkTokenMerge benchmarks the TokenMerge function +func BenchmarkTokenMerge(b *testing.B) { + d := NewTermWeightDealer("") + tks := []string{"a", "b", "c", "d", "e", "hello", "world", "x", "y", "z"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.TokenMerge(tks) + } +} + +// TestTermWeightStructure tests the TermWeight struct +func TestTermWeightStructure(t *testing.T) { + tw := TermWeight{ + Term: "test", + Weight: 0.5, + } + + if tw.Term != "test" { + t.Error("Term field incorrect") + } + if tw.Weight != 0.5 { + t.Error("Weight field incorrect") + } +} + +// TestIntegration tests an integrated workflow +func TestIntegration(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock dictionaries + nerData := `{ + "北京": "loca", + "腾讯": "corp" + }` + os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644) + os.WriteFile(filepath.Join(tmpDir, "term.freq"), []byte("北京\t1000\n腾讯\t500\n"), 0644) + + d := NewTermWeightDealer(tmpDir) + + // Full workflow: text -> split -> pretoken -> token_merge -> weights + text := "北京 腾讯 公司" + + // Step 1: Split + splitted := d.Split(text) + if len(splitted) == 0 { + t.Fatal("Split returned empty result") + } + + // Step 2: Pretoken + var allTokens []string + for _, s := range splitted { + tokens := d.Pretoken(s, true, true) + allTokens = append(allTokens, tokens...) + } + + // Step 3: Token merge + merged := d.TokenMerge(allTokens) + + // Step 4: Calculate weights + weights := d.Weights(merged, false) + + // Verify results + if len(weights) == 0 && len(merged) > 0 { + t.Error("Weights calculation failed") + } + + // Check weights sum to 1 + sum := 0.0 + for _, w := range weights { + sum += w.Weight + } + if sum < 0.99 || sum > 1.01 { + t.Errorf("Final weights sum to %f, expected ~1.0", sum) + } +} + +// TestWeightsEdgeCases tests edge cases for weight calculation +func TestWeightsEdgeCases(t *testing.T) { + d := NewTermWeightDealer("") + + t.Run("numbers pattern", func(t *testing.T) { + tks := []string{"123,45", "abc"} + weights := d.Weights(tks, false) + if len(weights) != 2 { + t.Fatalf("Expected 2 weights, got %d", len(weights)) + } + // Numbers should get nerWeight=2 + }) + + t.Run("short letters pattern", func(t *testing.T) { + tks := []string{"ab", "abc"} + weights := d.Weights(tks, false) + if len(weights) != 2 { + t.Fatalf("Expected 2 weights, got %d", len(weights)) + } + }) + + t.Run("letter pattern with spaces", func(t *testing.T) { + tks := []string{"hello world test"} + weights := d.Weights(tks, true) + // Should not panic + if weights == nil { + t.Error("Weights returned nil for letter pattern") + } + }) +} + +// TestPretokenWithNumbers tests pretoken with num parameter +func TestPretokenWithNumbers(t *testing.T) { + d := NewTermWeightDealer("") + + t.Run("num=false filters single digits", func(t *testing.T) { + result := d.Pretoken("5", false, true) + // Single digit should be filtered when num=false + found := false + for _, r := range result { + if r == "5" { + found = true + break + } + } + if found { + t.Error("Single digit should be filtered when num=false") + } + }) + + t.Run("num=true keeps single digits", func(t *testing.T) { + result := d.Pretoken("5 123", true, true) + // Check at least something is returned + if len(result) == 0 { + t.Log("Single digit may still be filtered by other rules") + } + }) +} + +// TestPretokenStopWords tests pretoken with stpwd parameter +func TestPretokenStopWords(t *testing.T) { + d := NewTermWeightDealer("") + + t.Run("stpwd=true removes stop words", func(t *testing.T) { + result := d.Pretoken("请问", true, true) + // "请问" is a stop word + for _, r := range result { + if r == "请问" { + t.Error("Stop word should be removed when stpwd=true") + } + } + }) + + t.Run("stpwd=false keeps stop words", func(t *testing.T) { + result := d.Pretoken("请问", true, false) + // With tokenizer, this might still filter it + _ = result + }) +} + +// TestTokenMergeEdgeCases tests edge cases for token merging +func TestTokenMergeEdgeCases(t *testing.T) { + d := NewTermWeightDealer("") + + t.Run("nil input", func(t *testing.T) { + result := d.TokenMerge(nil) + if len(result) != 0 { + t.Error("TokenMerge(nil) should return empty slice") + } + }) + + t.Run("empty strings in input", func(t *testing.T) { + result := d.TokenMerge([]string{"", "a", "", "b", ""}) + // Empty strings should be filtered + for _, r := range result { + if r == "" { + t.Error("Empty strings should be filtered") + } + } + }) + + t.Run("exactly 4 short tokens", func(t *testing.T) { + // 4 short tokens should be merged as one group (not split into pairs) + result := d.TokenMerge([]string{"a", "b", "c", "d"}) + expected := []string{"a b c d"} + if !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %v, got %v", expected, result) + } + }) + + t.Run("exactly 5 short tokens", func(t *testing.T) { + // 5 short tokens should be split into pairs + result := d.TokenMerge([]string{"a", "b", "c", "d", "e"}) + // Should be: a b, c d (e is left? depends on implementation) + if len(result) < 2 { + t.Errorf("Expected at least 2 groups for 5 tokens, got %d: %v", len(result), result) + } + }) +} + +// TestSplitEdgeCases tests edge cases for splitting +func TestSplitEdgeCases(t *testing.T) { + d := NewTermWeightDealer("") + + t.Run("tabs and spaces", func(t *testing.T) { + result := d.Split("hello\tworld\t\ttest") + // Tabs should be normalized to single space + hasTab := false + for _, r := range result { + if strings.Contains(r, "\t") { + hasTab = true + break + } + } + if hasTab { + t.Error("Tabs should be normalized") + } + }) + + t.Run("consecutive English with different NE types", func(t *testing.T) { + tmpDir := t.TempDir() + nerData := `{ + "hello": "func", + "world": "corp" + }` + os.WriteFile(filepath.Join(tmpDir, "ner.json"), []byte(nerData), 0644) + d2 := NewTermWeightDealer(tmpDir) + + result := d2.Split("hello world") + // Both have NE types, so they should NOT merge + if len(result) != 2 { + t.Errorf("Expected 2 tokens when both have NE types, got %d: %v", len(result), result) + } + }) +} diff --git a/internal/service/nlp/wordnet.go b/internal/service/nlp/wordnet.go new file mode 100644 index 00000000000..297c4998c51 --- /dev/null +++ b/internal/service/nlp/wordnet.go @@ -0,0 +1,572 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package wordnet provides a Go implementation of NLTK's WordNet synsets functionality. +// This implementation reads WordNet 3.0 database files and provides synonym set lookup. +package nlp + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" +) + +// POS constants for WordNet parts of speech +const ( + NOUN = "n" + VERB = "v" + ADJ = "a" + ADV = "r" +) + +// Morphy substitution rules for each POS +var morphologicalSubstitutions = map[string][][2]string{ + NOUN: { + {"s", ""}, + {"ses", "s"}, + {"ves", "f"}, + {"xes", "x"}, + {"zes", "z"}, + {"ches", "ch"}, + {"shes", "sh"}, + {"men", "man"}, + {"ies", "y"}, + }, + VERB: { + {"s", ""}, + {"ies", "y"}, + {"es", "e"}, + {"es", ""}, + {"ed", "e"}, + {"ed", ""}, + {"ing", "e"}, + {"ing", ""}, + }, + ADJ: { + {"er", ""}, + {"est", ""}, + {"er", "e"}, + {"est", "e"}, + }, + ADV: {}, +} + +// File suffix mapping for POS +var fileMap = map[string]string{ + NOUN: "noun", + VERB: "verb", + ADJ: "adj", + ADV: "adv", +} + +// Synset represents a WordNet synset (synonym set) +type Synset struct { + Name string + POS string + Offset int + Lemmas []string + Definition string + Examples []string +} + +// WordNet is the main struct for WordNet operations +type WordNet struct { + wordNetDir string + lemmaPosOffsetMap map[string]map[string][]int + exceptionMap map[string]map[string][]string + dataFileCache map[string]*os.File + dataFileCacheOffset map[string]int64 + fileMutexes map[string]*sync.Mutex // Mutex for each POS to ensure concurrency safety +} + +// NewWordNet creates a new WordNet instance with the given WordNet directory +func NewWordNet(wordNetDir string) (*WordNet, error) { + wn := &WordNet{ + wordNetDir: wordNetDir, + lemmaPosOffsetMap: make(map[string]map[string][]int), + exceptionMap: make(map[string]map[string][]string), + dataFileCache: make(map[string]*os.File), + dataFileCacheOffset: make(map[string]int64), + fileMutexes: make(map[string]*sync.Mutex), + } + + // Initialize exception maps for all POS + for pos := range fileMap { + wn.exceptionMap[pos] = make(map[string][]string) + } + + // Load exception files + if err := wn.loadExceptionMaps(); err != nil { + return nil, fmt.Errorf("failed to load exception maps: %w", err) + } + + // Load lemma pos offset map + if err := wn.loadLemmaPosOffsetMap(); err != nil { + return nil, fmt.Errorf("failed to load lemma pos offset map: %w", err) + } + + return wn, nil +} + +// Close closes all cached file handles +func (wn *WordNet) Close() { + for pos, f := range wn.dataFileCache { + if mutex, ok := wn.fileMutexes[pos]; ok { + mutex.Lock() + f.Close() + mutex.Unlock() + } else { + f.Close() + } + } +} + +// loadExceptionMaps loads the .exc files for each POS +func (wn *WordNet) loadExceptionMaps() error { + for pos, suffix := range fileMap { + filename := filepath.Join(wn.wordNetDir, suffix+".exc") + file, err := os.Open(filename) + if err != nil { + // It's okay if the file doesn't exist for some POS + continue + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + // First field is the inflected form, rest are base forms + wn.exceptionMap[pos][fields[0]] = fields[1:] + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading %s: %w", filename, err) + } + } + return nil +} + +// loadLemmaPosOffsetMap loads the index files for each POS +func (wn *WordNet) loadLemmaPosOffsetMap() error { + for _, suffix := range fileMap { + filename := filepath.Join(wn.wordNetDir, "index."+suffix) + file, err := os.Open(filename) + if err != nil { + return fmt.Errorf("failed to open %s: %w", filename, err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + + // Skip license header lines (lines starting with space) + if len(line) == 0 || line[0] == ' ' { + continue + } + + fields := strings.Fields(line) + if len(fields) < 6 { + continue + } + + // Parse index file format: + // lemma pos n_synsets n_pointers [pointers] n_senses n_ranked_synsets [synset_offsets...] + lemma := strings.ToLower(fields[0]) + filePos := fields[1] + nSynsets, err := strconv.Atoi(fields[2]) + if err != nil { + continue + } + nPointers, err := strconv.Atoi(fields[3]) + if err != nil { + continue + } + + // Calculate field positions + fieldIdx := 4 + + // Skip pointer symbols + for i := 0; i < nPointers && fieldIdx < len(fields); i++ { + fieldIdx++ + } + + // Read n_senses and n_ranked_synsets + if fieldIdx >= len(fields) { + continue + } + _, err = strconv.Atoi(fields[fieldIdx]) // n_senses + if err != nil { + continue + } + fieldIdx++ + + if fieldIdx >= len(fields) { + continue + } + _, err = strconv.Atoi(fields[fieldIdx]) // n_ranked_synsets + if err != nil { + continue + } + fieldIdx++ + + // Read synset offsets + var offsets []int + for i := 0; i < nSynsets && fieldIdx < len(fields); i++ { + offset, err := strconv.Atoi(fields[fieldIdx]) + if err != nil { + continue + } + offsets = append(offsets, offset) + fieldIdx++ + } + + // Store in map + if wn.lemmaPosOffsetMap[lemma] == nil { + wn.lemmaPosOffsetMap[lemma] = make(map[string][]int) + } + wn.lemmaPosOffsetMap[lemma][filePos] = offsets + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading %s: %w", filename, err) + } + } + return nil +} + +// morphy performs morphological analysis to find base forms of a word +func (wn *WordNet) morphy(form string, pos string, checkExceptions bool) []string { + form = strings.ToLower(form) + exceptions := wn.exceptionMap[pos] + substitutions := morphologicalSubstitutions[pos] + + // Helper function to apply substitution rules + applyRules := func(forms []string) []string { + var results []string + for _, f := range forms { + for _, sub := range substitutions { + old, new := sub[0], sub[1] + if strings.HasSuffix(f, old) { + base := f[:len(f)-len(old)] + new + results = append(results, base) + } + } + } + return results + } + + // Helper function to filter forms that exist in WordNet + filterForms := func(forms []string) []string { + var results []string + seen := make(map[string]bool) + for _, f := range forms { + if posMap, ok := wn.lemmaPosOffsetMap[f]; ok { + if _, hasPos := posMap[pos]; hasPos { + if !seen[f] { + results = append(results, f) + seen[f] = true + } + } + } + } + return results + } + + var forms []string + if checkExceptions { + if baseForms, ok := exceptions[form]; ok { + forms = baseForms + } + } + + // If no exception found, apply rules + if len(forms) == 0 { + forms = applyRules([]string{form}) + } + + // Filter to keep only valid forms, also check original form + return filterForms(append([]string{form}, forms...)) +} + +// getDataFile returns the data file for a given POS, with caching +func (wn *WordNet) getDataFile(pos string) (*os.File, *sync.Mutex, error) { + if pos == "s" { // Adjective satellite uses the same file as adjective + pos = ADJ + } + + // Get or create mutex for this POS + mutex, exists := wn.fileMutexes[pos] + if !exists { + mutex = &sync.Mutex{} + wn.fileMutexes[pos] = mutex + } + + if file, ok := wn.dataFileCache[pos]; ok { + return file, mutex, nil + } + + suffix, ok := fileMap[pos] + if !ok { + return nil, nil, fmt.Errorf("unknown POS: %s", pos) + } + + filename := filepath.Join(wn.wordNetDir, "data."+suffix) + file, err := os.Open(filename) + if err != nil { + return nil, nil, fmt.Errorf("failed to open %s: %w", filename, err) + } + + wn.dataFileCache[pos] = file + return file, mutex, nil +} + +// parseDataLine parses a line from a data file and returns a Synset +func parseDataLine(line string, pos string) (*Synset, error) { + // Data file format: + // synset_offset lex_filenum ss_type w_cnt word lex_id [word lex_id...] p_cnt [ptr_symbol synset_offset pos src_trgt...] [frames...] | gloss + + parts := strings.SplitN(line, "|", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid line format: no gloss separator") + } + + dataPart := strings.TrimSpace(parts[0]) + glossPart := strings.TrimSpace(parts[1]) + + // Parse gloss to get definition and examples + var definition string + var examples []string + + // Remove quotes from examples + gloss := glossPart + for { + start := strings.Index(gloss, "\"") + if start == -1 { + break + } + end := strings.Index(gloss[start+1:], "\"") + if end == -1 { + break + } + end += start + 1 + + example := gloss[start+1 : end] + if len(examples) == 0 && start > 0 { + definition = strings.TrimSpace(gloss[:start]) + } + examples = append(examples, example) + gloss = gloss[end+1:] + } + + if definition == "" { + definition = strings.Trim(glossPart, "; ") + // Remove quoted examples from definition + definition = regexpRemoveQuotes(definition) + } + + // Final cleanup: trim trailing semicolon and whitespace to match Python NLTK + definition = strings.TrimRight(definition, "; ") + + // Parse data part + fields := strings.Fields(dataPart) + if len(fields) < 4 { + return nil, fmt.Errorf("invalid data line: too few fields") + } + + offset, err := strconv.Atoi(fields[0]) + if err != nil { + return nil, fmt.Errorf("invalid offset: %w", err) + } + + // lexFilenum := fields[1] // Not used currently + ssType := fields[2] + + wCnt, err := strconv.ParseInt(fields[3], 16, 32) + if err != nil { + return nil, fmt.Errorf("invalid word count: %w", err) + } + + // Parse lemmas + var lemmas []string + fieldIdx := 4 + for i := 0; i < int(wCnt) && fieldIdx+1 < len(fields); i++ { + lemma := fields[fieldIdx] + // Remove syntactic marker if present (e.g., "(a)" or "(p)") + if idx := strings.Index(lemma, "("); idx != -1 { + lemma = lemma[:idx] + } + // Keep original case for lemmas (Python NLTK preserves case) + lemmas = append(lemmas, lemma) + fieldIdx += 2 // skip lex_id + } + + if len(lemmas) == 0 { + return nil, fmt.Errorf("no lemmas found") + } + + // Build synset name from first lemma (Python uses lowercase in synset name) + senseIndex := 1 // Default to 1, would need to look up in index for actual sense number + name := fmt.Sprintf("%s.%s.%02d", strings.ToLower(lemmas[0]), ssType, senseIndex) + + return &Synset{ + Name: name, + POS: ssType, + Offset: offset, + Lemmas: lemmas, + Definition: definition, + Examples: examples, + }, nil +} + +// regexpRemoveQuotes removes quoted strings from text (simplified version) +func regexpRemoveQuotes(s string) string { + var result strings.Builder + inQuote := false + for _, ch := range s { + if ch == '"' { + inQuote = !inQuote + continue + } + if !inQuote { + result.WriteRune(ch) + } + } + return strings.TrimSpace(strings.Trim(result.String(), "; ")) +} + +// synsetFromPosAndOffset retrieves a synset by POS and byte offset +func (wn *WordNet) synsetFromPosAndOffset(pos string, offset int) (*Synset, error) { + file, mutex, err := wn.getDataFile(pos) + if err != nil { + return nil, err + } + + // Lock only for Seek and Read operations to minimize critical section + mutex.Lock() + + // Seek to the offset + _, err = file.Seek(int64(offset), 0) + if err != nil { + mutex.Unlock() + return nil, fmt.Errorf("failed to seek to offset %d: %w", offset, err) + } + + reader := bufio.NewReader(file) + line, err := reader.ReadString('\n') + mutex.Unlock() // Release lock immediately after reading + + if err != nil { + return nil, fmt.Errorf("failed to read line at offset %d: %w", offset, err) + } + + //if len(line) < 8 { + // fmt.Println(line) + //} + + // Verify the offset matches + lineOffset := strings.TrimSpace(line[:8]) + expectedOffset := fmt.Sprintf("%08d", offset) + if lineOffset != expectedOffset { + return nil, fmt.Errorf("offset mismatch: expected %s, got %s", expectedOffset, lineOffset) + } + + synset, err := parseDataLine(line, pos) + if err != nil { + return nil, err + } + + // Calculate the correct sense number by looking up the offset in the index + // This operation only accesses memory map, no need for file lock + senseNum := wn.findSenseNumber(synset.Lemmas[0], pos, offset) + if senseNum > 0 { + synset.Name = fmt.Sprintf("%s.%s.%02d", synset.Lemmas[0], synset.POS, senseNum) + } + + return synset, nil +} + +// findSenseNumber finds the sense number for a lemma in a given synset +func (wn *WordNet) findSenseNumber(lemma string, pos string, offset int) int { + lemma = strings.ToLower(lemma) + if posMap, ok := wn.lemmaPosOffsetMap[lemma]; ok { + if offsets, hasPos := posMap[pos]; hasPos { + for i, off := range offsets { + if off == offset { + return i + 1 // sense numbers are 1-indexed + } + } + } + } + return 1 // Default to 1 if not found +} + +// Synsets returns all synsets for a given lemma and optional POS. +// If pos is empty, all parts of speech are searched. +// This is the main function equivalent to NLTK's wordnet.synsets() +func (wn *WordNet) Synsets(lemma string, pos string) []*Synset { + lemma = strings.ToLower(lemma) + + var poses []string + if pos == "" { + poses = []string{NOUN, VERB, ADJ, ADV} + } else { + poses = []string{pos} + } + + var results []*Synset + seen := make(map[string]bool) + + for _, p := range poses { + // Get morphological forms + forms := wn.morphy(lemma, p, true) + + for _, form := range forms { + if posMap, ok := wn.lemmaPosOffsetMap[form]; ok { + if offsets, hasPos := posMap[p]; hasPos { + for _, offset := range offsets { + // Create unique key to avoid duplicates + key := fmt.Sprintf("%s-%d", p, offset) + if !seen[key] { + seen[key] = true + synset, err := wn.synsetFromPosAndOffset(p, offset) + if err == nil { + results = append(results, synset) + } + } + } + } + } + } + } + + return results +} + +// Name returns the synset name (e.g., "dog.n.01") +func (s *Synset) NameStr() string { + return s.Name +} + +// String returns a string representation of the synset +func (s *Synset) String() string { + return fmt.Sprintf("Synset('%s')", s.Name) +} diff --git a/internal/service/nlp/wordnet_test.go b/internal/service/nlp/wordnet_test.go new file mode 100644 index 00000000000..6557b2b3e83 --- /dev/null +++ b/internal/service/nlp/wordnet_test.go @@ -0,0 +1,285 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package nlp + +import ( + "os" + "path/filepath" + "reflect" + "sort" + "testing" +) + +var testWordNetDir string + +func TestNewWordNet(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + // Verify that some basic data was loaded + if len(wn.lemmaPosOffsetMap) == 0 { + t.Error("lemmaPosOffsetMap is empty") + } + + // Check exception map loaded + if len(wn.exceptionMap[NOUN]) == 0 { + t.Error("NOUN exception map is empty") + } +} + +func TestMorphy(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + tests := []struct { + form string + pos string + expected []string + }{ + {"dogs", NOUN, []string{"dog"}}, + {"churches", NOUN, []string{"church"}}, + {"running", VERB, []string{"run"}}, + {"better", ADJ, []string{"good"}}, + } + + for _, tt := range tests { + result := wn.morphy(tt.form, tt.pos, true) + // We just verify that morphy returns some results for known words + // The exact results depend on what's in the exception files + t.Logf("morphy(%q, %q) = %v", tt.form, tt.pos, result) + } +} + +func TestSynsets(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + tests := []struct { + lemma string + pos string + minSynsets int + checkNames []string + }{ + // Basic nouns + {"dog", "", 1, []string{"dog.n.01"}}, + {"dog", NOUN, 1, []string{"dog.n.01"}}, + {"entity", NOUN, 1, []string{"entity.n.01"}}, + {"computer", NOUN, 1, nil}, + // Basic verbs + {"run", VERB, 1, nil}, + {"walk", VERB, 1, nil}, + // Basic adjectives/adverbs + {"good", ADJ, 1, nil}, + {"quickly", ADV, 1, nil}, + // Edge case: multi-word phrases + {"physical_entity", NOUN, 1, nil}, + {"hot_dog", NOUN, 1, nil}, + // Edge case: rare words + {"aardvark", NOUN, 1, nil}, + // Edge case: uppercase input (should be converted to lowercase) + {"DOG", NOUN, 1, []string{"dog.n.01"}}, + // Edge case: non-existent words + {"xyznonexistent", "", 0, nil}, + } + + for _, tt := range tests { + synsets := wn.Synsets(tt.lemma, tt.pos) + if len(synsets) < tt.minSynsets { + t.Errorf("Synsets(%q, %q) returned %d synsets, expected at least %d", + tt.lemma, tt.pos, len(synsets), tt.minSynsets) + } + + // Check that expected names are present + if tt.checkNames != nil { + names := make([]string, len(synsets)) + for i, s := range synsets { + names[i] = s.Name + } + for _, expectedName := range tt.checkNames { + found := false + for _, name := range names { + if name == expectedName { + found = true + break + } + } + if !found { + t.Errorf("Synsets(%q, %q) did not contain expected synset %q, got %v", + tt.lemma, tt.pos, expectedName, names) + } + } + } + + t.Logf("Synsets(%q, %q) returned %d synsets", tt.lemma, tt.pos, len(synsets)) + for _, s := range synsets { + t.Logf(" - %s: %s", s.Name, s.Definition) + } + } +} + +func TestSynsetsDetailed(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + // Test entity - should have at least 1 synset + synsets := wn.Synsets("entity", NOUN) + if len(synsets) == 0 { + t.Fatal("Expected at least 1 synset for 'entity'") + } + + found := false + for _, s := range synsets { + if s.Offset == 1740 { // entity.n.01 offset + found = true + if s.Definition == "" { + t.Error("Expected non-empty definition for entity.n.01") + } + if len(s.Lemmas) == 0 { + t.Error("Expected at least one lemma") + } + } + } + if !found { + t.Errorf("Expected to find synset with offset 1740 for 'entity'") + } +} + +func TestSynsetsConsistencyWithPython(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + // These are the expected results from Python NLTK for comparison + // wordnet.synsets('dog') returns synsets with these names: + pythonDogNames := []string{ + "dog.n.01", + "frump.n.01", + "dog.n.03", + "cad.n.01", + "frank.n.02", + "pawl.n.01", + "andiron.n.01", + } + + synsets := wn.Synsets("dog", NOUN) + var goDogNames []string + for _, s := range synsets { + goDogNames = append(goDogNames, s.Name) + } + + // Sort both lists for comparison + sort.Strings(pythonDogNames) + sort.Strings(goDogNames) + + t.Logf("Python expected (approximate): %v", pythonDogNames) + t.Logf("Go result: %v", goDogNames) + + // We may not match exactly due to sense numbering, but we should have some overlap + if len(goDogNames) == 0 { + t.Error("Expected at least some synsets for 'dog'") + } +} + +func TestSynsetContent(t *testing.T) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + t.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + synsets := wn.Synsets("dog", NOUN) + if len(synsets) == 0 { + t.Fatal("Expected at least 1 synset for 'dog'") + } + + // Check synset structure + for _, s := range synsets { + if s.Name == "" { + t.Error("Synset name is empty") + } + if s.POS == "" { + t.Error("Synset POS is empty") + } + if s.Offset == 0 { + t.Error("Synset offset is 0") + } + if len(s.Lemmas) == 0 { + t.Error("Synset has no lemmas") + } + } +} + +func BenchmarkSynsets(b *testing.B) { + wn, err := NewWordNet(testWordNetDir) + if err != nil { + b.Fatalf("Failed to create WordNet: %v", err) + } + defer wn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + wn.Synsets("dog", NOUN) + } +} + +// Helper function to check if two string slices are equal +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + sort.Strings(a) + sort.Strings(b) + return reflect.DeepEqual(a, b) +} + +func init() { + // Find project root by locating go.mod file + dir, err := os.Getwd() + if err != nil { + panic(err) + } + for { + goModPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(goModPath); err == nil { + // Found go.mod, project root is dir + testWordNetDir = filepath.Join(dir, "resource", "wordnet") + return + } + parent := filepath.Dir(dir) + if parent == dir { + // Reached root directory + break + } + dir = parent + } + // Fallback to relative path if go.mod not found + testWordNetDir = "../../../resource/wordnet" +} diff --git a/internal/service/search.go b/internal/service/search.go new file mode 100644 index 00000000000..106379a77e2 --- /dev/null +++ b/internal/service/search.go @@ -0,0 +1,132 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// SearchService search service +type SearchService struct { + searchDAO *dao.SearchDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewSearchService create search service +func NewSearchService() *SearchService { + return &SearchService{ + searchDAO: dao.NewSearchDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// SearchWithTenantInfo search with tenant info +type SearchWithTenantInfo struct { + *model.Search + Nickname string `json:"nickname"` + TenantAvatar string `json:"tenant_avatar,omitempty"` +} + +// ListSearchAppsRequest list search apps request +type ListSearchAppsRequest struct { + OwnerIDs []string `json:"owner_ids,omitempty"` +} + +// ListSearchAppsResponse list search apps response +type ListSearchAppsResponse struct { + SearchApps []map[string]interface{} `json:"search_apps"` + Total int64 `json:"total"` +} + +// ListSearchApps list search apps with advanced filtering (equivalent to list_search_app) +func (s *SearchService) ListSearchApps(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListSearchAppsResponse, error) { + var searches []*model.Search + var total int64 + var err error + + if len(ownerIDs) == 0 { + // Get tenant IDs by user ID (joined tenants) + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return nil, err + } + + // Use database pagination + searches, total, err = s.searchDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords) + if err != nil { + return nil, err + } + } else { + // Filter by owner IDs, manual pagination + searches, total, err = s.searchDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords) + if err != nil { + return nil, err + } + + // Manual pagination + if page > 0 && pageSize > 0 { + start := (page - 1) * pageSize + end := start + pageSize + if start < int(total) { + if end > int(total) { + end = int(total) + } + searches = searches[start:end] + } else { + searches = []*model.Search{} + } + } + } + + // Convert to response format + searchApps := make([]map[string]interface{}, len(searches)) + for i, search := range searches { + searchApps[i] = s.toSearchAppResponse(search) + } + + return &ListSearchAppsResponse{ + SearchApps: searchApps, + Total: total, + }, nil +} + +// toSearchAppResponse converts search model to response format +func (s *SearchService) toSearchAppResponse(search *model.Search) map[string]interface{} { + result := map[string]interface{}{ + "id": search.ID, + "tenant_id": search.TenantID, + "name": search.Name, + "description": search.Description, + "created_by": search.CreatedBy, + "status": search.Status, + "create_time": search.CreateTime, + "update_time": search.UpdateTime, + "search_config": search.SearchConfig, + } + + if search.Avatar != nil { + result["avatar"] = *search.Avatar + } + + // Add joined fields from user table + // Note: These fields are populated by the DAO query with Select clause + // but GORM will map them to the model's embedded fields if available + // We need to handle the extra fields manually + + return result +} diff --git a/internal/service/system.go b/internal/service/system.go new file mode 100644 index 00000000000..191487633b3 --- /dev/null +++ b/internal/service/system.go @@ -0,0 +1,56 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "ragflow/internal/server" + "ragflow/internal/utility" +) + +// SystemService system service +type SystemService struct{} + +// NewSystemService create system service +func NewSystemService() *SystemService { + return &SystemService{} +} + +// ConfigResponse system configuration response +type ConfigResponse struct { + RegisterEnabled int `json:"registerEnabled"` +} + +// GetConfig get system configuration +func (s *SystemService) GetConfig() (*ConfigResponse, error) { + cfg := server.GetConfig() + return &ConfigResponse{ + RegisterEnabled: cfg.RegisterEnabled, + }, nil +} + +// VersionResponse version response +type VersionResponse struct { + Version string `json:"version"` +} + +// GetVersion get RAGFlow version +func (s *SystemService) GetVersion() (*VersionResponse, error) { + version := utility.GetRAGFlowVersion() + return &VersionResponse{ + Version: version, + }, nil +} diff --git a/internal/service/tenant.go b/internal/service/tenant.go new file mode 100644 index 00000000000..5a024b36c44 --- /dev/null +++ b/internal/service/tenant.go @@ -0,0 +1,120 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "time" + + "ragflow/internal/dao" +) + +// TenantService tenant service +type TenantService struct { + tenantDAO *dao.TenantDAO + userTenantDAO *dao.UserTenantDAO +} + +// NewTenantService create tenant service +func NewTenantService() *TenantService { + return &TenantService{ + tenantDAO: dao.NewTenantDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } +} + +// TenantInfoResponse tenant information response +type TenantInfoResponse struct { + TenantID string `json:"tenant_id"` + Name *string `json:"name,omitempty"` + LLMID string `json:"llm_id"` + EmbDID string `json:"embd_id"` + RerankID string `json:"rerank_id"` + ASRID string `json:"asr_id"` + Img2TxtID string `json:"img2txt_id"` + TTSID *string `json:"tts_id,omitempty"` + ParserIDs string `json:"parser_ids"` + Role string `json:"role"` +} + +// GetTenantInfo get tenant information for the current user (owner tenant) +func (s *TenantService) GetTenantInfo(userID string) (*TenantInfoResponse, error) { + tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID) + if err != nil { + return nil, err + } + if len(tenantInfos) == 0 { + return nil, nil // No tenant found (should not happen for valid user) + } + // Return the first tenant (should be only one owner tenant per user) + ti := tenantInfos[0] + return &TenantInfoResponse{ + TenantID: ti.TenantID, + Name: ti.Name, + LLMID: ti.LLMID, + EmbDID: ti.EmbDID, + RerankID: ti.RerankID, + ASRID: ti.ASRID, + Img2TxtID: ti.Img2TxtID, + TTSID: ti.TTSID, + ParserIDs: ti.ParserIDs, + Role: ti.Role, + }, nil +} + +// TenantListItem tenant list item response +type TenantListItem struct { + TenantID string `json:"tenant_id"` + Role string `json:"role"` + Nickname string `json:"nickname"` + Email string `json:"email"` + Avatar string `json:"avatar"` + UpdateDate string `json:"update_date"` + DeltaSeconds float64 `json:"delta_seconds"` +} + +// GetTenantList get tenant list for a user +func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) { + tenants, err := s.userTenantDAO.GetTenantsByUserID(userID) + if err != nil { + return nil, err + } + + result := make([]*TenantListItem, len(tenants)) + now := time.Now() + + for i, t := range tenants { + // Parse update_date and calculate delta_seconds + var deltaSeconds float64 + if t.UpdateDate != "" { + if updateTime, err := time.Parse("2006-01-02 15:04:05", t.UpdateDate); err == nil { + deltaSeconds = now.Sub(updateTime).Seconds() + } + } + + result[i] = &TenantListItem{ + TenantID: t.TenantID, + Role: t.Role, + Nickname: t.Nickname, + Email: t.Email, + Avatar: t.Avatar, + UpdateDate: t.UpdateDate, + DeltaSeconds: deltaSeconds, + } + } + + return result, nil +} diff --git a/internal/service/user.go b/internal/service/user.go new file mode 100644 index 00000000000..e92541502d8 --- /dev/null +++ b/internal/service/user.go @@ -0,0 +1,621 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "os" + "ragflow/internal/server" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/scrypt" + + "ragflow/internal/dao" + "ragflow/internal/model" + "ragflow/internal/utility" +) + +// UserService user service +type UserService struct { + userDAO *dao.UserDAO +} + +// NewUserService create user service +func NewUserService() *UserService { + return &UserService{ + userDAO: dao.NewUserDAO(), + } +} + +// RegisterRequest registration request +type RegisterRequest struct { + Username string `json:"username" binding:"required,min=3,max=50"` + Password string `json:"password" binding:"required,min=6"` + Email string `json:"email" binding:"required,email"` + Nickname string `json:"nickname"` +} + +// LoginRequest login request +type LoginRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +// EmailLoginRequest email login request +type EmailLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` +} + +// UpdateSettingsRequest update user settings request +type UpdateSettingsRequest struct { + Nickname *string `json:"nickname,omitempty"` + Email *string `json:"email,omitempty" binding:"omitempty,email"` + Avatar *string `json:"avatar,omitempty"` + Language *string `json:"language,omitempty"` + ColorSchema *string `json:"color_schema,omitempty"` + Timezone *string `json:"timezone,omitempty"` +} + +// ChangePasswordRequest change password request +type ChangePasswordRequest struct { + Password *string `json:"password,omitempty"` + NewPassword *string `json:"new_password,omitempty"` +} + +// UserResponse user response +type UserResponse struct { + ID string `json:"id"` + Email string `json:"email"` + Nickname string `json:"nickname"` + Status *string `json:"status"` + CreatedAt string `json:"created_at"` +} + +// Register user registration +func (s *UserService) Register(req *RegisterRequest) (*model.User, error) { + // Check if email exists + existUser, _ := s.userDAO.GetByEmail(req.Email) + if existUser != nil { + return nil, errors.New("email already exists") + } + + // Generate password hash + hashedPassword, err := s.HashPassword(req.Password) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) + } + + // Create user + status := "1" + user := &model.User{ + Password: &hashedPassword, + Email: req.Email, + Nickname: req.Nickname, + Status: &status, + } + + if err := s.userDAO.Create(user); err != nil { + return nil, fmt.Errorf("failed to create user: %w", err) + } + + return user, nil +} + +// Login user login +func (s *UserService) Login(req *LoginRequest) (*model.User, error) { + // Get user by email (using username field as email) + user, err := s.userDAO.GetByEmail(req.Username) + if err != nil { + return nil, errors.New("invalid email or password") + } + + // Decrypt password using RSA + decryptedPassword, err := s.decryptPassword(req.Password) + if err != nil { + return nil, fmt.Errorf("failed to decrypt password: %w", err) + } + + // Verify password + if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { + return nil, errors.New("invalid username or password") + } + + // Check user status + if user.Status == nil || *user.Status != "1" { + return nil, errors.New("user is disabled") + } + + // Generate new access token + token := s.GenerateToken() + if err := s.UpdateUserAccessToken(user, token); err != nil { + return nil, fmt.Errorf("failed to update access token: %w", err) + } + + // Update timestamp + now := time.Now().Unix() + user.UpdateTime = &now + if err := s.userDAO.Update(user); err != nil { + return nil, fmt.Errorf("failed to update user: %w", err) + } + + return user, nil +} + +// LoginByEmail user login by email +func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, error) { + // Check for default admin account + if req.Email == "admin@ragflow.io" { + return nil, errors.New("default admin account cannot be used to login normal services") + } + + // Get user by email + user, err := s.userDAO.GetByEmail(req.Email) + if err != nil { + return nil, errors.New("invalid email or password") + } + + // Decrypt password using RSA + decryptedPassword, err := s.decryptPassword(req.Password) + if err != nil { + return nil, fmt.Errorf("failed to decrypt password: %w", err) + } + + // Verify password + if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { + return nil, errors.New("invalid email or password") + } + + // Check user status + if user.Status == nil || *user.Status != "1" { + return nil, errors.New("user is disabled") + } + + // Generate new access token + token := s.GenerateToken() + user.AccessToken = &token + + // Update timestamp + now := time.Now().Unix() + user.UpdateTime = &now + now_date := time.Now() + user.UpdateDate = &now_date + if err := s.userDAO.Update(user); err != nil { + return nil, fmt.Errorf("failed to update user: %w", err) + } + + return user, nil +} + +// GetUserByID get user by ID +func (s *UserService) GetUserByID(id uint) (*UserResponse, error) { + user, err := s.userDAO.GetByID(id) + if err != nil { + return nil, err + } + + return &UserResponse{ + ID: user.ID, + Email: user.Email, + Nickname: user.Nickname, + Status: user.Status, + CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), + }, nil +} + +// ListUsers list users +func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, error) { + offset := (page - 1) * pageSize + users, total, err := s.userDAO.List(offset, pageSize) + if err != nil { + return nil, 0, err + } + + responses := make([]*UserResponse, len(users)) + for i, user := range users { + responses[i] = &UserResponse{ + ID: user.ID, + Email: user.Email, + Nickname: user.Nickname, + Status: user.Status, + CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), + } + } + + return responses, total, nil +} + +// HashPassword generate password hash +func (s *UserService) HashPassword(password string) (string, error) { + salt := s.generateSalt() + hash, err := scrypt.Key([]byte(password), salt, 32768, 8, 1, 64) + if err != nil { + return "", err + } + + // Return werkzeug format: scrypt:n:r:p$salt$hash + return fmt.Sprintf("scrypt:32768:8:1$%s$%x", string(salt), hash), nil +} + +// VerifyPassword verify password +func (s *UserService) VerifyPassword(hashedPassword, password string) bool { + // Parse hash format: scrypt:n:r:p$salt$hash + parts := strings.Split(hashedPassword, "$") + if len(parts) != 3 { + return false + } + + params := strings.Split(parts[0], ":") + if len(params) != 4 || params[0] != "scrypt" { + return false + } + + n, err := strconv.ParseUint(params[1], 10, 0) + if err != nil { + return false + } + r, err := strconv.ParseUint(params[2], 10, 0) + if err != nil { + return false + } + p, err := strconv.ParseUint(params[3], 10, 0) + if err != nil { + return false + } + + saltStr := parts[1] + hashHex := parts[2] + + // Compute password hash + computed, err := scrypt.Key([]byte(password), []byte(saltStr), int(n), int(r), int(p), len(hashHex)/2) + if err != nil { + return false + } + + decodedHash, err := hex.DecodeString(hashHex) + + // Constant time comparison + return s.constantTimeCompare(decodedHash, computed) +} + +// generateSalt generate salt +func (s *UserService) generateSalt() []byte { + return []byte("random_salt_for_user") // TODO: use random salt +} + +// constantTimeCompare constant time comparison +func (s *UserService) constantTimeCompare(a, b []byte) bool { + if len(a) != len(b) { + return false + } + + var result byte + for i := 0; i < len(a); i++ { + result |= a[i] ^ b[i] + } + + return result == 0 +} + +// loadPrivateKey loads and decrypts the RSA private key from conf/private.pem +// nolint:staticcheck // DecryptPEMBlock is deprecated but still works for traditional PEM encryption +func (s *UserService) loadPrivateKey() (*rsa.PrivateKey, error) { + // Read private key file + keyData, err := os.ReadFile("conf/private.pem") + if err != nil { + return nil, fmt.Errorf("failed to read private key file: %w", err) + } + + // Parse PEM block + block, _ := pem.Decode(keyData) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + + // Decrypt the PEM block if it's encrypted + var privateKey interface{} + if block.Headers["Proc-Type"] == "4,ENCRYPTED" { + // Decrypt using password "Welcome" + // Note: DecryptPEMBlock is deprecated but still functional for traditional PEM encryption + decryptedData, err := x509.DecryptPEMBlock(block, []byte("Welcome")) + if err != nil { + return nil, fmt.Errorf("failed to decrypt private key: %w", err) + } + + // Parse the decrypted key + privateKey, err = x509.ParsePKCS1PrivateKey(decryptedData) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + } else { + // Not encrypted, parse directly + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("not an RSA private key") + } + + return rsaPrivateKey, nil +} + +// decryptPassword decrypts the password using RSA private key +func (s *UserService) decryptPassword(encryptedPassword string) (string, error) { + // Try to decode base64 + ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword) + if err != nil { + // If base64 decoding fails, assume it's already a plain password + return encryptedPassword, nil + } + + // Load private key + privateKey, err := s.loadPrivateKey() + if err != nil { + return "", err + } + + // Decrypt using PKCS#1 v1.5 + plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext) + if err != nil { + // If decryption fails, assume it's already a plain password + return encryptedPassword, nil + } + + return string(plaintext), nil +} + +// GenerateToken generates a new access token +func (s *UserService) GenerateToken() string { + return strings.ReplaceAll(uuid.New().String(), "-", "") +} + +// GetUserByToken gets user by authorization header +// The token parameter is the authorization header value, which needs to be decrypted +// using itsdangerous URLSafeTimedSerializer to get the actual access_token +func (s *UserService) GetUserByToken(authorization string) (*model.User, error) { + // Get secret key from config + variables := server.GetVariables() + secretKey := variables.SecretKey + + // Extract access token from authorization header + // Equivalent to: access_token = str(jwt.loads(authorization)) in Python + accessToken, err := utility.ExtractAccessToken(authorization, secretKey) + if err != nil { + return nil, fmt.Errorf("invalid authorization token: %w", err) + } + + // Validate token format (should be at least 32 chars, UUID format) + if len(accessToken) < 32 { + return nil, errors.New("invalid access token format") + } + + // Get user by access token + return s.userDAO.GetByAccessToken(accessToken) +} + +// UpdateUserAccessToken updates user's access token +func (s *UserService) UpdateUserAccessToken(user *model.User, token string) error { + return s.userDAO.UpdateAccessToken(user, token) +} + +// Logout invalidates user's access token +func (s *UserService) Logout(user *model.User) error { + // Invalidate token by setting it to an invalid value + // Similar to Python implementation: "INVALID_" + secrets.token_hex(16) + invalidToken := "INVALID_" + s.GenerateToken() + return s.UpdateUserAccessToken(user, invalidToken) +} + +// GetUserProfile returns user profile information +func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} { + // Format create time and date (from database fields) + createTime := user.CreateTime + createDate := "" + if user.CreateDate != nil { + createDate = user.CreateDate.Format("2006-01-02T15:04:05") + } + + // Format update time and date (from database fields) + var updateTime int64 + updateDate := "" + if user.UpdateTime != nil { + updateTime = *user.UpdateTime + } + if user.UpdateDate != nil { + updateDate = user.UpdateDate.Format("2006-01-02T15:04:05") + } + + // Format last login time + var lastLoginTime string + if user.LastLoginTime != nil { + lastLoginTime = user.LastLoginTime.Format("2006-01-02T15:04:05") + } + + // Get access token + var accessToken string + if user.AccessToken != nil { + accessToken = *user.AccessToken + } + + // Get avatar + var avatar interface{} + if user.Avatar != nil { + avatar = *user.Avatar + } else { + avatar = nil + } + + // Get color schema + colorSchema := "Bright" + if user.ColorSchema != nil && *user.ColorSchema != "" { + colorSchema = *user.ColorSchema + } + + // Get language + language := "English" + if user.Language != nil && *user.Language != "" { + language = *user.Language + } + + // Get timezone + timezone := "UTC+8\tAsia/Shanghai" + if user.Timezone != nil && *user.Timezone != "" { + timezone = *user.Timezone + } + + // Get login channel + loginChannel := "password" + if user.LoginChannel != nil && *user.LoginChannel != "" { + loginChannel = *user.LoginChannel + } + + // Get password + var password string + if user.Password != nil { + password = *user.Password + } + + // Get status + status := "1" + if user.Status != nil { + status = *user.Status + } + + // Get is_superuser + isSuperuser := false + if user.IsSuperuser != nil { + isSuperuser = *user.IsSuperuser + } + + return map[string]interface{}{ + "access_token": accessToken, + "avatar": avatar, + "color_schema": colorSchema, + "create_date": createDate, + "create_time": createTime, + "email": user.Email, + "id": user.ID, + "is_active": user.IsActive, + "is_anonymous": user.IsAnonymous, + "is_authenticated": user.IsAuthenticated, + "is_superuser": isSuperuser, + "language": language, + "last_login_time": lastLoginTime, + "login_channel": loginChannel, + "nickname": user.Nickname, + "password": password, + "status": status, + "timezone": timezone, + "update_date": updateDate, + "update_time": updateTime, + } +} + +// UpdateUserSettings updates user settings +func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) error { + // Update fields if provided + if req.Nickname != nil { + user.Nickname = *req.Nickname + } + if req.Email != nil { + user.Email = *req.Email + } + if req.Avatar != nil { + // In Go version, avatar might be stored differently + // For now, just update if field exists + } + if req.Language != nil { + // Store language preference + } + if req.ColorSchema != nil { + // Store color schema preference + } + if req.Timezone != nil { + // Store timezone preference + } + + // Save updated user + return s.userDAO.Update(user) +} + +// ChangePassword changes user password +func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) error { + // If password is provided, verify current password + if req.Password != nil { + if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) { + return errors.New("current password is incorrect") + } + } + + // If new password is provided, update password + if req.NewPassword != nil { + hashedPassword, err := s.HashPassword(*req.NewPassword) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + user.Password = &hashedPassword + } + + // Save updated user + return s.userDAO.Update(user) +} + +// LoginChannel represents a login channel response +type LoginChannel struct { + Channel string `json:"channel"` + DisplayName string `json:"display_name"` + Icon string `json:"icon"` +} + +// GetLoginChannels gets all supported authentication channels +func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) { + cfg := server.GetConfig() + channels := make([]*LoginChannel, 0) + + for channel, oauthCfg := range cfg.OAuth { + displayName := oauthCfg.DisplayName + if displayName == "" { + displayName = strings.Title(channel) + } + + icon := oauthCfg.Icon + if icon == "" { + icon = "sso" + } + + channels = append(channels, &LoginChannel{ + Channel: channel, + DisplayName: displayName, + Icon: icon, + }) + } + + return channels, nil +} diff --git a/internal/tokenizer/tokenizer.go b/internal/tokenizer/tokenizer.go new file mode 100644 index 00000000000..54f89d34869 --- /dev/null +++ b/internal/tokenizer/tokenizer.go @@ -0,0 +1,477 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package tokenizer + +import ( + "context" + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + + rag "ragflow/internal/go_binding" + "ragflow/internal/logger" +) + +// PoolConfig configures the elastic analyzer pool +type PoolConfig struct { + DictPath string // Path to dictionary files + MinSize int // Minimum number of pre-warmed instances (default: 2*CPU) + MaxSize int // Maximum number of instances allowed (default: 16*CPU) + IdleTimeout time.Duration // Idle timeout for shrinking (default: 5 minutes) + AcquireTimeout time.Duration // Timeout for acquiring an instance (default: 10 seconds) +} + +// poolInstance wraps an analyzer instance with metadata for pool management +type poolInstance struct { + analyzer *rag.Analyzer + lastUsedAt time.Time +} + +// analyzerPool is the elastic pool for analyzer instances +type analyzerPool struct { + config PoolConfig + baseAnalyzer *rag.Analyzer // Original analyzer used as template for copying + instances chan *poolInstance // Channel-based pool for available instances + currentSize int32 // Current number of instances (atomic) + initialized bool + mu sync.RWMutex + stopCh chan struct{} + wg sync.WaitGroup +} + +var ( + globalPool *analyzerPool + poolOnce sync.Once + poolInitError error +) + +// Init initializes the elastic analyzer pool with the given configuration +// Can be called multiple times if the pool was previously closed +func Init(cfg *PoolConfig) error { + // Check if we need to reset poolOnce (for testing or re-initialization) + if globalPool != nil && !globalPool.initialized { + // Pool was closed, reset poolOnce for re-initialization + poolOnce = sync.Once{} + } + + poolOnce.Do(func() { + if cfg == nil { + cfg = &PoolConfig{} + } + + // Set default values + if cfg.DictPath == "" { + cfg.DictPath = "/usr/share/infinity/resource" + } + if cfg.MinSize <= 0 { + cfg.MinSize = runtime.NumCPU() * 2 + } + if cfg.MaxSize <= 0 { + cfg.MaxSize = runtime.NumCPU() * 16 + } + if cfg.MinSize > cfg.MaxSize { + cfg.MinSize = cfg.MaxSize + } + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = 5 * time.Minute + } + if cfg.AcquireTimeout <= 0 { + cfg.AcquireTimeout = 10 * time.Second + } + + logger.Info("Initializing analyzer pool", + zap.String("dict_path", cfg.DictPath), + zap.Int("min_size", cfg.MinSize), + zap.Int("max_size", cfg.MaxSize), + zap.Duration("idle_timeout", cfg.IdleTimeout), + zap.Duration("acquire_timeout", cfg.AcquireTimeout)) + + globalPool = &analyzerPool{ + config: *cfg, + instances: make(chan *poolInstance, cfg.MaxSize), + stopCh: make(chan struct{}), + } + + // Create the base analyzer as template + baseAnalyzer, err := rag.NewAnalyzer(cfg.DictPath) + if err != nil { + poolInitError = fmt.Errorf("failed to create base analyzer: %w", err) + logger.Error("Failed to create base analyzer", poolInitError) + return + } + + if err = baseAnalyzer.Load(); err != nil { + poolInitError = fmt.Errorf("failed to load base analyzer: %w", err) + logger.Error("Failed to load base analyzer", poolInitError) + baseAnalyzer.Close() + return + } + + globalPool.baseAnalyzer = baseAnalyzer + + // Pre-warm minSize instances + for i := 0; i < cfg.MinSize; i++ { + instance, err := globalPool.createInstance() + if err != nil { + poolInitError = fmt.Errorf("failed to create instance %d: %w", i, err) + logger.Error("Failed to create pool instance", poolInitError) + globalPool.Close() + return + } + globalPool.instances <- instance + atomic.AddInt32(&globalPool.currentSize, 1) + } + + globalPool.initialized = true + logger.Info("Analyzer pool initialized successfully", + zap.Int("pre_warmed", cfg.MinSize), + zap.Int32("current_size", atomic.LoadInt32(&globalPool.currentSize))) + + // Start the shrink loop for idle instance cleanup + globalPool.wg.Add(1) + go globalPool.shrinkLoop() + }) + + return poolInitError +} + +// createInstance creates a new analyzer instance by copying the base analyzer +func (p *analyzerPool) createInstance() (*poolInstance, error) { + if p.baseAnalyzer == nil { + return nil, fmt.Errorf("base analyzer is nil") + } + + // Copy the base analyzer to create a new independent instance + copied := p.baseAnalyzer.Copy() + if copied == nil { + return nil, fmt.Errorf("failed to copy analyzer") + } + + return &poolInstance{ + analyzer: copied, + lastUsedAt: time.Now(), + }, nil +} + +// acquire gets an analyzer instance from the pool +// If pool is empty and below max size, creates a new instance dynamically +func (p *analyzerPool) acquire() (*poolInstance, error) { + if !p.initialized { + return nil, fmt.Errorf("pool not initialized") + } + + // Fast path: try to get from pool without blocking + select { + case instance := <-p.instances: + instance.lastUsedAt = time.Now() + return instance, nil + default: + } + + // Slow path: pool is empty, try dynamic expansion or wait + current := atomic.LoadInt32(&p.currentSize) + if current < int32(p.config.MaxSize) { + // Try to increment atomically and create new instance + if atomic.CompareAndSwapInt32(&p.currentSize, current, current+1) { + instance, err := p.createInstance() + if err != nil { + // Decrement counter on failure + atomic.AddInt32(&p.currentSize, -1) + return nil, fmt.Errorf("failed to dynamically create instance: %w", err) + } + logger.Info("Pool expanded dynamically", + zap.Int32("previous_size", current), + zap.Int32("new_size", current+1), + zap.Int("max_size", p.config.MaxSize)) + return instance, nil + } + // CAS failed, another goroutine created an instance, fall through to wait + } + + // Wait for an instance to become available with timeout + ctx, cancel := context.WithTimeout(context.Background(), p.config.AcquireTimeout) + defer cancel() + + select { + case instance := <-p.instances: + instance.lastUsedAt = time.Now() + return instance, nil + case <-ctx.Done(): + return nil, fmt.Errorf("timeout waiting for analyzer instance (current_size=%d, max=%d)", + atomic.LoadInt32(&p.currentSize), p.config.MaxSize) + } +} + +// release returns an analyzer instance to the pool +func (p *analyzerPool) release(instance *poolInstance) { + if instance == nil || instance.analyzer == nil { + return + } + + if !p.initialized { + instance.analyzer.Close() + return + } + + select { + case p.instances <- instance: + // Successfully returned to pool + default: + // Pool is full (shouldn't happen normally), close this instance + logger.Warn("Pool full when releasing instance, destroying it", + zap.Int32("current_size", atomic.LoadInt32(&p.currentSize))) + instance.analyzer.Close() + atomic.AddInt32(&p.currentSize, -1) + } +} + +// shrinkLoop periodically checks and shrinks the pool by removing idle instances +func (p *analyzerPool) shrinkLoop() { + defer p.wg.Done() + + ticker := time.NewTicker(30 * time.Second) // Check every 30 seconds + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.shrink() + case <-p.stopCh: + return + } + } +} + +// shrink removes idle instances that have exceeded the idle timeout +// while keeping at least MinSize instances +func (p *analyzerPool) shrink() { + if !p.initialized { + return + } + + currentSize := atomic.LoadInt32(&p.currentSize) + minSize := int32(p.config.MinSize) + + // Only shrink if we have more than minimum instances + if currentSize <= minSize { + return + } + + now := time.Now() + timeout := p.config.IdleTimeout + var toRemove []*poolInstance + + // Try to collect idle instances without blocking + for i := 0; i < int(currentSize-minSize); i++ { + select { + case instance := <-p.instances: + if now.Sub(instance.lastUsedAt) > timeout { + toRemove = append(toRemove, instance) + } else { + // Not idle, put back + select { + case p.instances <- instance: + default: + // Pool full, should not happen + toRemove = append(toRemove, instance) + } + } + default: + // No more instances in pool + break + } + } + + if len(toRemove) > 0 { + // Close and destroy idle instances + for _, instance := range toRemove { + instance.analyzer.Close() + } + + newSize := atomic.AddInt32(&p.currentSize, -int32(len(toRemove))) + logger.Info("Pool shrunk", + zap.Int("removed_instances", len(toRemove)), + zap.Int32("previous_size", currentSize), + zap.Int32("new_size", newSize), + zap.Int("min_size", p.config.MinSize)) + } +} + +// Close closes the pool and releases all resources +func (p *analyzerPool) Close() { + if p == nil { + return + } + + p.mu.Lock() + if !p.initialized { + p.mu.Unlock() + return + } + p.initialized = false + p.mu.Unlock() + + // Signal shrink loop to stop + close(p.stopCh) + p.wg.Wait() + + // Close all instances in pool + close(p.instances) + for instance := range p.instances { + if instance != nil && instance.analyzer != nil { + instance.analyzer.Close() + } + } + + // Close base analyzer + if p.baseAnalyzer != nil { + p.baseAnalyzer.Close() + p.baseAnalyzer = nil + } + + logger.Info("Analyzer pool closed", + zap.Int32("final_size", atomic.LoadInt32(&p.currentSize))) +} + +// GetPoolStats returns current pool statistics +func GetPoolStats() map[string]interface{} { + if globalPool == nil { + return map[string]interface{}{ + "initialized": false, + } + } + + return map[string]interface{}{ + "initialized": globalPool.initialized, + "current_size": atomic.LoadInt32(&globalPool.currentSize), + "min_size": globalPool.config.MinSize, + "max_size": globalPool.config.MaxSize, + "idle_timeout": globalPool.config.IdleTimeout.String(), + "instances_available": len(globalPool.instances), + } +} + +// Close closes the global pool +func Close() { + if globalPool != nil { + globalPool.Close() + } +} + +// withAnalyzer executes the given function with an exclusive analyzer instance +func withAnalyzer(fn func(*rag.Analyzer) error) error { + if globalPool == nil { + return fmt.Errorf("tokenizer pool not initialized") + } + + instance, err := globalPool.acquire() + if err != nil { + return err + } + defer globalPool.release(instance) + + return fn(instance.analyzer) +} + +// withAnalyzerResult executes the given function with an exclusive analyzer instance and returns a result +func withAnalyzerResult[T any](fn func(*rag.Analyzer) (T, error)) (T, error) { + var result T + if globalPool == nil { + return result, fmt.Errorf("tokenizer pool not initialized") + } + + instance, err := globalPool.acquire() + if err != nil { + return result, err + } + defer globalPool.release(instance) + + return fn(instance.analyzer) +} + +// Tokenize tokenizes the text and returns a space-separated string of tokens +// Example: "hello world" -> "hello world" +func Tokenize(text string) (string, error) { + return withAnalyzerResult(func(a *rag.Analyzer) (string, error) { + return a.Tokenize(text) + }) +} + +// TokenizeWithPosition tokenizes the text and returns a list of tokens with position information +func TokenizeWithPosition(text string) ([]rag.TokenWithPosition, error) { + return withAnalyzerResult(func(a *rag.Analyzer) ([]rag.TokenWithPosition, error) { + return a.TokenizeWithPosition(text) + }) +} + +// Analyze analyzes the text and returns all tokens +func Analyze(text string) ([]rag.Token, error) { + return withAnalyzerResult(func(a *rag.Analyzer) ([]rag.Token, error) { + return a.Analyze(text) + }) +} + +// SetFineGrained sets whether to use fine-grained tokenization +// Note: This is a no-op in pool mode as each request uses its own instance +// To configure an instance, modify the base analyzer before Init() or use custom instances +func SetFineGrained(fineGrained bool) { + // In pool mode, we don't set global state on instances + // Each request gets a fresh instance with default settings + logger.Debug("SetFineGrained is no-op in pool mode", zap.Bool("fine_grained", fineGrained)) +} + +// FineGrainedTokenize performs fine-grained tokenization on space-separated tokens +// Input: space-separated tokens (e.g., "hello world 测试") +// Output: space-separated fine-grained tokens (e.g., "hello world 测 试") +func FineGrainedTokenize(tokens string) (string, error) { + return withAnalyzerResult(func(a *rag.Analyzer) (string, error) { + return a.FineGrainedTokenize(tokens) + }) +} + +// SetEnablePosition sets whether to enable position tracking +// Note: This is a no-op in pool mode as each request uses its own instance +func SetEnablePosition(enablePosition bool) { + logger.Debug("SetEnablePosition is no-op in pool mode", zap.Bool("enable_position", enablePosition)) +} + +// IsInitialized checks whether the tokenizer pool has been initialized +func IsInitialized() bool { + return globalPool != nil && globalPool.initialized +} + +// GetTermFreq returns the frequency of a term (matching Python rag_tokenizer.freq) +// Returns: frequency value, or 0 if term not found +func GetTermFreq(term string) int32 { + result, _ := withAnalyzerResult(func(a *rag.Analyzer) (int32, error) { + return a.GetTermFreq(term), nil + }) + return result +} + +// GetTermTag returns the POS tag of a term (matching Python rag_tokenizer.tag) +// Returns: POS tag string (e.g., "n", "v", "ns"), or empty string if term not found or no tag +func GetTermTag(term string) string { + result, _ := withAnalyzerResult(func(a *rag.Analyzer) (string, error) { + return a.GetTermTag(term), nil + }) + return result +} diff --git a/internal/tokenizer/tokenizer_concurrent_test.go b/internal/tokenizer/tokenizer_concurrent_test.go new file mode 100644 index 00000000000..319a693324a --- /dev/null +++ b/internal/tokenizer/tokenizer_concurrent_test.go @@ -0,0 +1,493 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package tokenizer + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "go.uber.org/zap" + + "ragflow/internal/logger" +) + +func init() { + // Initialize logger for tests + if err := logger.Init("info"); err != nil { + fmt.Printf("Failed to initialize logger: %v\n", err) + } +} + +// TestConcurrentTokenize tests concurrent tokenization with dynamic pool expansion and shrinking +func TestConcurrentTokenize(t *testing.T) { + // Use small pool to test expansion + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 2, + MaxSize: 10, + IdleTimeout: 5 * time.Second, + AcquireTimeout: 5 * time.Second, + } + + if err := Init(cfg); err != nil { + t.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + // Print initial pool stats + stats := GetPoolStats() + t.Logf("Initial pool stats: %+v", stats) + + // Test texts + texts := []string{ + "Hello world this is a test", + "Natural language processing is amazing", + "Elastic pool handles concurrent requests", + "中文分词测试", + "深度学习与机器学习", + "RAGFlow is an open-source RAG engine", + } + + // Phase 1: High concurrency test - should trigger expansion + t.Log("=== Phase 1: High concurrency test (should trigger expansion) ===") + var expansionDetected int32 + var wg sync.WaitGroup + numGoroutines := 20 + requestsPerGoroutine := 10 + + start := time.Now() + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + text := texts[(id+j)%len(texts)] + result, err := Tokenize(text) + if err != nil { + t.Errorf("Goroutine %d request %d failed: %v", id, j, err) + return + } + if result == "" { + t.Errorf("Goroutine %d request %d returned empty result", id, j) + } + + // Check pool stats periodically + if j%5 == 0 { + stats := GetPoolStats() + currentSize := stats["current_size"].(int32) + if currentSize > int32(cfg.MinSize) { + atomic.StoreInt32(&expansionDetected, 1) + } + } + } + }(i) + } + wg.Wait() + phase1Duration := time.Since(start) + + stats = GetPoolStats() + t.Logf("Phase 1 completed in %v", phase1Duration) + t.Logf("Pool stats after Phase 1: %+v", stats) + + if atomic.LoadInt32(&expansionDetected) == 1 { + t.Log("✓ Pool expansion detected during high concurrency") + } else { + t.Log("℗ Pool expansion not detected (may need more concurrency)") + } + + currentSize := stats["current_size"].(int32) + if currentSize > int32(cfg.MinSize) { + t.Logf("✓ Current pool size (%d) is greater than minSize (%d)", currentSize, cfg.MinSize) + } + + // Phase 2: Wait for idle timeout - should trigger shrinking + t.Log("=== Phase 2: Waiting for idle timeout (should trigger shrinking) ===") + t.Logf("Waiting %v for idle instances to timeout...", cfg.IdleTimeout) + time.Sleep(cfg.IdleTimeout + 2*time.Second) + + stats = GetPoolStats() + t.Logf("Pool stats after Phase 2 (waiting): %+v", stats) + + currentSize = stats["current_size"].(int32) + if currentSize <= int32(cfg.MinSize) { + t.Logf("✓ Pool shrunk back to minSize or below: current=%d, min=%d", currentSize, cfg.MinSize) + } else { + t.Logf("℗ Pool not yet shrunk: current=%d, min=%d (may need more time)", currentSize, cfg.MinSize) + } + + // Phase 3: Moderate concurrency after shrink - should trigger expansion again + t.Log("=== Phase 3: Moderate concurrency after shrink (should trigger re-expansion) ===") + var reExpansionDetected int32 + start = time.Now() + for i := 0; i < numGoroutines/2; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < requestsPerGoroutine/2; j++ { + text := texts[(id+j)%len(texts)] + _, err := Tokenize(text) + if err != nil { + t.Errorf("Phase 3 goroutine %d request %d failed: %v", id, j, err) + return + } + + if j%3 == 0 { + stats := GetPoolStats() + currentSize := stats["current_size"].(int32) + if currentSize > int32(cfg.MinSize) { + atomic.StoreInt32(&reExpansionDetected, 1) + } + } + } + }(i) + } + wg.Wait() + phase3Duration := time.Since(start) + + stats = GetPoolStats() + t.Logf("Phase 3 completed in %v", phase3Duration) + t.Logf("Pool stats after Phase 3: %+v", stats) + + if atomic.LoadInt32(&reExpansionDetected) == 1 { + t.Log("✓ Pool re-expansion detected after shrink") + } + + t.Log("=== Test completed successfully ===") +} + +// TestConcurrentTokenizeWithPosition tests concurrent tokenization with position info +func TestConcurrentTokenizeWithPosition(t *testing.T) { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 2, + MaxSize: 8, + IdleTimeout: 3 * time.Second, + AcquireTimeout: 5 * time.Second, + } + + if err := Init(cfg); err != nil { + t.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + text := "This is a test sentence for position tracking" + var wg sync.WaitGroup + numGoroutines := 15 + + t.Log("=== Testing TokenizeWithPosition concurrently ===") + start := time.Now() + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 5; j++ { + tokens, err := TokenizeWithPosition(text) + if err != nil { + t.Errorf("Goroutine %d request %d failed: %v", id, j, err) + return + } + if len(tokens) == 0 { + t.Errorf("Goroutine %d request %d returned empty tokens", id, j) + return + } + // Verify position info + for _, token := range tokens { + if token.Text == "" { + t.Errorf("Goroutine %d request %d returned empty token text", id, j) + return + } + if token.EndOffset <= token.Offset { + t.Errorf("Goroutine %d request %d has invalid position: offset=%d, end=%d", + id, j, token.Offset, token.EndOffset) + return + } + } + } + }(i) + } + wg.Wait() + + duration := time.Since(start) + stats := GetPoolStats() + t.Logf("Completed %d goroutines x 5 requests in %v", numGoroutines, duration) + t.Logf("Final pool stats: %+v", stats) + t.Log("✓ TokenizeWithPosition concurrent test passed") +} + +// TestPoolExhaustion tests pool exhaustion and timeout behavior +func TestPoolExhaustion(t *testing.T) { + // Very small pool to test exhaustion + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 1, + MaxSize: 2, + IdleTimeout: 10 * time.Second, + AcquireTimeout: 500 * time.Millisecond, // Short timeout for faster test + } + + if err := Init(cfg); err != nil { + t.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + t.Log("=== Testing pool exhaustion behavior ===") + stats := GetPoolStats() + t.Logf("Initial pool stats: %+v", stats) + + // Use all available instances + var wg sync.WaitGroup + barrier := make(chan struct{}) + errors := make(chan error, 10) + + // Launch goroutines that hold instances + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + <-barrier // Wait for signal to start + _, err := Tokenize("test text") + if err != nil { + errors <- fmt.Errorf("goroutine %d: %w", id, err) + } + }(i) + } + + // Release all goroutines at once to create contention + close(barrier) + + // Wait for all to complete + wg.Wait() + close(errors) + + timeoutCount := 0 + for err := range errors { + if err != nil { + t.Logf("Expected error from limited pool: %v", err) + timeoutCount++ + } + } + + stats = GetPoolStats() + t.Logf("Final pool stats: %+v", stats) + t.Logf("Timeout errors: %d (expected with small pool)", timeoutCount) + + if timeoutCount > 0 { + t.Log("✓ Pool correctly returned timeout errors when exhausted") + } else { + t.Log("℗ No timeout errors (pool handled all requests, may be too fast)") + } +} + +// TestFineGrainedTokenizeConcurrent tests concurrent fine-grained tokenization +func TestFineGrainedTokenizeConcurrent(t *testing.T) { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 2, + MaxSize: 6, + IdleTimeout: 3 * time.Second, + AcquireTimeout: 5 * time.Second, + } + + if err := Init(cfg); err != nil { + t.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + tokens := "hello world 中文测试" + var wg sync.WaitGroup + numGoroutines := 10 + + t.Log("=== Testing FineGrainedTokenize concurrently ===") + start := time.Now() + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 5; j++ { + result, err := FineGrainedTokenize(tokens) + if err != nil { + t.Errorf("Goroutine %d request %d failed: %v", id, j, err) + return + } + if result == "" { + t.Errorf("Goroutine %d request %d returned empty result", id, j) + } + } + }(i) + } + wg.Wait() + + duration := time.Since(start) + stats := GetPoolStats() + t.Logf("Completed %d goroutines x 5 requests in %v", numGoroutines, duration) + t.Logf("Final pool stats: %+v", stats) + t.Log("✓ FineGrainedTokenize concurrent test passed") +} + +// TestTermFreqAndTagConcurrent tests concurrent term frequency and tag lookups +func TestTermFreqAndTagConcurrent(t *testing.T) { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 2, + MaxSize: 6, + IdleTimeout: 3 * time.Second, + AcquireTimeout: 5 * time.Second, + } + + if err := Init(cfg); err != nil { + t.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + terms := []string{"hello", "world", "中文", "test", "natural"} + var wg sync.WaitGroup + numGoroutines := 10 + + t.Log("=== Testing GetTermFreq and GetTermTag concurrently ===") + start := time.Now() + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 10; j++ { + term := terms[(id+j)%len(terms)] + freq := GetTermFreq(term) + tag := GetTermTag(term) + // We don't validate the results as terms may or may not exist in dictionary + // Just ensuring no panics or errors + _ = freq + _ = tag + } + }(i) + } + wg.Wait() + + duration := time.Since(start) + stats := GetPoolStats() + t.Logf("Completed %d goroutines x 10 requests in %v", numGoroutines, duration) + t.Logf("Final pool stats: %+v", stats) + t.Log("✓ GetTermFreq and GetTermTag concurrent test passed") +} + +// BenchmarkTokenize benchmarks the tokenization performance +func BenchmarkTokenize(b *testing.B) { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: runtime.NumCPU() * 2, + MaxSize: runtime.NumCPU() * 4, + IdleTimeout: 5 * time.Minute, + AcquireTimeout: 10 * time.Second, + } + + if err := Init(cfg); err != nil { + b.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + text := "This is a benchmark test for tokenization performance with natural language processing" + + // Warm up + for i := 0; i < 100; i++ { + Tokenize(text) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := Tokenize(text) + if err != nil { + b.Errorf("Tokenize failed: %v", err) + } + } + }) + + stats := GetPoolStats() + b.Logf("Final pool stats: %+v", stats) +} + +// BenchmarkTokenizeWithPosition benchmarks position-aware tokenization +func BenchmarkTokenizeWithPosition(b *testing.B) { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: runtime.NumCPU() * 2, + MaxSize: runtime.NumCPU() * 4, + IdleTimeout: 5 * time.Minute, + AcquireTimeout: 10 * time.Second, + } + + if err := Init(cfg); err != nil { + b.Fatalf("Failed to initialize pool: %v", err) + } + defer Close() + + text := "This is a benchmark test for position-aware tokenization" + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := TokenizeWithPosition(text) + if err != nil { + b.Errorf("TokenizeWithPosition failed: %v", err) + } + } + }) +} + +// ExampleGetPoolStats demonstrates getting pool statistics +func ExampleGetPoolStats() { + cfg := &PoolConfig{ + DictPath: "/usr/share/infinity/resource", + MinSize: 2, + MaxSize: 10, + IdleTimeout: 5 * time.Minute, + AcquireTimeout: 10 * time.Second, + } + + if err := Init(cfg); err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + return + } + defer Close() + + stats := GetPoolStats() + fmt.Printf("Pool initialized: %v\n", stats["initialized"]) + fmt.Printf("Current size: %d\n", stats["current_size"]) + fmt.Printf("Min size: %d\n", stats["min_size"]) + fmt.Printf("Max size: %d\n", stats["max_size"]) + + // Output will vary based on actual initialization +} + +// logPoolStats logs pool statistics using the zap logger +func logPoolStats(msg string) { + stats := GetPoolStats() + logger.Info(msg, + zap.Bool("initialized", stats["initialized"].(bool)), + zap.Int32("current_size", stats["current_size"].(int32)), + zap.Int("min_size", stats["min_size"].(int)), + zap.Int("max_size", stats["max_size"].(int)), + zap.String("idle_timeout", stats["idle_timeout"].(string)), + zap.Int("instances_available", stats["instances_available"].(int)), + ) +} diff --git a/internal/utility/embedding_lru.go b/internal/utility/embedding_lru.go new file mode 100644 index 00000000000..28725d87d8f --- /dev/null +++ b/internal/utility/embedding_lru.go @@ -0,0 +1,141 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "container/list" + "sync" +) + +// EmbeddingLRU is a thread-safe LRU cache for embeddings. +// The key is a combination of question and embedding ID. +type EmbeddingLRU struct { + capacity int + cache map[string]*list.Element + list *list.List + mu sync.RWMutex +} + +// entry holds the key and value in the LRU cache. +type entry struct { + key string + value []float64 +} + +// NewEmbeddingLRU creates a new EmbeddingLRU with the given capacity. +func NewEmbeddingLRU(capacity int) *EmbeddingLRU { + return &EmbeddingLRU{ + capacity: capacity, + cache: make(map[string]*list.Element), + list: list.New(), + } +} + +// buildKey creates a composite key from question and embedding ID. +func buildKey(question, embeddingID string) string { + // Use a delimiter that is unlikely to appear in the strings. + // If needed, a more robust key generation can be implemented. + return question + "::" + embeddingID +} + +// Get retrieves the embedding for the given question and embedding ID. +// Returns the embedding and true if found, otherwise nil and false. +func (lru *EmbeddingLRU) Get(question, embeddingID string) ([]float64, bool) { + key := buildKey(question, embeddingID) + lru.mu.RLock() + defer lru.mu.RUnlock() + + if elem, ok := lru.cache[key]; ok { + // Move to front (most recently used) + lru.list.MoveToFront(elem) + ent := elem.Value.(*entry) + // Return a copy to prevent external modification of cached slice + embedding := make([]float64, len(ent.value)) + copy(embedding, ent.value) + return embedding, true + } + return nil, false +} + +// Put stores an embedding for the given question and embedding ID. +// If the key already exists, its value is updated and moved to front. +// If the cache is at capacity, the least recently used item is evicted. +func (lru *EmbeddingLRU) Put(question, embeddingID string, embedding []float64) { + key := buildKey(question, embeddingID) + lru.mu.Lock() + defer lru.mu.Unlock() + + // If key exists, update value and move to front + if elem, ok := lru.cache[key]; ok { + lru.list.MoveToFront(elem) + ent := elem.Value.(*entry) + // Replace the embedding slice + ent.value = make([]float64, len(embedding)) + copy(ent.value, embedding) + return + } + + // Add new entry + ent := &entry{key: key, value: make([]float64, len(embedding))} + copy(ent.value, embedding) + elem := lru.list.PushFront(ent) + lru.cache[key] = elem + + // Evict if capacity exceeded + if lru.list.Len() > lru.capacity { + lru.evictOldest() + } +} + +// evictOldest removes the least recently used item from the cache. +// Must be called with lock held. +func (lru *EmbeddingLRU) evictOldest() { + elem := lru.list.Back() + if elem != nil { + lru.list.Remove(elem) + ent := elem.Value.(*entry) + delete(lru.cache, ent.key) + } +} + +// Remove removes the embedding for the given question and embedding ID. +func (lru *EmbeddingLRU) Remove(question, embeddingID string) { + key := buildKey(question, embeddingID) + lru.mu.Lock() + defer lru.mu.Unlock() + + if elem, ok := lru.cache[key]; ok { + lru.list.Remove(elem) + delete(lru.cache, key) + } +} + +// Clear removes all items from the cache. +func (lru *EmbeddingLRU) Clear() { + lru.mu.Lock() + defer lru.mu.Unlock() + + lru.cache = make(map[string]*list.Element) + lru.list.Init() +} + +// Len returns the number of items in the cache. +func (lru *EmbeddingLRU) Len() int { + lru.mu.RLock() + defer lru.mu.RUnlock() + return lru.list.Len() +} diff --git a/internal/utility/token.go b/internal/utility/token.go new file mode 100644 index 00000000000..789036b4478 --- /dev/null +++ b/internal/utility/token.go @@ -0,0 +1,135 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "strings" + + "github.com/iromli/go-itsdangerous" +) + +// ExtractAccessToken extract access token from authorization header +// This is equivalent to: str(jwt.loads(authorization)) in Python +// Uses github.com/iromli/go-itsdangerous for itsdangerous compatibility +func ExtractAccessToken(authorization, secretKey string) (string, error) { + if authorization == "" { + return "", errors.New("empty authorization") + } + + // Create URLSafeTimedSerializer with correct configuration + // Matching Python itsdangerous configuration: + // - salt: "itsdangerous" + // - key_derivation: "django-concat" + // - digest_method: sha1 + algo := &itsdangerous.HMACAlgorithm{DigestMethod: sha1.New} + signer := itsdangerous.NewTimestampSignature( + secretKey, + "itsdangerous", + ".", + "django-concat", + sha1.New, + algo, + ) + + // Unsign the token (verifies signature and extracts payload) + encodedValue, err := signer.Unsign(authorization, 0) + if err != nil { + return "", fmt.Errorf("failed to decode token: %w", err) + } + + // Base64 decode the payload + jsonValue, err := urlSafeB64Decode(encodedValue) + if err != nil { + return "", fmt.Errorf("failed to decode payload: %w", err) + } + + // Parse JSON string (remove surrounding quotes) + value := string(jsonValue) + if strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"") { + value = value[1 : len(value)-1] + } + + return value, nil +} + +// DumpAccessToken creates an authorization token from access token +// This is equivalent to: jwt.dumps(access_token) in Python +// Uses github.com/iromli/go-itsdangerous for itsdangerous compatibility +func DumpAccessToken(accessToken, secretKey string) (string, error) { + if accessToken == "" { + return "", errors.New("empty access token") + } + + // Create URLSafeTimedSerializer with correct configuration + // Matching Python itsdangerous configuration: + // - salt: "itsdangerous" + // - key_derivation: "django-concat" + // - digest_method: sha1 + algo := &itsdangerous.HMACAlgorithm{DigestMethod: sha1.New} + signer := itsdangerous.NewTimestampSignature( + secretKey, + "itsdangerous", + ".", + "django-concat", + sha1.New, + algo, + ) + + // Encode the access token as JSON string (add surrounding quotes) + jsonValue := fmt.Sprintf("\"%s\"", accessToken) + encodedValue := urlSafeB64Encode([]byte(jsonValue)) + + // Sign the token (creates signature) + token, err := signer.Sign(encodedValue) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return token, nil +} + +// urlSafeB64Decode URL-safe base64 decode +func urlSafeB64Decode(s string) ([]byte, error) { + // Add padding if needed + padding := 4 - len(s)%4 + if padding != 4 { + s += strings.Repeat("=", padding) + } + return base64.URLEncoding.DecodeString(s) +} + +// urlSafeB64Encode URL-safe base64 encode (without padding) +func urlSafeB64Encode(data []byte) string { + encoded := base64.URLEncoding.EncodeToString(data) + // Remove padding + return strings.TrimRight(encoded, "=") +} + +// generateSecretKey generates a 32-byte hex string (equivalent to Python's secrets.token_hex(32)) +func GenerateSecretKey() (string, error) { + bytes := make([]byte, 32) // 32 bytes = 256 bits + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random key: %v", err) + } + return hex.EncodeToString(bytes), nil +} diff --git a/internal/utility/version.go b/internal/utility/version.go new file mode 100644 index 00000000000..1097d678f5f --- /dev/null +++ b/internal/utility/version.go @@ -0,0 +1,76 @@ +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "sync" +) + +var ( + ragflowVersionInfo = "unknown" + versionOnce sync.Once +) + +// GetRAGFlowVersion gets the RAGFlow version information +// It reads from VERSION file or falls back to git describe command +func GetRAGFlowVersion() string { + versionOnce.Do(func() { + ragflowVersionInfo = getRAGFlowVersionInternal() + }) + return ragflowVersionInfo +} + +// getRAGFlowVersionInternal internal function to get version +func getRAGFlowVersionInternal() string { + // Get the path to VERSION file + // Assuming this file is in internal/utility, VERSION is in project root + exePath, err := os.Executable() + if err != nil { + return getClosestTagAndCount() + } + + // Try to find VERSION file in project root + // Start from executable directory and go up + dir := filepath.Dir(exePath) + for i := 0; i < 5; i++ { // Try up to 5 levels up + versionPath := filepath.Join(dir, "VERSION") + if data, err := os.ReadFile(versionPath); err == nil { + return strings.TrimSpace(string(data)) + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + + // Fallback to git command + return getClosestTagAndCount() +} + +// getClosestTagAndCount gets version info from git describe command +func getClosestTagAndCount() string { + cmd := exec.Command("git", "describe", "--tags", "--match=v*", "--first-parent", "--always") + output, err := cmd.Output() + if err != nil { + return "unknown" + } + return strings.TrimSpace(string(output)) +} diff --git a/internal/utility/version_test.go b/internal/utility/version_test.go new file mode 100644 index 00000000000..7c3384274a5 --- /dev/null +++ b/internal/utility/version_test.go @@ -0,0 +1,39 @@ +// +// Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "fmt" + "testing" +) + +func TestGetRAGFlowVersion(t *testing.T) { + version := GetRAGFlowVersion() + fmt.Printf("RAGFlow Version: %s\n", version) + if version == "" { + t.Error("GetRAGFlowVersion returned empty string") + } + if version == "unknown" { + t.Log("Warning: GetRAGFlowVersion returned 'unknown', VERSION file not found and git command failed") + } +} + +func TestGetClosestTagAndCount(t *testing.T) { + version := getClosestTagAndCount() + fmt.Printf("Git Version: %s\n", version) + // This test just prints the version, no strict assertion +} diff --git a/web/vite.config.ts b/web/vite.config.ts index 21477eb397d..8f014255b3f 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -75,6 +75,21 @@ export default defineConfig(({ mode, command }) => { changeOrigin: true, ws: true, }, + // '/v1/system/config': { + // target: 'http://127.0.0.1:9382/', + // changeOrigin: true, + // ws: true, + // }, + // '/v1/user/login': { + // target: 'http://127.0.0.1:9382/', + // changeOrigin: true, + // ws: true, + // }, + // '/v1/user/logout': { + // target: 'http://127.0.0.1:9382/', + // changeOrigin: true, + // ws: true, + // }, '/v1': { target: 'http://127.0.0.1:9380/', changeOrigin: true, From d1a22265aa7ff43f674a3bdf1913564fd5f326ee Mon Sep 17 00:00:00 2001 From: Yao Wei <251109226@qq.com> Date: Wed, 4 Mar 2026 19:24:49 +0800 Subject: [PATCH 131/565] fix: remove company info from resume_summary to prevent over-retrieval (#13358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Problem: When searching for a specific company name like(Daofeng Technology), the search would incorrectly return unrelated resumes containing generic terms like (Technology) in their company names Root Cause: The `corporation_name_tks` field was included in the identity fields that are redundantly written to every chunk. This caused common words like "科技" to match across all chunks, leading to over-retrieval of irrelevant resumes. Solution: Remove `corporation_name_tks` from the `_IDENTITY_FIELDS` list. Company information is still preserved in the "Work Overview" chunk where it belongs, allowing proper company-based searches while preventing false positives from generic terms. --------- Co-authored-by: Aron.Yao Co-authored-by: Aron.Yao Co-authored-by: Liu An --- rag/app/resume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rag/app/resume.py b/rag/app/resume.py index f9f885e126d..08bb7377641 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -2125,7 +2125,7 @@ def _build_chunk_document(filename: str, resume: dict, # Extract key identity fields, redundantly written to each chunk # These fields are small in size but high in information density; once retrieved, the candidate can be immediately identified _IDENTITY_FIELDS = ("name_kwd", "phone_kwd", "email_tks", "gender_kwd", - "highest_degree_kwd", "work_exp_flt", "corporation_name_tks") + "highest_degree_kwd", "work_exp_flt") identity_meta = {} for ik in _IDENTITY_FIELDS: iv = resume.get(ik) From 53c833669587a674df03dc1e35a2b3047d368d74 Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:28:36 +0800 Subject: [PATCH 132/565] playwright : add data-testids for new test (#13364) ### What problem does this PR solve? add data-testids for new test ### Type of change - [x] Other (please describe): add data-testids for new test --- .../components/auto-keywords-form-field.tsx | 4 ++++ web/src/components/avatar-upload.tsx | 21 ++++++++++++++++- .../components/children-delimiter-form.tsx | 6 ++++- web/src/components/confirm-delete-dialog.tsx | 17 +++++++++++++- web/src/components/edit-tag/index.tsx | 9 +++++++- .../components/entity-types-form-field.tsx | 10 +++++++- .../components/excel-to-html-form-field.tsx | 1 + .../layout-recognize-form-field.tsx | 5 +++- .../max-token-number-from-field.tsx | 11 ++++++++- web/src/components/page-rank-form-field.tsx | 2 ++ .../graph-rag-form-fields.tsx | 9 +++++++- .../raptor-form-fields.tsx | 23 +++++++++++++++++-- .../components/slider-input-form-field.tsx | 6 +++++ web/src/components/ui/radio.tsx | 11 ++++++++- web/src/components/ui/select.tsx | 21 ++++++++++++++++- .../dataset/components/metedata/interface.ts | 8 +++++++ .../components/metedata/manage-modal.tsx | 11 +++++++++ .../metedata/manage-values-modal.tsx | 6 +++++ .../configuration/common-item.tsx | 22 +++++++++++++++++- .../dataset-setting/configuration/naive.tsx | 8 +++++-- .../dataset/dataset-setting/general-form.tsx | 18 ++++++++++++--- .../pages/dataset/dataset-setting/index.tsx | 1 + .../dataset-setting/permission-form-field.tsx | 1 + .../dataset/dataset-setting/saving-button.tsx | 2 ++ 24 files changed, 215 insertions(+), 18 deletions(-) diff --git a/web/src/components/auto-keywords-form-field.tsx b/web/src/components/auto-keywords-form-field.tsx index 9b8b9da73d0..33373b1e1db 100644 --- a/web/src/components/auto-keywords-form-field.tsx +++ b/web/src/components/auto-keywords-form-field.tsx @@ -13,6 +13,8 @@ export function AutoKeywordsFormField() { min={0} tooltip={t('autoKeywordsTip')} layout={FormLayout.Horizontal} + sliderTestId="ds-settings-parser-auto-keyword-slider" + numberInputTestId="ds-settings-parser-auto-keyword-input" > ); } @@ -28,6 +30,8 @@ export function AutoQuestionsFormField() { min={0} tooltip={t('autoQuestionsTip')} layout={FormLayout.Horizontal} + sliderTestId="ds-settings-parser-auto-question-slider" + numberInputTestId="ds-settings-parser-auto-question-input" > ); } diff --git a/web/src/components/avatar-upload.tsx b/web/src/components/avatar-upload.tsx index bee9c85f0c1..95e9d9c3a7c 100644 --- a/web/src/components/avatar-upload.tsx +++ b/web/src/components/avatar-upload.tsx @@ -18,10 +18,25 @@ type AvatarUploadProps = { value?: string; onChange?: (value: string) => void; tips?: string; + uploadInputTestId?: string; + removeButtonTestId?: string; + cropModalTestId?: string; + cropModalOkButtonTestId?: string; }; export const AvatarUpload = forwardRef( - function AvatarUpload({ value, onChange, tips }, ref) { + function AvatarUpload( + { + value, + onChange, + tips, + uploadInputTestId, + removeButtonTestId, + cropModalTestId, + cropModalOkButtonTestId, + }, + ref, + ) { const { t } = useTranslation(); const [avatarBase64Str, setAvatarBase64Str] = useState(''); // Avatar Image base64 const [isCropModalOpen, setIsCropModalOpen] = useState(false); @@ -285,6 +300,7 @@ export const AvatarUpload = forwardRef( className="border-background focus-visible:border-background absolute -top-2 -right-2 size-6 rounded-full border-2 shadow-none z-10" aria-label="Remove image" type="button" + data-testid={removeButtonTestId} > @@ -299,6 +315,7 @@ export const AvatarUpload = forwardRef( className="absolute top-0 left-0 w-full h-full opacity-0 cursor-pointer" onChange={handleChange} ref={ref} + data-testid={uploadInputTestId} />
@@ -318,6 +335,8 @@ export const AvatarUpload = forwardRef( size="small" onCancel={handleCancelCrop} onOk={handleCrop} + testId={cropModalTestId} + okButtonTestId={cropModalOkButtonTestId} // footer={ //
//
@@ -99,7 +100,10 @@ export function ChildrenDelimiterForm() {
- +
diff --git a/web/src/components/confirm-delete-dialog.tsx b/web/src/components/confirm-delete-dialog.tsx index 616ff7a0724..eacbb839862 100644 --- a/web/src/components/confirm-delete-dialog.tsx +++ b/web/src/components/confirm-delete-dialog.tsx @@ -27,6 +27,9 @@ interface IProps { }; okButtonText?: string; cancelButtonText?: string; + testId?: string; + confirmButtonTestId?: string; + cancelButtonTestId?: string; } export function ConfirmDeleteDialog({ @@ -41,6 +44,9 @@ export function ConfirmDeleteDialog({ content, okButtonText, cancelButtonText, + testId, + confirmButtonTestId, + cancelButtonTestId, }: IProps & DialogProps) { const { t } = useTranslation(); @@ -60,6 +66,7 @@ export function ConfirmDeleteDialog({ onSelect={(e) => e.preventDefault()} onClick={(e) => e.stopPropagation()} className="bg-bg-base " + data-testid={testId ?? 'confirm-delete-dialog'} > @@ -86,12 +93,20 @@ export function ConfirmDeleteDialog({ )} - + {cancelButtonText || t('common.cancel')} {okButtonText || t('common.delete')} diff --git a/web/src/components/edit-tag/index.tsx b/web/src/components/edit-tag/index.tsx index decfbb968ba..cadd34ee8c4 100644 --- a/web/src/components/edit-tag/index.tsx +++ b/web/src/components/edit-tag/index.tsx @@ -13,10 +13,15 @@ interface EditTagsProps { value?: string[]; onChange?: (tags: string[]) => void; disabled?: boolean; + addButtonTestId?: string; + inputTestId?: string; } const EditTag = React.forwardRef( - function EditTag({ value = [], onChange, disabled }, ref) { + function EditTag( + { value = [], onChange, disabled, addButtonTestId, inputTestId }, + ref, + ) { const [inputVisible, setInputVisible] = useState(false); const [inputValue, setInputValue] = useState(''); const inputRef = useRef(null); @@ -92,6 +97,7 @@ const EditTag = React.forwardRef( onChange={handleInputChange} onBlur={handleInputConfirm} disabled={disabled} + data-testid={inputTestId} onKeyDown={(e) => { if (e?.key === 'Enter') { handleInputConfirm(); @@ -107,6 +113,7 @@ const EditTag = React.forwardRef( className="w-fit flex items-center justify-center gap-2 bg-bg-card border-border-button border" onClick={showInput} disabled={disabled} + data-testid={addButtonTestId} > diff --git a/web/src/components/entity-types-form-field.tsx b/web/src/components/entity-types-form-field.tsx index 0cc6b9f071d..6838ce8ac12 100644 --- a/web/src/components/entity-types-form-field.tsx +++ b/web/src/components/entity-types-form-field.tsx @@ -11,9 +11,13 @@ import { type EntityTypesFormFieldProps = { name?: string; + addButtonTestId?: string; + inputTestId?: string; }; export function EntityTypesFormField({ name = 'parser_config.entity_types', + addButtonTestId, + inputTestId, }: EntityTypesFormFieldProps) { const { t } = useTranslate('knowledgeConfiguration'); const form = useFormContext(); @@ -31,7 +35,11 @@ export function EntityTypesFormField({
- +
diff --git a/web/src/components/excel-to-html-form-field.tsx b/web/src/components/excel-to-html-form-field.tsx index 13ff8b821e4..7e4bca30328 100644 --- a/web/src/components/excel-to-html-form-field.tsx +++ b/web/src/components/excel-to-html-form-field.tsx @@ -37,6 +37,7 @@ export function ExcelToHtmlFormField() { diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index e122055e4c9..7b6a077fb3e 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -6,8 +6,8 @@ import { camelCase } from 'lodash'; import { ReactNode, useMemo } from 'react'; import { useFormContext } from 'react-hook-form'; import { MinerUOptionsFormField } from './mineru-options-form-field'; -import { PaddleOCROptionsFormField } from './paddleocr-options-form-field'; import { SelectWithSearch } from './originui/select-with-search'; +import { PaddleOCROptionsFormField } from './paddleocr-options-form-field'; import { FormControl, FormField, @@ -30,6 +30,7 @@ export function LayoutRecognizeFormField({ label, showMineruOptions = true, showPaddleocrOptions = true, + testId, }: { name?: string; horizontal?: boolean; @@ -37,6 +38,7 @@ export function LayoutRecognizeFormField({ label?: ReactNode; showMineruOptions?: boolean; showPaddleocrOptions?: boolean; + testId?: string; }) { const form = useFormContext(); @@ -106,6 +108,7 @@ export function LayoutRecognizeFormField({ diff --git a/web/src/components/max-token-number-from-field.tsx b/web/src/components/max-token-number-from-field.tsx index b01598d9341..c9ed7212c94 100644 --- a/web/src/components/max-token-number-from-field.tsx +++ b/web/src/components/max-token-number-from-field.tsx @@ -5,9 +5,16 @@ import { SliderInputFormField } from './slider-input-form-field'; interface IProps { initialValue?: number; max?: number; + sliderTestId?: string; + numberInputTestId?: string; } -export function MaxTokenNumberFormField({ max = 2048, initialValue }: IProps) { +export function MaxTokenNumberFormField({ + max = 2048, + initialValue, + sliderTestId, + numberInputTestId, +}: IProps) { const { t } = useTranslate('knowledgeConfiguration'); return ( @@ -18,6 +25,8 @@ export function MaxTokenNumberFormField({ max = 2048, initialValue }: IProps) { max={max} defaultValue={initialValue ?? 0} layout={FormLayout.Horizontal} + sliderTestId={sliderTestId} + numberInputTestId={numberInputTestId} > ); } diff --git a/web/src/components/page-rank-form-field.tsx b/web/src/components/page-rank-form-field.tsx index ea7d3459056..42c6d00a242 100644 --- a/web/src/components/page-rank-form-field.tsx +++ b/web/src/components/page-rank-form-field.tsx @@ -14,6 +14,8 @@ export function PageRankFormField() { max={100} min={0} layout={FormLayout.Horizontal} + sliderTestId="ds-settings-parser-page-rank-slider" + numberInputTestId="ds-settings-parser-page-rank-input" > ); } diff --git a/web/src/components/parse-configuration/graph-rag-form-fields.tsx b/web/src/components/parse-configuration/graph-rag-form-fields.tsx index 8f1fcdb4344..1c418773920 100644 --- a/web/src/components/parse-configuration/graph-rag-form-fields.tsx +++ b/web/src/components/parse-configuration/graph-rag-form-fields.tsx @@ -147,7 +147,11 @@ const GraphRagItems = ({ > {useRaptor && ( <> - + @@ -200,6 +205,7 @@ const GraphRagItems = ({ @@ -229,6 +235,7 @@ const GraphRagItems = ({ diff --git a/web/src/components/parse-configuration/raptor-form-fields.tsx b/web/src/components/parse-configuration/raptor-form-fields.tsx index dc089cdb0aa..110493009fc 100644 --- a/web/src/components/parse-configuration/raptor-form-fields.tsx +++ b/web/src/components/parse-configuration/raptor-form-fields.tsx @@ -129,8 +129,18 @@ const RaptorFormFields = ({
- {t('scopeDataset')} - {t('scopeSingleFile')} + + {t('scopeDataset')} + + + {t('scopeSingleFile')} +
@@ -167,6 +177,7 @@ const RaptorFormFields = ({ onChange={(e) => { field.onChange(e?.target?.value); }} + data-testid="ds-settings-raptor-prompt-textarea" /> @@ -186,6 +197,8 @@ const RaptorFormFields = ({ max={2048} min={0} layout={FormLayout.Horizontal} + sliderTestId="ds-settings-raptor-max-token-slider" + numberInputTestId="ds-settings-raptor-max-token-input" > } diff --git a/web/src/components/slider-input-form-field.tsx b/web/src/components/slider-input-form-field.tsx index 2b9980eb124..f45290eb072 100644 --- a/web/src/components/slider-input-form-field.tsx +++ b/web/src/components/slider-input-form-field.tsx @@ -27,6 +27,8 @@ type SliderInputFormFieldProps = { className?: string; numberInputClassName?: string; percentage?: boolean; + sliderTestId?: string; + numberInputTestId?: string; } & FormLayoutType; export const SliderInputFormField = forwardRef< @@ -46,6 +48,8 @@ export const SliderInputFormField = forwardRef< numberInputClassName, layout = FormLayout.Horizontal, percentage = false, + sliderTestId, + numberInputTestId, }, ref, ) => { @@ -95,6 +99,7 @@ export const SliderInputFormField = forwardRef< max={displayMax} min={displayMin} step={displayStep} + data-testid={sliderTestId} > @@ -118,6 +123,7 @@ export const SliderInputFormField = forwardRef< ); } }} + data-testid={numberInputTestId} > diff --git a/web/src/components/ui/radio.tsx b/web/src/components/ui/radio.tsx index 64fa7d14f4a..edcfb502d06 100644 --- a/web/src/components/ui/radio.tsx +++ b/web/src/components/ui/radio.tsx @@ -13,9 +13,17 @@ type RadioProps = { disabled?: boolean; onChange?: (checked: boolean) => void; children?: React.ReactNode; + testId?: string; }; -function Radio({ value, checked, disabled, onChange, children }: RadioProps) { +function Radio({ + value, + checked, + disabled, + onChange, + children, + testId, +}: RadioProps) { const groupContext = useContext(RadioGroupContext); const isControlled = checked !== undefined; // const [internalChecked, setInternalChecked] = useState(false); @@ -54,6 +62,7 @@ function Radio({ value, checked, disabled, onChange, children }: RadioProps) { mergedDisabled && 'border-muted', )} onClick={handleClick} + data-testid={testId} > {isChecked && (
diff --git a/web/src/components/ui/select.tsx b/web/src/components/ui/select.tsx index b4bd825ab73..71e28fe7cc5 100644 --- a/web/src/components/ui/select.tsx +++ b/web/src/components/ui/select.tsx @@ -203,6 +203,8 @@ export type RAGFlowSelectProps = Partial & { contentProps?: React.ComponentPropsWithoutRef; triggerClassName?: string; onlyShowSelectedIcon?: boolean; + triggerTestId?: string; + optionTestIdPrefix?: string; } & SelectPrimitive.SelectProps; /** @@ -237,6 +239,8 @@ export const RAGFlowSelect = forwardRef< // defaultValue, triggerClassName, onlyShowSelectedIcon = false, + triggerTestId, + optionTestIdPrefix, }, ref, ) { @@ -301,6 +305,7 @@ export const RAGFlowSelect = forwardRef< allowClear={allowClear} ref={ref} className={triggerClassName} + data-testid={triggerTestId} > {label} @@ -313,6 +318,11 @@ export const RAGFlowSelect = forwardRef< value={o.value as RAGFlowSelectOptionType['value']} key={o.value} disabled={o.disabled} + data-testid={ + optionTestIdPrefix + ? `${optionTestIdPrefix}-${o.value}` + : undefined + } >
{o.icon} @@ -326,7 +336,16 @@ export const RAGFlowSelect = forwardRef< {o.label} {o.options.map((x) => ( - + {x.label} ))} diff --git a/web/src/pages/dataset/components/metedata/interface.ts b/web/src/pages/dataset/components/metedata/interface.ts index f5b65b194c6..6b759a64c57 100644 --- a/web/src/pages/dataset/components/metedata/interface.ts +++ b/web/src/pages/dataset/components/metedata/interface.ts @@ -73,6 +73,11 @@ export type IManageModalProps = { builtInMetadata?: IBuiltInMetadataItem[]; success?: (data: any) => void; secondTitle?: ReactNode; + testId?: string; + okButtonTestId?: string; + addButtonTestId?: string; + nestedModalTestId?: string; + nestedModalOkButtonTestId?: string; }; export interface IManageValuesProps { @@ -97,6 +102,9 @@ export interface IManageValuesProps { type?: MetadataValueType, ) => void; addDeleteValue: (key: string, value: string) => void; + testId?: string; + okButtonTestId?: string; + addValueButtonTestId?: string; } export interface DeleteOperation { diff --git a/web/src/pages/dataset/components/metedata/manage-modal.tsx b/web/src/pages/dataset/components/metedata/manage-modal.tsx index 3fb8fca8124..7fa3fc6bce1 100644 --- a/web/src/pages/dataset/components/metedata/manage-modal.tsx +++ b/web/src/pages/dataset/components/metedata/manage-modal.tsx @@ -68,6 +68,11 @@ export const ManageMetadataModal = (props: IManageModalProps) => { success, documentIds, secondTitle, + testId, + okButtonTestId, + addButtonTestId, + nestedModalTestId, + nestedModalOkButtonTestId, } = props; const { t } = useTranslation(); const [valueData, setValueData] = useState({ @@ -304,6 +309,8 @@ export const ManageMetadataModal = (props: IManageModalProps) => { onCancel={hideModal} maskClosable={false} okText={t('common.save')} + testId={testId} + okButtonTestId={okButtonTestId} onOk={async () => { const res = await handleSave({ callback: hideModal, @@ -337,6 +344,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => { className="border border-border-button" type="button" onClick={handAddValueRow} + data-testid={addButtonTestId} > {t('common.add')} @@ -571,6 +579,9 @@ export const ManageMetadataModal = (props: IManageModalProps) => { isShowValueSwitch={isShowValueSwitch} isShowType={true} isVerticalShowValue={isVerticalShowValue} + testId={nestedModalTestId} + okButtonTestId={nestedModalOkButtonTestId} + addValueButtonTestId="ds-settings-metadata-add-modal-add-value-btn" // handleDeleteSingleValue={handleDeleteSingleValue} // handleDeleteSingleRow={handleDeleteSingleRow} /> diff --git a/web/src/pages/dataset/components/metedata/manage-values-modal.tsx b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx index 97383e60a67..f2a74a1c9ef 100644 --- a/web/src/pages/dataset/components/metedata/manage-values-modal.tsx +++ b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx @@ -123,6 +123,9 @@ export const ManageValuesModal = (props: IManageValuesProps) => { isVerticalShowValue, isShowType, type: metadataType, + testId, + okButtonTestId, + addValueButtonTestId, } = props; const { metaData, @@ -251,6 +254,8 @@ export const ManageValuesModal = (props: IManageValuesProps) => { onOk={() => formRef.current?.submit(handleSubmit)} maskClosable={false} footer={null} + testId={testId} + okButtonTestId={okButtonTestId} >
{!isEditField && ( @@ -281,6 +286,7 @@ export const ManageValuesModal = (props: IManageValuesProps) => { variant={'ghost'} className="border border-border-button" onClick={handleAddValue} + data-testid={addValueButtonTestId} > diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index 45db84f498e..66ad986f93d 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -107,11 +107,13 @@ export const EmbeddingSelect = ({ field, name, disabled = false, + testId, }: { isEdit: boolean; field: FieldValues; name?: string; disabled?: boolean; + testId?: string; }) => { const { t } = useTranslate('knowledgeConfiguration'); const form = useFormContext(); @@ -149,6 +151,7 @@ export const EmbeddingSelect = ({ value={field.value} options={embeddingModelOptions} placeholder={t('embeddingModelPlaceholder')} + testId={testId} /> ); @@ -188,6 +191,7 @@ export function EmbeddingModelItem({ line = 1, isEdit }: IProps) { isEdit={!!isEdit} field={field} disabled={disabled} + testId="ds-settings-basic-embedding-model-select" >
@@ -313,6 +317,7 @@ export function EnableTocToggle() {
@@ -345,6 +350,8 @@ export function ImageContextWindow() { defaultValue={0} min={0} max={256} + sliderTestId="ds-settings-parser-image-table-context-window-slider" + numberInputTestId="ds-settings-parser-image-table-context-window-input" />
@@ -365,6 +372,8 @@ export function OverlappedPercent() { label={t('knowledgeConfiguration.overlappedPercent')} max={0.3} step={0.01} + sliderTestId="ds-settings-parser-overlapped-percent-slider" + numberInputTestId="ds-settings-parser-overlapped-percent-input" > ); } @@ -439,7 +448,12 @@ export function AutoMetadata({ tooltip: t('knowledgeConfiguration.autoMetadataTip'), render: (fieldProps: ControllerRenderProps) => (
-
@@ -61,6 +64,7 @@ export function GeneralForm() {
@@ -74,7 +78,12 @@ export function GeneralForm() { {t('setting.avatar')} - +
@@ -99,7 +108,10 @@ export function GeneralForm() { {t('flow.description')} - +
diff --git a/web/src/pages/dataset/dataset-setting/index.tsx b/web/src/pages/dataset/dataset-setting/index.tsx index b4a85905387..b3ce08ea094 100644 --- a/web/src/pages/dataset/dataset-setting/index.tsx +++ b/web/src/pages/dataset/dataset-setting/index.tsx @@ -334,6 +334,7 @@ export default function DatasetSettings() { - - -
- - - - - - - -
{currentConversationName}
- -
-
- - - -
- - -
-
-
- - {embedVisible && ( - - )} - + +
+
+ + + + + + + +
{currentConversationName}
+ + +
+
+ + + +
+ + +
+
+
+
+
); } diff --git a/web/src/pages/next-chats/chat/sessions.tsx b/web/src/pages/next-chats/chat/sessions.tsx index 95efe09dbbf..b4e2b9e68ff 100644 --- a/web/src/pages/next-chats/chat/sessions.tsx +++ b/web/src/pages/next-chats/chat/sessions.tsx @@ -1,9 +1,17 @@ import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog'; +import EmbedDialog from '@/components/embed-dialog'; +import { useShowEmbedModal } from '@/components/embed-dialog/use-show-embed-dialog'; import { MoreButton } from '@/components/more-button'; import { RAGFlowAvatar } from '@/components/ragflow-avatar'; import { Button } from '@/components/ui/button'; import { Checkbox } from '@/components/ui/checkbox'; import { SearchInput } from '@/components/ui/input'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; +import { SharedFrom } from '@/constants/chat'; import { useSetModalState } from '@/hooks/common-hooks'; import { useFetchDialog, @@ -15,11 +23,13 @@ import { LucideListChecks, LucidePanelLeftClose, LucidePlus, + LucideSend, LucideTrash2, LucideUndo2, } from 'lucide-react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useParams } from 'react-router'; import { useChatUrlParams } from '../hooks/use-chat-url'; import { useHandleClickConversationCard } from '../hooks/use-click-card'; import { useSelectDerivedConversationList } from '../hooks/use-select-conversation-list'; @@ -132,6 +142,10 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { const selectedCount = useMemo(() => selectedIds.size, [selectedIds]); + const { id } = useParams(); + const { showEmbedModal, hideEmbedModal, embedVisible, beta } = + useShowEmbedModal(); + if (!visible) { return (
@@ -158,7 +172,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { role="complementary" data-testid="chat-detail-sessions" > -
+
{data.name}
+ + + + + {t('common.embedIntoSite')} + + + +
-
); -}; +} diff --git a/web/src/pages/next-chats/chat/index.tsx b/web/src/pages/next-chats/chat/index.tsx index 89483787ebc..63ab7e345c0 100644 --- a/web/src/pages/next-chats/chat/index.tsx +++ b/web/src/pages/next-chats/chat/index.tsx @@ -40,8 +40,11 @@ export default function Chat() { const { data: dialogList } = useFetchConversationList(); const currentConversationName = useMemo(() => { - return dialogList.find((x) => x.id === conversationId)?.name; - }, [conversationId, dialogList]); + return ( + dialogList.find((x) => x.id === conversationId)?.name || + t('chat.newConversation') + ); + }, [conversationId, dialogList, t]); const fetchConversation: typeof handleConversationCardClick = useCallback( async (conversationId, isNew) => { diff --git a/web/tailwind.config.js b/web/tailwind.config.js index 9076087755b..dab0175d8a6 100644 --- a/web/tailwind.config.js +++ b/web/tailwind.config.js @@ -155,10 +155,6 @@ module.exports = { DEFAULT: 'var(--colors-background-inverse-standard)', foreground: 'var(--colors-background-inverse-standard-foreground)', }, - 'colors-background-inverse-standard': { - DEFAULT: 'var(--colors-background-inverse-standard)', - foreground: 'var(--background-inverse-standard-foreground)', - }, 'colors-background-inverse-strong': { DEFAULT: 'var(--colors-background-inverse-strong)', foreground: 'var(--background-inverse-standard-foreground)', From 8ee1e9dd24c2ec4af72cf48ab50e18f8f8e6969c Mon Sep 17 00:00:00 2001 From: Lynn Date: Fri, 6 Mar 2026 11:42:31 +0800 Subject: [PATCH 157/565] Fix: init func (#13430) ### What problem does this PR solve? Fix update_cnt add error in init_data. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/llm.py | 2 +- api/db/init_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/component/llm.py b/agent/component/llm.py index f651fd5abd2..24254ce20cf 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -85,7 +85,7 @@ class LLM(ComponentBase): def __init__(self, canvas, component_id, param: ComponentParamBase): super().__init__(canvas, component_id, param) - chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id)) + chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id) self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config, max_retries=self._param.max_retries, retry_interval=self._param.delay_after_error) diff --git a/api/db/init_data.py b/api/db/init_data.py index 229d3ffc01e..26f53051d71 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -314,7 +314,7 @@ def fix_empty_tenant_model_id(): if tenant_model: update_dict.update({f"tenant_{key}": tenant_model.id}) if update_dict: - update_dict += TenantService.update_by_id(tenant_dict["id"], update_dict) + update_cnt += TenantService.update_by_id(tenant_dict["id"], update_dict) logging.info(f"Update {update_cnt} tenant_model_id in table tenant.") logging.info("Fix empty tenant_model_id done.") From 6cec49974b39cb9f0bf9602736f8f4f46a885767 Mon Sep 17 00:00:00 2001 From: BitToby <218712309+bittoby@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:48:47 +0200 Subject: [PATCH 158/565] fix: re-chunk documents when data source content is updated (#12918) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes: #12889 ### What problem does this PR solve? When syncing external data sources (e.g., Jira, Confluence, Google Drive), updated documents were not being re-chunked. The raw content was correctly updated in blob storage, but the vector database retained stale chunks, causing search results to return outdated information. **Root cause:** The task digest used for chunk reuse optimization was calculated only from parser configuration fields (`parser_id`, `parser_config`, `kb_id`, etc.), without any content-dependent fields. When a document's content changed but the parser configuration remained the same, the system incorrectly reused old chunks instead of regenerating new ones. **Example scenario:** 1. User syncs a Jira issue: "Meeting scheduled for Monday" 2. User updates the Jira issue to: "Meeting rescheduled to Friday" 3. User triggers sync again 4. Raw content panel shows updated text ✓ 5. Chunk panel still shows old text "Monday" ✗ **Solution:** 1. Include `update_time` and `size` in the chunking config, so the task digest changes when document content is updated 2. Track updated documents separately in `upload_document()` and return them for processing 3. Process updated documents through the re-parsing pipeline to regenerate chunks [1.webm](https://github.com/user-attachments/assets/d21d4dcd-e189-4d39-8700-053bae0ca5a0) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/db_models.py | 3 +++ api/db/services/document_service.py | 2 ++ api/db/services/file_service.py | 21 ++++++++++++++++----- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/api/db/db_models.py b/api/db/db_models.py index b735fbce640..6348a68a304 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -908,6 +908,8 @@ class Document(DataBaseModel): process_duration = FloatField(default=0) suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True) + content_hash = CharField(max_length=32, null=True, help_text="xxhash128 of document content for change detection", default="", index=True) + run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) @@ -1523,6 +1525,7 @@ def migrate_db(): alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True)) # Migrate system_settings.value from CharField to TextField for longer sandbox configs alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)")) + alter_db_add_column(migrator, "document", "content_hash", CharField(max_length=32, null=True, help_text="xxhash128 of document content for change detection", default="", index=True)) update_tenant_llm_to_id_primary_key() alter_db_add_column(migrator, "tenant", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) alter_db_add_column(migrator, "tenant", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 9390b794151..8809373a323 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -683,6 +683,8 @@ def get_chunking_config(cls, doc_id): cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, + cls.model.size, + cls.model.content_hash, Knowledgebase.language, Knowledgebase.embd_id, Tenant.id.alias("tenant_id"), diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index d31004c93e0..05091e4d5bb 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Union +import xxhash from peewee import fn from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType @@ -442,11 +443,20 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str doc_id = file.id if hasattr(file, "id") else get_uuid() e, doc = DocumentService.get_by_id(doc_id) if e: - blob = file.read() - settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id) - doc.size = len(blob) - doc = doc.to_dict() - DocumentService.update_by_id(doc["id"], doc) + try: + blob = file.read() + new_hash = xxhash.xxh128(blob).hexdigest() + old_hash = doc.content_hash or "" + settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id) + doc.size = len(blob) + doc.content_hash = new_hash + doc = doc.to_dict() + DocumentService.update_by_id(doc["id"], doc) + if new_hash != old_hash: + files.append((doc, blob)) + except Exception as exc: + logging.exception(f"Failed to update document {doc_id}: {exc}") + err.append(file.filename + ": " + str(exc)) continue try: DocumentService.check_doc_health(kb.tenant_id, file.filename) @@ -485,6 +495,7 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str "location": location, "size": len(blob), "thumbnail": thumbnail_location, + "content_hash": xxhash.xxh128(blob).hexdigest(), } DocumentService.insert(doc) From b30d7eb5da7e6d6e2b7d9430daa33de50cd2e763 Mon Sep 17 00:00:00 2001 From: Achieve3318 Date: Thu, 5 Mar 2026 23:51:22 -0500 Subject: [PATCH 159/565] Feat(memory): implement get_aggregation for OceanBase memory (#13428) ### What problem does this PR solve? - Add aggregation_utils.aggregate_by_field for pure aggregation logic - Wire OBConnection.get_aggregation to use it (unwrap tuple, pass messages) - Add unit tests for aggregate_by_field (no DB/heavy deps) ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- memory/utils/aggregation_utils.py | 56 +++++++++++++++++++ memory/utils/ob_conn.py | 10 +++- .../memory/utils/test_ob_conn_aggregation.py | 55 ++++++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 memory/utils/aggregation_utils.py create mode 100644 test/unit_test/memory/utils/test_ob_conn_aggregation.py diff --git a/memory/utils/aggregation_utils.py b/memory/utils/aggregation_utils.py new file mode 100644 index 00000000000..6de63f1ba13 --- /dev/null +++ b/memory/utils/aggregation_utils.py @@ -0,0 +1,56 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pure aggregation helpers for search results (no heavy dependencies).""" + + +def aggregate_by_field(messages: list | None, field_name: str) -> list[tuple[str, int]]: + """Aggregate message documents by a field; returns [(value, count), ...]. + + Handles pre-aggregated rows (dicts with "value" and "count") and + per-doc field values (str or list of str). + """ + if not messages: + return [] + + counts: dict[str, int] = {} + result: list[tuple[str, int]] = [] + + for doc in messages: + if "value" in doc and "count" in doc: + result.append((doc["value"], doc["count"])) + continue + + if field_name not in doc: + continue + + v = doc[field_name] + if isinstance(v, list): + for vv in v: + if isinstance(vv, str): + key = vv.strip() + if key: + counts[key] = counts.get(key, 0) + 1 + elif isinstance(v, str): + key = v.strip() + if key: + counts[key] = counts.get(key, 0) + 1 + + if counts: + for k, v in counts.items(): + result.append((k, v)) + + return result diff --git a/memory/utils/ob_conn.py b/memory/utils/ob_conn.py index bf8ac400504..09c976e2ca5 100644 --- a/memory/utils/ob_conn.py +++ b/memory/utils/ob_conn.py @@ -24,6 +24,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT from common.decorator import singleton +from memory.utils.aggregation_utils import aggregate_by_field from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template from common.float_utils import get_float @@ -609,5 +610,10 @@ def get_highlight(self, res, keywords: list[str], field_name: str): def get_aggregation(self, res, field_name: str): """Get aggregation for search results.""" - # TODO: Implement aggregation functionality for OceanBase memory - return [] + if isinstance(res, tuple): + res_obj = res[0] + else: + res_obj = res + + messages = getattr(res_obj, "messages", None) + return aggregate_by_field(messages, field_name) diff --git a/test/unit_test/memory/utils/test_ob_conn_aggregation.py b/test/unit_test/memory/utils/test_ob_conn_aggregation.py new file mode 100644 index 00000000000..cf136eb2087 --- /dev/null +++ b/test/unit_test/memory/utils/test_ob_conn_aggregation.py @@ -0,0 +1,55 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for OceanBase memory aggregation. + +Tests the pure aggregation logic used by OBConnection.get_aggregation, +without requiring a real OceanBase instance or heavy dependencies. +""" + +from memory.utils.aggregation_utils import aggregate_by_field + + +class TestAggregateByField: + """Tests for aggregate_by_field (used by get_aggregation).""" + + def test_empty_messages_returns_empty_list(self): + assert aggregate_by_field([], "message_type_kwd") == [] + assert aggregate_by_field(None, "message_type_kwd") == [] + + def test_aggregates_field_values(self): + messages = [ + {"id": "m1", "message_type_kwd": "user", "content_ltks": "a", "message_id": "msg1", "memory_id": "mem1", "status_int": 1}, + {"id": "m2", "message_type_kwd": "assistant", "content_ltks": "b", "message_id": "msg2", "memory_id": "mem1", "status_int": 1}, + {"id": "m3", "message_type_kwd": "user", "content_ltks": "c", "message_id": "msg3", "memory_id": "mem1", "status_int": 1}, + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert set(out) == {("user", 2), ("assistant", 1)} + + def test_single_doc_result(self): + messages = [ + {"id": "m1", "message_type_kwd": "user", "content_ltks": "x", "message_id": "msg1", "memory_id": "mem1", "status_int": 1} + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert out == [("user", 1)] + + def test_pre_aggregated_value_count_rows(self): + messages = [ + {"value": "user", "count": 2}, + {"value": "assistant", "count": 1}, + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert set(out) == {("user", 2), ("assistant", 1)} From 221e7417c65bba37e314d7342d0e779808759fee Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Fri, 6 Mar 2026 16:42:49 +0800 Subject: [PATCH 160/565] =?UTF-8?q?Feat=EF=BC=9AUsing=20Go=20to=20implemen?= =?UTF-8?q?t=20user=20registration=20logic=20(#13431)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Feat:Using Go to implement user registration logic ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .agents/rules/named.md | 192 +++++++++++++++++++++++++ .agents/skills/go-naming/SKILL.md | 6 + go.mod | 3 +- go.sum | 4 + internal/common/error_code.go | 40 ++++++ internal/dao/database.go | 13 +- internal/dao/file.go | 5 + internal/dao/tenant.go | 10 ++ internal/dao/user.go | 5 + internal/handler/chat.go | 24 ++-- internal/handler/chat_session.go | 24 ++-- internal/handler/chunk.go | 6 +- internal/handler/connector.go | 6 +- internal/handler/file.go | 24 ++-- internal/handler/kb.go | 6 +- internal/handler/llm.go | 16 ++- internal/handler/search.go | 6 +- internal/handler/tenant.go | 11 +- internal/handler/user.go | 230 ++++++++++++++++++----------- internal/router/router.go | 1 + internal/service/user.go | 231 ++++++++++++++++++++++-------- 21 files changed, 658 insertions(+), 205 deletions(-) create mode 100644 .agents/rules/named.md create mode 100644 .agents/skills/go-naming/SKILL.md create mode 100644 internal/common/error_code.go diff --git a/.agents/rules/named.md b/.agents/rules/named.md new file mode 100644 index 00000000000..32ba41e1f8e --- /dev/null +++ b/.agents/rules/named.md @@ -0,0 +1,192 @@ +# Go Naming Best Practices + +## 1. Package Naming + +- **All lowercase, no underscores**: `package user`, not `package userService` or `package user_service` +- **Short and meaningful**: `package http`, `package json`, `package dao` +- **Avoid plurals**: `package user` not `package users` +- **Avoid generic names**: Avoid `package util`, `package common`, `package base` + +```go +// Recommended +package user +package handler +package service + +// Not recommended +package UserService +package user_service +package utils +``` + +## 2. File Naming + +- **All lowercase, underscore separated**: `user_handler.go`, `user_service.go` +- **Test files**: `user_handler_test.go` +- **Platform-specific**: `user_linux.go`, `user_windows.go` + +``` +user/ +├── user_handler.go +├── user_service.go +├── user_dao.go +└── user_test.go +``` + +## 3. Directory Naming + +- **All lowercase, no underscores or hyphens**: `internal/`, `pkg/`, `cmd/` +- **Short and descriptive**: `handler/`, `service/`, `dao/` + +``` +project/ +├── cmd/ # Main entry point +│ └── server_main.go +├── internal/ # Private code +│ ├── handler/ +│ ├── service/ +│ ├── dao/ +│ ├── model/ +│ └── middleware/ +├── pkg/ # Public code +└── api/ # API definitions +``` + +## 4. Interface Naming + +- **Single-method interfaces end with "-er"**: `Reader`, `Writer`, `Handler` +- **Verb form**: `Reader`, `Executor`, `Validator` + +```go +// Recommended +type Reader interface { + Read(p []byte) (n int, err error) +} + +type UserService interface { + Register(req *RegisterRequest) (*User, error) + Login(req *LoginRequest) (*User, error) +} + +// Not recommended +type UserInterface interface {} +type IUserService interface {} +``` + +## 5. Struct Naming + +- **CamelCase**: `UserService`, `UserHandler` +- **Avoid redundant prefixes**: `User` not `UserModel` + +```go +// Recommended +type UserService struct {} +type UserHandler struct {} +type RegisterRequest struct {} + +// Not recommended +type user_service struct {} +type SUserService struct {} +type UserModel struct {} +``` + +## 6. Method/Function Naming + +- **CamelCase** +- **Start with verb**: `GetUser`, `CreateUser`, `DeleteUser` +- **Boolean returns use Is/Has/Can prefix**: `IsValid`, `HasPermission` + +```go +// Recommended +func (s *UserService) Register(req *RegisterRequest) (*User, error) +func (s *UserService) GetUserByID(id uint) (*User, error) +func (s *UserService) IsEmailExists(email string) bool + +// Not recommended +func (s *UserService) register_user() +func (s *UserService) get_user_by_id() +func (s *UserService) CheckEmailExists() // Should use Is/Has +``` + +## 7. Constant Naming + +- **CamelCase**: `const MaxRetryCount = 3` +- **Enum constants**: `const StatusActive = "active"` + +```go +// Recommended +const ( + StatusActive = "1" + StatusInactive = "0" + MaxRetryCount = 3 +) + +// Not recommended +const ( + STATUS_ACTIVE = "1" // Not all uppercase + status_active = "1" // Not all lowercase +) +``` + +## 8. Error Variable Naming + +- **Start with "Err"**: `ErrNotFound`, `ErrInvalidInput` + +```go +// Recommended +var ( + ErrNotFound = errors.New("not found") + ErrInvalidInput = errors.New("invalid input") + ErrUnauthorized = errors.New("unauthorized") +) +``` + +## 9. Acronyms Keep Consistent Case + +```go +// Recommended +type HTTPHandler struct {} +var URL string +func GetHTTPClient() {} +func ParseJSON() {} + +// Not recommended +type HttpHandler struct {} +var Url string +func GetHttpClient() {} +``` + +## 10. Project Structure Naming + +``` +project-name/ +├── cmd/ # Main programs +│ └── app_name/ +│ └── main.go +├── internal/ # Private code +│ ├── handler/ # HTTP handlers +│ ├── service/ # Business logic +│ ├── repository/ # Data access +│ ├── model/ # Data models +│ └── config/ # Configuration +├── pkg/ # Public code +├── api/ # API definitions +├── configs/ # Config files +├── scripts/ # Scripts +├── docs/ # Documentation +├── go.mod +└── go.sum +``` + +## Summary Table + +| Type | Rule | Example | +| -------------- | ----------------------------------- | ------------------- | +| Package | All lowercase, no underscores | `package user` | +| File | All lowercase, underscore separated | `user_service.go` | +| Directory | All lowercase, no separators | `internal/handler/` | +| Struct | CamelCase, capitalized first letter | `UserService` | +| Interface | CamelCase, -er suffix | `Reader`, `Writer` | +| Method | CamelCase, verb prefix | `GetUserByID` | +| Constant | CamelCase | `MaxRetryCount` | +| Error Variable | Err prefix | `ErrNotFound` | diff --git a/.agents/skills/go-naming/SKILL.md b/.agents/skills/go-naming/SKILL.md new file mode 100644 index 00000000000..fb7f2b96a50 --- /dev/null +++ b/.agents/skills/go-naming/SKILL.md @@ -0,0 +1,6 @@ +--- +name: go-naming +description: Go naming conventions and best practices. Use this skill when working with Go code and need to name packages, files, directories, structs, interfaces, functions, variables, or constants. Provides comprehensive naming guidelines following Go community standards. +--- + +Strictly follow the naming conventions in [rules/named.md](rules/named.md) diff --git a/go.mod b/go.mod index 256f066ac63..139d21f861f 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,8 @@ require ( golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/sys v0.40.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/protobuf v1.32.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 6b405659dbb..eb856996c4d 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,10 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= diff --git a/internal/common/error_code.go b/internal/common/error_code.go new file mode 100644 index 00000000000..0817c174317 --- /dev/null +++ b/internal/common/error_code.go @@ -0,0 +1,40 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +type ErrorCode int + +const ( + CodeSuccess ErrorCode = 0 + CodeNotEffective ErrorCode = 10 + CodeExceptionError ErrorCode = 100 + CodeArgumentError ErrorCode = 101 + CodeDataError ErrorCode = 102 + CodeOperatingError ErrorCode = 103 + CodeTimeoutError ErrorCode = 104 + CodeConnectionError ErrorCode = 105 + CodeRunning ErrorCode = 106 + CodeResourceExhausted ErrorCode = 107 + CodePermissionError ErrorCode = 108 + CodeAuthenticationError ErrorCode = 109 + CodeBadRequest ErrorCode = 400 + CodeUnauthorized ErrorCode = 401 + CodeForbidden ErrorCode = 403 + CodeNotFound ErrorCode = 404 + CodeConflict ErrorCode = 409 + CodeServerError ErrorCode = 500 +) diff --git a/internal/dao/database.go b/internal/dao/database.go index b0d0e1ee5ed..163f172df98 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -18,6 +18,7 @@ package dao import ( "fmt" + "ragflow/internal/model" "ragflow/internal/server" "time" @@ -77,9 +78,15 @@ func InitDB() error { sqlDB.SetConnMaxLifetime(time.Hour) // Auto migrate - //if err := DB.AutoMigrate(&model.User{}, &model.Document{}); err != nil { - // return fmt.Errorf("failed to migrate database: %w", err) - //} + if err := DB.AutoMigrate( + &model.User{}, + &model.Tenant{}, + &model.UserTenant{}, + &model.File{}, + &model.File2Document{}, + ); err != nil { + return fmt.Errorf("failed to migrate database: %w", err) + } logger.Info("Database connected and migrated successfully") return nil diff --git a/internal/dao/file.go b/internal/dao/file.go index bbf9a660989..c665d1077bd 100644 --- a/internal/dao/file.go +++ b/internal/dao/file.go @@ -195,6 +195,11 @@ func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) { return parentFolders, nil } +// Create creates a new file +func (dao *FileDAO) Create(file *model.File) error { + return DB.Create(file).Error +} + // generateUUID generates a UUID func generateUUID() string { id := uuid.New().String() diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go index c992b1a7429..781c6c20587 100644 --- a/internal/dao/tenant.go +++ b/internal/dao/tenant.go @@ -88,3 +88,13 @@ func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) { } return &tenant, nil } + +// Create creates a new tenant +func (dao *TenantDAO) Create(tenant *model.Tenant) error { + return DB.Create(tenant).Error +} + +// Delete deletes a tenant by ID (soft delete) +func (dao *TenantDAO) Delete(id string) error { + return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error +} diff --git a/internal/dao/user.go b/internal/dao/user.go index 014be061979..ff134683bc1 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -101,3 +101,8 @@ func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) { func (dao *UserDAO) Delete(id uint) error { return DB.Delete(&model.User{}, id).Error } + +// DeleteByID delete user by string ID +func (dao *UserDAO) DeleteByID(id string) error { + return DB.Model(&model.User{}).Where("id = ?", id).Update("status", "0").Error +} diff --git a/internal/handler/chat.go b/internal/handler/chat.go index aa09c3353b9..c7b2dde9842 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -59,11 +59,11 @@ func (h *ChatHandler) ListChats(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -112,11 +112,11 @@ func (h *ChatHandler) ListChatsNext(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -196,11 +196,11 @@ func (h *ChatHandler) SetDialog(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -268,11 +268,11 @@ func (h *ChatHandler) RemoveChats(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index fd5d4492310..54995371a55 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -61,11 +61,11 @@ func (h *ChatSessionHandler) SetChatSession(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -124,11 +124,11 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -190,11 +190,11 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -270,11 +270,11 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index 10b19830da3..d13f4ac2792 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -59,11 +59,11 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/connector.go b/internal/handler/connector.go index 9f54b804198..6c0ebedb051 100644 --- a/internal/handler/connector.go +++ b/internal/handler/connector.go @@ -58,11 +58,11 @@ func (h *ConnectorHandler) ListConnectors(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/file.go b/internal/handler/file.go index 974d3bbd688..3474ce0cb52 100644 --- a/internal/handler/file.go +++ b/internal/handler/file.go @@ -65,11 +65,11 @@ func (h *FileHandler) ListFiles(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -141,11 +141,11 @@ func (h *FileHandler) GetRootFolder(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -189,11 +189,11 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) { } // Get user by access token (for validation) - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -246,11 +246,11 @@ func (h *FileHandler) GetAllParentFolders(c *gin.Context) { } // Get user by access token (for validation) - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/kb.go b/internal/handler/kb.go index 1c482fa89f1..e4e2a025b48 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -130,11 +130,11 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/llm.go b/internal/handler/llm.go index bcad7f2be1d..6926dfc97d2 100644 --- a/internal/handler/llm.go +++ b/internal/handler/llm.go @@ -71,10 +71,11 @@ func (h *LLMHandler) GetMyLLMs(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -127,10 +128,11 @@ func (h *LLMHandler) Factories(c *gin.Context) { } // Get user by token - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -207,11 +209,11 @@ func (h *LLMHandler) ListApp(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/search.go b/internal/handler/search.go index 5a6317b183f..b291a780270 100644 --- a/internal/handler/search.go +++ b/internal/handler/search.go @@ -65,11 +65,11 @@ func (h *SearchHandler) ListSearchApps(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index ab96f958ce4..02b87a41643 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -58,10 +58,11 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) { return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -109,11 +110,11 @@ func (h *TenantHandler) TenantList(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/user.go b/internal/handler/user.go index 2a4091857fd..7fb39a5df96 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -17,7 +17,9 @@ package handler import ( + "fmt" "net/http" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/utility" "strconv" @@ -47,31 +49,51 @@ func NewUserHandler(userService *service.UserService) *UserHandler { // @Produce json // @Param request body service.RegisterRequest true "registration info" // @Success 200 {object} map[string]interface{} -// @Router /api/v1/users/register [post] +// @Router /v1/user/register [post] func (h *UserHandler) Register(c *gin.Context) { var req service.RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, + }) + return + } + + user, code, err := h.userService.Register(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } - user, err := h.userService.Register(&req) + variables := server.GetVariables() + secretKey := variables.SecretKey + authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Failed to generate auth token", + "data": false, }) return } + c.Header("Authorization", authToken) + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "*") + c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Expose-Headers", "Authorization") + + profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "message": "registration successful", - "data": gin.H{ - "id": user.ID, - "nickname": user.Nickname, - "email": user.Email, - }, + "code": common.CodeSuccess, + "message": fmt.Sprintf("%s, welcome aboard!", req.Nickname), + "data": profile, }) } @@ -87,18 +109,20 @@ func (h *UserHandler) Register(c *gin.Context) { func (h *UserHandler) Login(c *gin.Context) { var req service.LoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, "message": err.Error(), + "data": false, }) return } - user, err := h.userService.Login(&req) + user, code, err := h.userService.Login(&req) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } @@ -114,7 +138,7 @@ func (h *UserHandler) Login(c *gin.Context) { c.Header("Access-Control-Expose-Headers", "Authorization") c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "Welcome back!", "data": user, }) @@ -132,18 +156,20 @@ func (h *UserHandler) Login(c *gin.Context) { func (h *UserHandler) LoginByEmail(c *gin.Context) { var req service.EmailLoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, "message": err.Error(), + "data": false, }) return } - user, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } @@ -151,21 +177,26 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { variables := server.GetVariables() secretKey := variables.SecretKey authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) - - // Set Authorization header with access_token - if user.AccessToken != nil { - c.Header("Authorization", authToken) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Failed to generate auth token", + "data": false, + }) + return } - // Set CORS headers + + c.Header("Authorization", authToken) c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "*") c.Header("Access-Control-Allow-Headers", "*") c.Header("Access-Control-Expose-Headers", "Authorization") + profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "Welcome back!", - "data": user, + "data": profile, }) } @@ -182,22 +213,28 @@ func (h *UserHandler) GetUserByID(c *gin.Context) { idStr := c.Param("id") id, err := strconv.ParseUint(idStr, 10, 32) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid user id", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": "invalid user id", + "data": false, }) return } - user, err := h.userService.GetUserByID(uint(id)) + user, code, err := h.userService.GetUserByID(uint(id)) if err != nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": "user not found", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "data": user, + "code": common.CodeSuccess, + "message": "success", + "data": user, }) } @@ -222,15 +259,19 @@ func (h *UserHandler) ListUsers(c *gin.Context) { pageSize = 10 } - users, total, err := h.userService.ListUsers(page, pageSize) + users, total, code, err := h.userService.ListUsers(page, pageSize) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get users", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", "data": gin.H{ "items": users, "total": total, @@ -253,34 +294,38 @@ func (h *UserHandler) Logout(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } // Logout user - if err := h.userService.Logout(user); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + code, err = h.userService.Logout(user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "data": true, "message": "success", }) @@ -299,19 +344,21 @@ func (h *UserHandler) Info(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -320,8 +367,9 @@ func (h *UserHandler) Info(c *gin.Context) { profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": profile, + "code": common.CodeSuccess, + "message": "success", + "data": profile, }) } @@ -339,18 +387,21 @@ func (h *UserHandler) Setting(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -358,22 +409,29 @@ func (h *UserHandler) Setting(c *gin.Context) { // Parse request var req service.UpdateSettingsRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, }) return } // Update user settings - if err := h.userService.UpdateUserSettings(user, &req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + code, err = h.userService.UpdateUserSettings(user, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, "message": "settings updated successfully", + "data": true, }) } @@ -391,18 +449,21 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -410,22 +471,29 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // Parse request var req service.ChangePasswordRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, }) return } // Change password - if err := h.userService.ChangePassword(user, &req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + code, err = h.userService.ChangePassword(user, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, "message": "password changed successfully", + "data": true, }) } @@ -438,10 +506,10 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/login/channels [get] func (h *UserHandler) GetLoginChannels(c *gin.Context) { - channels, err := h.userService.GetLoginChannels() + channels, code, err := h.userService.GetLoginChannels() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": "Load channels failure, error: " + err.Error(), "data": []interface{}{}, }) @@ -449,7 +517,7 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "success", "data": channels, }) diff --git a/internal/router/router.go b/internal/router/router.go index bcc1d683842..5f41765d60f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -86,6 +86,7 @@ func (r *Router) Setup(engine *gin.Engine) { // User login by email endpoint engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + engine.POST("/v1/user/register", r.userHandler.Register) // User login channels endpoint engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels) // User logout endpoint diff --git a/internal/service/user.go b/internal/service/user.go index e92541502d8..99dd5351b57 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -25,7 +25,9 @@ import ( "errors" "fmt" "os" + "ragflow/internal/common" "ragflow/internal/server" + "regexp" "strconv" "strings" "time" @@ -52,9 +54,8 @@ func NewUserService() *UserService { // RegisterRequest registration request type RegisterRequest struct { - Username string `json:"username" binding:"required,min=3,max=50"` - Password string `json:"password" binding:"required,min=6"` Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` Nickname string `json:"nickname"` } @@ -96,125 +97,220 @@ type UserResponse struct { } // Register user registration -func (s *UserService) Register(req *RegisterRequest) (*model.User, error) { - // Check if email exists +func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorCode, error) { + cfg := server.GetConfig() + if cfg.RegisterEnabled == 0 { + return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!") + } + + emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`) + if !emailRegex.MatchString(req.Email) { + return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email) + } + existUser, _ := s.userDAO.GetByEmail(req.Email) if existUser != nil { - return nil, errors.New("email already exists") + return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email) + } + + decryptedPassword, err := s.decryptPassword(req.Password) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password") } - // Generate password hash - hashedPassword, err := s.HashPassword(req.Password) + hashedPassword, err := s.HashPassword(decryptedPassword) if err != nil { - return nil, fmt.Errorf("failed to hash password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err) } - // Create user + userID := s.GenerateToken() + accessToken := s.GenerateToken() status := "1" + loginChannel := "password" + isSuperuser := false + user := &model.User{ - Password: &hashedPassword, - Email: req.Email, - Nickname: req.Nickname, - Status: &status, + ID: userID, + AccessToken: &accessToken, + Email: req.Email, + Nickname: req.Nickname, + Password: &hashedPassword, + Status: &status, + IsActive: "1", + IsAuthenticated: "1", + IsAnonymous: "0", + LoginChannel: &loginChannel, + IsSuperuser: &isSuperuser, } + now := time.Now().Unix() + user.CreateTime = now + user.UpdateTime = &now + now_date := time.Now() + user.CreateDate = &now_date + user.UpdateDate = &now_date + user.LastLoginTime = &now_date + + tenantName := req.Nickname + "'s Kingdom" + tenant := &model.Tenant{ + ID: userID, + Name: &tenantName, + LLMID: cfg.Server.Mode, + EmbDID: cfg.Server.Mode, + ASRID: cfg.Server.Mode, + Img2TxtID: cfg.Server.Mode, + RerankID: cfg.Server.Mode, + ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", + } + tenant.CreateTime = now + tenant.UpdateTime = &now + tenant.CreateDate = &now_date + tenant.UpdateDate = &now_date + + userTenantID := s.GenerateToken() + userTenant := &model.UserTenant{ + ID: userTenantID, + UserID: userID, + TenantID: userID, + Role: "owner", + InvitedBy: userID, + Status: &status, + } + userTenant.CreateTime = now + userTenant.UpdateTime = &now + userTenant.CreateDate = &now_date + userTenant.UpdateDate = &now_date + + fileID := s.GenerateToken() + rootFile := &model.File{ + ID: fileID, + ParentID: fileID, + TenantID: userID, + CreatedBy: userID, + Name: "/", + Type: "folder", + Size: 0, + } + rootFile.CreateTime = now + rootFile.UpdateTime = &now + rootFile.CreateDate = &now_date + rootFile.UpdateDate = &now_date + + tenantDAO := dao.NewTenantDAO() + userTenantDAO := dao.NewUserTenantDAO() + fileDAO := dao.NewFileDAO() + if err := s.userDAO.Create(user); err != nil { - return nil, fmt.Errorf("failed to create user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err) } - return user, nil + if err := tenantDAO.Create(tenant); err != nil { + s.userDAO.DeleteByID(userID) + return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err) + } + + if err := userTenantDAO.Create(userTenant); err != nil { + s.userDAO.DeleteByID(userID) + tenantDAO.Delete(userID) + return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err) + } + + if err := fileDAO.Create(rootFile); err != nil { + s.userDAO.DeleteByID(userID) + tenantDAO.Delete(userID) + userTenantDAO.Delete(userTenantID) + return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err) + } + + return user, common.CodeSuccess, nil } // Login user login -func (s *UserService) Login(req *LoginRequest) (*model.User, error) { +func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, error) { // Get user by email (using username field as email) user, err := s.userDAO.GetByEmail(req.Username) if err != nil { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password") } // Decrypt password using RSA decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, fmt.Errorf("failed to decrypt password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err) } // Verify password if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { - return nil, errors.New("invalid username or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password") } - // Check user status if user.Status == nil || *user.Status != "1" { - return nil, errors.New("user is disabled") + return nil, common.CodeForbidden, fmt.Errorf("user is disabled") } // Generate new access token token := s.GenerateToken() if err := s.UpdateUserAccessToken(user, token); err != nil { - return nil, fmt.Errorf("failed to update access token: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err) } // Update timestamp now := time.Now().Unix() user.UpdateTime = &now if err := s.userDAO.Update(user); err != nil { - return nil, fmt.Errorf("failed to update user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } - return user, nil + return user, common.CodeSuccess, nil } // LoginByEmail user login by email -func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, error) { - // Check for default admin account +// Returns user on success, or error with specific code: +// - CodeAuthenticationError (109): Email not registered or password mismatch +// - CodeServerError (500): Password decryption failure +// - CodeForbidden (403): Account disabled +func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) { if req.Email == "admin@ragflow.io" { - return nil, errors.New("default admin account cannot be used to login normal services") + return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services") } - // Get user by email user, err := s.userDAO.GetByEmail(req.Email) if err != nil { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("Email: %s is not registered!", req.Email) } - // Decrypt password using RSA decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, fmt.Errorf("failed to decrypt password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("Fail to crypt password") } - // Verify password if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("Email and password do not match!") } - // Check user status - if user.Status == nil || *user.Status != "1" { - return nil, errors.New("user is disabled") + if user.IsActive == "0" { + return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!") } - // Generate new access token token := s.GenerateToken() user.AccessToken = &token - // Update timestamp now := time.Now().Unix() user.UpdateTime = &now now_date := time.Now() user.UpdateDate = &now_date if err := s.userDAO.Update(user); err != nil { - return nil, fmt.Errorf("failed to update user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } - return user, nil + return user, common.CodeSuccess, nil } // GetUserByID get user by ID -func (s *UserService) GetUserByID(id uint) (*UserResponse, error) { +func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) { user, err := s.userDAO.GetByID(id) if err != nil { - return nil, err + return nil, common.CodeNotFound, err } return &UserResponse{ @@ -223,15 +319,15 @@ func (s *UserService) GetUserByID(id uint) (*UserResponse, error) { Nickname: user.Nickname, Status: user.Status, CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), - }, nil + }, common.CodeSuccess, nil } // ListUsers list users -func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, error) { +func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) { offset := (page - 1) * pageSize users, total, err := s.userDAO.List(offset, pageSize) if err != nil { - return nil, 0, err + return nil, 0, common.CodeServerError, err } responses := make([]*UserResponse, len(users)) @@ -245,7 +341,7 @@ func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, err } } - return responses, total, nil + return responses, total, common.CodeSuccess, nil } // HashPassword generate password hash @@ -399,7 +495,7 @@ func (s *UserService) GenerateToken() string { // GetUserByToken gets user by authorization header // The token parameter is the authorization header value, which needs to be decrypted // using itsdangerous URLSafeTimedSerializer to get the actual access_token -func (s *UserService) GetUserByToken(authorization string) (*model.User, error) { +func (s *UserService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) { // Get secret key from config variables := server.GetVariables() secretKey := variables.SecretKey @@ -408,16 +504,21 @@ func (s *UserService) GetUserByToken(authorization string) (*model.User, error) // Equivalent to: access_token = str(jwt.loads(authorization)) in Python accessToken, err := utility.ExtractAccessToken(authorization, secretKey) if err != nil { - return nil, fmt.Errorf("invalid authorization token: %w", err) + return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err) } // Validate token format (should be at least 32 chars, UUID format) if len(accessToken) < 32 { - return nil, errors.New("invalid access token format") + return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format") } // Get user by access token - return s.userDAO.GetByAccessToken(accessToken) + user, err := s.userDAO.GetByAccessToken(accessToken) + if err != nil { + return nil, common.CodeUnauthorized, err + } + + return user, common.CodeSuccess, nil } // UpdateUserAccessToken updates user's access token @@ -426,11 +527,15 @@ func (s *UserService) UpdateUserAccessToken(user *model.User, token string) erro } // Logout invalidates user's access token -func (s *UserService) Logout(user *model.User) error { +func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) { // Invalidate token by setting it to an invalid value // Similar to Python implementation: "INVALID_" + secrets.token_hex(16) invalidToken := "INVALID_" + s.GenerateToken() - return s.UpdateUserAccessToken(user, invalidToken) + err := s.UpdateUserAccessToken(user, invalidToken) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // GetUserProfile returns user profile information @@ -539,7 +644,7 @@ func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} { } // UpdateUserSettings updates user settings -func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) error { +func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) (common.ErrorCode, error) { // Update fields if provided if req.Nickname != nil { user.Nickname = *req.Nickname @@ -562,15 +667,18 @@ func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRe } // Save updated user - return s.userDAO.Update(user) + if err := s.userDAO.Update(user); err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // ChangePassword changes user password -func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) error { +func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) (common.ErrorCode, error) { // If password is provided, verify current password if req.Password != nil { if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) { - return errors.New("current password is incorrect") + return common.CodeBadRequest, fmt.Errorf("current password is incorrect") } } @@ -578,13 +686,16 @@ func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordReques if req.NewPassword != nil { hashedPassword, err := s.HashPassword(*req.NewPassword) if err != nil { - return fmt.Errorf("failed to hash new password: %w", err) + return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err) } user.Password = &hashedPassword } // Save updated user - return s.userDAO.Update(user) + if err := s.userDAO.Update(user); err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // LoginChannel represents a login channel response @@ -595,7 +706,7 @@ type LoginChannel struct { } // GetLoginChannels gets all supported authentication channels -func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) { +func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) { cfg := server.GetConfig() channels := make([]*LoginChannel, 0) @@ -617,5 +728,5 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) { }) } - return channels, nil + return channels, common.CodeSuccess, nil } From 7004e03a2eb351251c87ab0cd5a69e9183444f7e Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Fri, 6 Mar 2026 16:56:12 +0800 Subject: [PATCH 161/565] Fix docker file (#13438) ### What problem does this PR solve? To copy infinity/resource into docker images ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: Jin Hai --- Dockerfile | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/Dockerfile b/Dockerfile index 957bb74a703..ee19086b3aa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -55,6 +55,16 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ apt install -y fonts-freefont-ttf fonts-noto-cjk && \ apt install -y postgresql-client +# Download resource from GitHub to /usr/share/infinity +RUN mkdir -p /usr/share/infinity/resource && \ + if [ "$NEED_MIRROR" == "1" ]; then \ + git clone --depth 1 --single-branch https://gitee.com/infiniflow/resource /tmp/resource; \ + else \ + git clone --depth 1 --single-branch https://github.com/infiniflow/resource.git /tmp/resource; \ + fi && \ + cp -r /tmp/resource/* /usr/share/infinity/resource && \ + rm -rf /tmp/resource + ARG NGINX_VERSION=1.29.5-1~noble RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ mkdir -p /etc/apt/keyrings && \ From 48bb917e39e0199a49e8021e835f15294a15200c Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Fri, 6 Mar 2026 17:19:51 +0800 Subject: [PATCH 162/565] Fix: paddle ocr missing outlines (#13441) ### What problem does this PR solve? Fix: paddle ocr missing outlines #13422 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- deepdoc/parser/paddleocr_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepdoc/parser/paddleocr_parser.py b/deepdoc/parser/paddleocr_parser.py index 85db63b862d..28546e1c0fc 100644 --- a/deepdoc/parser/paddleocr_parser.py +++ b/deepdoc/parser/paddleocr_parser.py @@ -199,6 +199,7 @@ def __init__( """Initialize PaddleOCR parser.""" super().__init__() + self.outlines = [] self.api_url = api_url.rstrip("/") if api_url else os.getenv("PADDLEOCR_API_URL", "") self.access_token = access_token or os.getenv("PADDLEOCR_ACCESS_TOKEN") self.algorithm = algorithm From da9c34bf2142aa7a46aa6276a91db5a2170109f2 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Fri, 6 Mar 2026 18:03:35 +0800 Subject: [PATCH 163/565] Revert aliyun registry to registry.cn-hangzhou.aliyuncs.com (#13445) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Revert aliyun registry from `infiniflow-registry.cn-shanghai.cr.aliyuncs.com` back to `registry.cn-hangzhou.aliyuncs.com` ## Test plan - [ ] Verify the docker/.env file contains the correct registry URL 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 --- README_tzh.md | 2 +- README_zh.md | 2 +- docker/.env | 4 ++-- docker/README.md | 2 +- docs/administrator/configurations.md | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README_tzh.md b/README_tzh.md index e7b21fe53ed..b4ff86c6552 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -219,7 +219,7 @@ > 如果你遇到 Docker 映像檔拉不下來的問題,可以在 **docker/.env** 檔案內根據變數 `RAGFLOW_IMAGE` 的註解提示選擇華為雲或阿里雲的對應映像。 > > - 華為雲鏡像名:`swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow` -> - 阿里雲鏡像名:`infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow` +> - 阿里雲鏡像名:`registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow` 4. 伺服器啟動成功後再次確認伺服器狀態: diff --git a/README_zh.md b/README_zh.md index ef64dc1786a..2f4679c1ec8 100644 --- a/README_zh.md +++ b/README_zh.md @@ -220,7 +220,7 @@ > 如果你遇到 Docker 镜像拉不下来的问题,可以在 **docker/.env** 文件内根据变量 `RAGFLOW_IMAGE` 的注释提示选择华为云或者阿里云的相应镜像。 > > - 华为云镜像名:`swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow` - > - 阿里云镜像名:`infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow` + > - 阿里云镜像名:`registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow` 4. 服务器启动成功后再次确认服务器状态: diff --git a/docker/.env b/docker/.env index 8e20e257bb4..79f4f91b34c 100644 --- a/docker/.env +++ b/docker/.env @@ -158,11 +158,11 @@ RAGFLOW_IMAGE=infiniflow/ragflow:v0.24.0 # If you cannot download the RAGFlow Docker image: # RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.24.0 -# RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:v0.24.0 +# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.24.0 # # - For the `nightly` edition, uncomment either of the following: # RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly -# RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly +# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly # The embedding service image, model and port. # Important: To enable the embedding service, you need to uncomment one of the following two lines: diff --git a/docker/README.md b/docker/README.md index b5f9bc66712..c6422bad8c7 100644 --- a/docker/README.md +++ b/docker/README.md @@ -87,7 +87,7 @@ The [.env](./.env) file contains important environment variables for Docker. > > - For the `nightly` edition: > - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, -> - `RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly`. +> - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. ### Timezone diff --git a/docs/administrator/configurations.md b/docs/administrator/configurations.md index 213c6d8a3ee..2178f407074 100644 --- a/docs/administrator/configurations.md +++ b/docs/administrator/configurations.md @@ -110,7 +110,7 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - For the `nightly` edition: - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, - - `RAGFLOW_IMAGE=infiniflow-registry.cn-shanghai.cr.aliyuncs.com/infiniflow/ragflow:nightly`. + - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. ::: ### Embedding service From 03ec26e3b0616654dfb15b4b29095d7f29c4419a Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Fri, 6 Mar 2026 18:16:42 +0800 Subject: [PATCH 164/565] Refa: empty ids means no-op operation (#13439) ### What problem does this PR solve? Empty ids means no-op operation. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring --------- Co-authored-by: writinwaters --- api/apps/chunk_app.py | 17 +-- api/apps/sdk/chat.py | 15 +-- api/apps/sdk/dataset.py | 37 +++--- api/apps/sdk/doc.py | 40 ++++--- api/apps/sdk/session.py | 31 ++--- docs/references/http_api_reference.md | 25 +++-- docs/references/python_api_reference.md | 33 ++++-- test/testcases/test_http_api/common.py | 106 +++++++++++++++++- test/testcases/test_http_api/conftest.py | 18 +-- .../conftest.py | 4 +- .../test_chat_sdk_routes_unit.py | 9 ++ .../test_delete_chat_assistants.py | 4 +- .../conftest.py | 4 +- .../test_delete_chunks.py | 4 +- .../test_dataset_management/conftest.py | 6 +- .../test_delete_datasets.py | 2 +- .../conftest.py | 8 +- .../test_delete_documents.py | 4 +- .../test_doc_sdk_routes_unit.py | 10 +- .../test_metadata_batch_update.py | 2 +- .../test_session_management/conftest.py | 6 +- .../test_agent_completions.py | 4 +- .../test_agent_sessions.py | 16 ++- .../test_chat_completions.py | 16 +-- .../test_chat_completions_openai.py | 8 +- ...est_delete_sessions_with_chat_assistant.py | 4 +- .../test_session_sdk_routes_unit.py | 8 ++ test/testcases/test_sdk_api/common.py | 75 +++++++++++++ test/testcases/test_sdk_api/conftest.py | 18 +-- .../conftest.py | 4 +- .../test_delete_chat_assistants.py | 4 +- .../conftest.py | 4 +- .../test_delete_chunks.py | 9 +- .../test_dataset_mangement/conftest.py | 6 +- .../test_delete_datasets.py | 2 +- .../conftest.py | 8 +- .../test_delete_documents.py | 4 +- .../test_session_management/conftest.py | 6 +- ...est_delete_sessions_with_chat_assistant.py | 9 +- .../test_chunk_app/test_chunk_routes_unit.py | 14 +++ .../test_chunk_app/test_rm_chunks.py | 2 +- .../test_dataset_sdk_routes_unit.py | 8 +- .../test_web_api/test_kb_app/conftest.py | 24 +++- 43 files changed, 447 insertions(+), 191 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 3b1c153aa8c..4d806eb32ef 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -240,6 +240,16 @@ async def rm(): req = await get_request_json() try: def _rm_sync(): + deleted_chunk_ids = req["chunk_ids"] + if isinstance(deleted_chunk_ids, list): + unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids)) + has_ids = len(unique_chunk_ids) > 0 + else: + unique_chunk_ids = [deleted_chunk_ids] + has_ids = deleted_chunk_ids not in (None, "") + if not has_ids: + return get_json_result(data=True) + e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") @@ -250,13 +260,6 @@ def _rm_sync(): doc.kb_id) except Exception: return get_data_error_result(message="Chunk deleting failure") - deleted_chunk_ids = req["chunk_ids"] - if isinstance(deleted_chunk_ids, list): - unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids)) - has_ids = len(unique_chunk_ids) > 0 - else: - unique_chunk_ids = [deleted_chunk_ids] - has_ids = deleted_chunk_ids not in (None, "") if has_ids and deleted_count == 0: return get_data_error_result(message="Index updating failure") if deleted_count > 0 and deleted_count < len(unique_chunk_ids): diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 786d1a733f7..e1142de2572 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -235,16 +235,13 @@ async def delete_chats(tenant_id): success_count = 0 req = await get_request_json() if not req: - ids = None - else: - ids = req.get("ids") + return get_result() + + ids = req.get("ids") if not ids: - id_list = [] - dias = DialogService.query(tenant_id=tenant_id, status=StatusEnum.VALID.value) - for dia in dias: - id_list.append(dia.id) - else: - id_list = ids + return get_result() + + id_list = ids unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "assistant") diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 6538d3a336c..caa75ec02b8 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -202,10 +202,8 @@ async def delete(tenant_id): items: type: string description: | - Specifies the datasets to delete: - - If `null`, all datasets will be deleted. - - If an array of IDs, only the specified datasets will be deleted. - - If an empty array, no datasets will be deleted. + List of dataset IDs to delete. + If `null` or an empty array is provided, no datasets will be deleted. responses: 200: description: Successful operation. @@ -218,22 +216,19 @@ async def delete(tenant_id): try: kb_id_instance_pairs = [] - if req["ids"] is None: - kbs = KnowledgebaseService.query(tenant_id=tenant_id) - for kb in kbs: - kb_id_instance_pairs.append((kb.id, kb)) - - else: - error_kb_ids = [] - for kb_id in req["ids"]: - kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) - if kb is None: - error_kb_ids.append(kb_id) - continue - kb_id_instance_pairs.append((kb_id, kb)) - if len(error_kb_ids) > 0: - return get_error_permission_result( - message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + if req["ids"] is None or len(req["ids"]) == 0: + return get_result() + + error_kb_ids = [] + for kb_id in req["ids"]: + kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) + if kb is None: + error_kb_ids.append(kb_id) + continue + kb_id_instance_pairs.append((kb_id, kb)) + if len(error_kb_ids) > 0: + return get_error_permission_result( + message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") errors = [] success_count = 0 @@ -811,4 +806,4 @@ def trace_raptor(tenant_id,dataset_id): if not ok: return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") - return get_result(data=task.to_dict()) \ No newline at end of file + return get_result(data=task.to_dict()) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 7c1b3aa8641..80d0a2e1eaf 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -727,7 +727,9 @@ async def delete(tenant_id, dataset_id): type: array items: type: string - description: List of document IDs to delete. + description: | + List of document IDs to delete. + If omitted, `null`, or an empty array is provided, no documents will be deleted. - in: header name: Authorization type: string @@ -743,16 +745,13 @@ async def delete(tenant_id, dataset_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") req = await get_request_json() if not req: - doc_ids = None - else: - doc_ids = req.get("ids") + return get_result() + + doc_ids = req.get("ids") if not doc_ids: - doc_list = [] - docs = DocumentService.query(kb_id=dataset_id) - for doc in docs: - doc_list.append(doc.id) - else: - doc_list = doc_ids + return get_result() + + doc_list = doc_ids unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") doc_list = unique_doc_ids @@ -1318,7 +1317,9 @@ async def rm_chunk(tenant_id, dataset_id, document_id): type: array items: type: string - description: List of chunk IDs to remove. + description: | + List of chunk IDs to remove. + If omitted, `null`, or an empty array is provided, no chunks will be deleted. - in: header name: Authorization type: string @@ -1336,17 +1337,20 @@ async def rm_chunk(tenant_id, dataset_id, document_id): if not docs: raise LookupError(f"Can't find the document with ID {document_id}!") req = await get_request_json() + if not req: + return get_result() + + chunk_ids = req.get("chunk_ids") + if not chunk_ids: + return get_result() + condition = {"doc_id": document_id} - if "chunk_ids" in req: - unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") - condition["id"] = unique_chunk_ids - else: - unique_chunk_ids = [] - duplicate_messages = [] + unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk") + condition["id"] = unique_chunk_ids chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) - if "chunk_ids" in req and chunk_number != len(unique_chunk_ids): + if chunk_number != len(unique_chunk_ids): if len(unique_chunk_ids) == 0: return get_result(message=f"deleted {chunk_number} chunks") return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}") diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 9553baf1a86..10439564d58 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -739,18 +739,14 @@ async def delete(tenant_id, chat_id): errors = [] success_count = 0 req = await get_request_json() - convs = ConversationService.query(dialog_id=chat_id) if not req: - ids = None - else: - ids = req.get("ids") + return get_result() + ids = req.get("ids") if not ids: - conv_list = [] - for conv in convs: - conv_list.append(conv.id) - else: - conv_list = ids + return get_result() + + conv_list = ids unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session") conv_list = unique_conv_ids @@ -791,21 +787,14 @@ async def delete_agent_session(tenant_id, agent_id): if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") - convs = API4ConversationService.query(dialog_id=agent_id) - if not convs: - return get_error_data_result(f"Agent {agent_id} has no sessions") - if not req: - ids = None - else: - ids = req.get("ids") + return get_result() + ids = req.get("ids") if not ids: - conv_list = [] - for conv in convs: - conv_list.append(conv.id) - else: - conv_list = ids + return get_result() + + conv_list = ids unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session") conv_list = unique_conv_ids diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 8a45106a321..a6ccf63fa6d 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -676,9 +676,8 @@ curl --request DELETE \ - `"ids"`: (*Body parameter*), `list[string]` or `null`, *Required* Specifies the datasets to delete: - - If `null`, all datasets will be deleted. - - If an array of IDs, only the specified datasets will be deleted. - - If an empty array, no datasets will be deleted. + - If omitted, or set to `null` or an empty array, no datasets are deleted. + - If an array of IDs is provided, only the datasets matching those IDs are deleted. #### Response @@ -1764,7 +1763,9 @@ curl --request DELETE \ - `dataset_id`: (*Path parameter*) The associated dataset ID. - `"ids"`: (*Body parameter*), `list[string]` - The IDs of the documents to delete. If it is not specified, all documents in the specified dataset will be deleted. + The IDs of the documents to delete. + - If omitted, or set to `null` or an empty array, no documents are deleted. + - If an array of IDs is provided, only the documents matching those IDs are deleted. #### Response @@ -2124,7 +2125,9 @@ curl --request DELETE \ - `document_ids`: (*Path parameter*) The associated document ID. - `"chunk_ids"`: (*Body parameter*), `list[string]` - The IDs of the chunks to delete. If it is not specified, all chunks of the specified document will be deleted. + The IDs of the chunks to delete. + - If omitted, or set to `null` or an empty array, no chunks are deleted. + - If an array of IDs is provided, only the chunks matching those IDs are deleted. #### Response @@ -2796,7 +2799,9 @@ curl --request DELETE \ ##### Request parameters - `"ids"`: (*Body parameter*), `list[string]` - The IDs of the chat assistants to delete. If it is not specified, all chat assistants in the system will be deleted. + The IDs of the chat assistants to delete. + - If omitted, or set to `null` or an empty array, no chat assistants are deleted. + - If an array of IDs is provided, only the chat assistants matching those IDs are deleted. #### Response @@ -3174,7 +3179,9 @@ curl --request DELETE \ - `chat_id`: (*Path parameter*) The ID of the associated chat assistant. - `"ids"`: (*Body Parameter*), `list[string]` - The IDs of the sessions to delete. If it is not specified, all sessions associated with the specified chat assistant will be deleted. + The IDs of the sessions to delete. + - If omitted, or set to `null` or an empty array, no sessions are deleted. + - If an array of IDs is provided, only the sessions matching those IDs are deleted. #### Response @@ -4538,7 +4545,9 @@ curl --request DELETE \ - `agent_id`: (*Path parameter*) The ID of the associated agent. - `"ids"`: (*Body Parameter*), `list[string]` - The IDs of the sessions to delete. If it is not specified, all sessions associated with the specified agent will be deleted. + The IDs of the sessions to delete. + - If omitted, or set to `null` or an empty array, no sessions are deleted. + - If an array of IDs is provided, only the sessions matching those IDs are deleted. #### Response diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 80a8666e9bd..430e58a0f6f 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -240,9 +240,9 @@ Deletes datasets by ID. ##### ids: `list[str]` or `None`, *Required* The IDs of the datasets to delete. Defaults to `None`. - - If `None`, all datasets will be deleted. - - If an array of IDs, only the specified datasets will be deleted. - - If an empty array, no datasets will be deleted. + +- If omitted, or set to `null` or an empty array, no datasets are deleted. +- If an array of IDs is provided, only the datasets matching those IDs are deleted. #### Returns @@ -661,9 +661,12 @@ Deletes documents by ID. #### Parameters -##### ids: `list[list]` +##### ids: `list[str]` or `None` + +The IDs of the documents to delete. Defaults to `None`. -The IDs of the documents to delete. Defaults to `None`. If it is not specified, all documents in the dataset will be deleted. +- If omitted, or set to `null` or an empty array, no documents are deleted. +- If an array of IDs is provided, only the documents matching those IDs are deleted. #### Returns @@ -931,7 +934,10 @@ Deletes chunks by ID. ##### chunk_ids: `list[str]` -The IDs of the chunks to delete. Defaults to `None`. If it is not specified, all chunks of the current document will be deleted. +The IDs of the chunks to delete. Defaults to `None`. + +- If omitted, or set to `null` or an empty array, no chunks are deleted. +- If an array of IDs is provided, only the chunks matching those IDs are deleted. #### Returns @@ -1234,7 +1240,10 @@ Deletes chat assistants by ID. ##### ids: `list[str]` -The IDs of the chat assistants to delete. Defaults to `None`. If it is empty or not specified, all chat assistants in the system will be deleted. +The IDs of the chat assistants to delete. Defaults to `None`. + +- If omitted, or set to `null` or an empty array, no chat assistants are deleted. +- If an array of IDs is provided, only the chat assistants matching those IDs are deleted. #### Returns @@ -1463,7 +1472,10 @@ Deletes sessions of the current chat assistant by ID. ##### ids: `list[str]` -The IDs of the sessions to delete. Defaults to `None`. If it is not specified, all sessions associated with the current chat assistant will be deleted. +The IDs of the sessions to delete. Defaults to `None`. + +- If omitted, or set to `null` or an empty array, no sessions are deleted. +- If an array of IDs is provided, only the sessions matching those IDs are deleted. #### Returns @@ -1781,7 +1793,10 @@ Deletes sessions of an agent by ID. ##### ids: `list[str]` -The IDs of the sessions to delete. Defaults to `None`. If it is not specified, all sessions associated with the agent will be deleted. +The IDs of the sessions to delete. Defaults to `None`. + +- If omitted, or set to `null` or an empty array, no sessions are deleted. +- If an array of IDs is provided, only the sessions matching those IDs are deleted. #### Returns diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 4e27f74a1e0..d6334543db3 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -58,6 +58,23 @@ def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): return res.json() +def delete_all_datasets(auth, *, page_size=1000): + # Dataset DELETE now treats null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + dataset_ids = [] + while True: + res = list_datasets(auth, {"page": page, "page_size": page_size}) + data = res.get("data") or [] + dataset_ids.extend(dataset["id"] for dataset in data) + if len(data) < page_size: + break + page += 1 + + if not dataset_ids: + return {"code": 0, "message": ""} + return delete_datasets(auth, {"ids": dataset_ids}) + + def batch_create_datasets(auth, num): ids = [] for i in range(num): @@ -127,6 +144,23 @@ def delete_documents(auth, dataset_id, payload=None): return res.json() +def delete_all_documents(auth, dataset_id, *, page_size=1000): + # Document DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + document_ids = [] + while True: + res = list_documents(auth, dataset_id, {"page": page, "page_size": page_size}) + docs = (res.get("data") or {}).get("docs") or [] + document_ids.extend(doc["id"] for doc in docs) + if len(docs) < page_size: + break + page += 1 + + if not document_ids: + return {"code": 0, "message": ""} + return delete_documents(auth, dataset_id, {"ids": document_ids}) + + def parse_documents(auth, dataset_id, payload=None): url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) @@ -176,6 +210,23 @@ def delete_chunks(auth, dataset_id, document_id, payload=None): return res.json() +def delete_all_chunks(auth, dataset_id, document_id, *, page_size=1000): + # Chunk DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + chunk_ids = [] + while True: + res = list_chunks(auth, dataset_id, document_id, {"page": page, "page_size": page_size}) + chunks = (res.get("data") or {}).get("chunks") or [] + chunk_ids.extend(chunk["id"] for chunk in chunks) + if len(chunks) < page_size: + break + page += 1 + + if not chunk_ids: + return {"code": 0, "message": ""} + return delete_chunks(auth, dataset_id, document_id, {"chunk_ids": chunk_ids}) + + def retrieval_chunks(auth, payload=None): url = f"{HOST_ADDRESS}{RETRIEVAL_API_URL}" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) @@ -215,6 +266,23 @@ def delete_chat_assistants(auth, payload=None): return res.json() +def delete_all_chat_assistants(auth, *, page_size=1000): + # Chat DELETE now treats null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + chat_ids = [] + while True: + res = list_chat_assistants(auth, {"page": page, "page_size": page_size}) + data = res.get("data") or [] + chat_ids.extend(chat["id"] for chat in data) + if len(data) < page_size: + break + page += 1 + + if not chat_ids: + return {"code": 0, "message": ""} + return delete_chat_assistants(auth, {"ids": chat_ids}) + + def batch_create_chat_assistants(auth, num): chat_assistant_ids = [] for i in range(num): @@ -244,12 +312,27 @@ def update_session_with_chat_assistant(auth, chat_assistant_id, session_id, payl def delete_session_with_chat_assistants(auth, chat_assistant_id, payload=None): url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) - if payload is None: - payload = {} res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() +def delete_all_sessions_with_chat_assistant(auth, chat_assistant_id, *, page_size=1000): + # Session DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + session_ids = [] + while True: + res = list_session_with_chat_assistants(auth, chat_assistant_id, {"page": page, "page_size": page_size}) + data = res.get("data") or [] + session_ids.extend(session["id"] for session in data) + if len(data) < page_size: + break + page += 1 + + if not session_ids: + return {"code": 0, "message": ""} + return delete_session_with_chat_assistants(auth, chat_assistant_id, {"ids": session_ids}) + + def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num): session_ids = [] for i in range(num): @@ -350,12 +433,27 @@ def list_agent_sessions(auth, agent_id, params=None): def delete_agent_sessions(auth, agent_id, payload=None): url = f"{HOST_ADDRESS}{SESSION_WITH_AGENT_API_URL}".format(agent_id=agent_id) - if payload is None: - payload = {} res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() +def delete_all_agent_sessions(auth, agent_id, *, page_size=1000): + # Agent session DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + session_ids = [] + while True: + res = list_agent_sessions(auth, agent_id, {"page": page, "page_size": page_size}) + data = res.get("data") or [] + session_ids.extend(session["id"] for session in data) + if len(data) < page_size: + break + page += 1 + + if not session_ids: + return {"code": 0, "message": ""} + return delete_agent_sessions(auth, agent_id, {"ids": session_ids}) + + def agent_completions(auth, agent_id, payload=None): url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) diff --git a/test/testcases/test_http_api/conftest.py b/test/testcases/test_http_api/conftest.py index eab05d09bcc..d3c571a6f07 100644 --- a/test/testcases/test_http_api/conftest.py +++ b/test/testcases/test_http_api/conftest.py @@ -21,9 +21,9 @@ batch_create_chat_assistants, batch_create_datasets, bulk_upload_documents, - delete_chat_assistants, - delete_datasets, - delete_session_with_chat_assistants, + delete_all_chat_assistants, + delete_all_datasets, + delete_all_sessions_with_chat_assistant, list_documents, parse_documents, ) @@ -89,7 +89,7 @@ def HttpApiAuth(token): @pytest.fixture(scope="function") def clear_datasets(request, HttpApiAuth): def cleanup(): - delete_datasets(HttpApiAuth, {"ids": None}) + delete_all_datasets(HttpApiAuth) request.addfinalizer(cleanup) @@ -97,7 +97,7 @@ def cleanup(): @pytest.fixture(scope="function") def clear_chat_assistants(request, HttpApiAuth): def cleanup(): - delete_chat_assistants(HttpApiAuth) + delete_all_chat_assistants(HttpApiAuth) request.addfinalizer(cleanup) @@ -106,7 +106,7 @@ def cleanup(): def clear_session_with_chat_assistants(request, HttpApiAuth, add_chat_assistants): def cleanup(): for chat_assistant_id in chat_assistant_ids: - delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_id) request.addfinalizer(cleanup) @@ -116,7 +116,7 @@ def cleanup(): @pytest.fixture(scope="class") def add_dataset(request, HttpApiAuth): def cleanup(): - delete_datasets(HttpApiAuth, {"ids": None}) + delete_all_datasets(HttpApiAuth) request.addfinalizer(cleanup) @@ -127,7 +127,7 @@ def cleanup(): @pytest.fixture(scope="function") def add_dataset_func(request, HttpApiAuth): def cleanup(): - delete_datasets(HttpApiAuth, {"ids": None}) + delete_all_datasets(HttpApiAuth) request.addfinalizer(cleanup) @@ -154,7 +154,7 @@ def add_chunks(HttpApiAuth, add_document): @pytest.fixture(scope="class") def add_chat_assistants(request, HttpApiAuth, add_document): def cleanup(): - delete_chat_assistants(HttpApiAuth) + delete_all_chat_assistants(HttpApiAuth) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py index 772c0788ba1..b81b48edcf2 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_chat_assistants, delete_chat_assistants, list_chat_assistants, list_documents, parse_documents +from common import batch_create_chat_assistants, delete_all_chat_assistants, list_chat_assistants, list_documents, parse_documents from utils import wait_for @@ -30,7 +30,7 @@ def condition(_auth, _dataset_id): @pytest.fixture(scope="function") def add_chat_assistants_func(request, HttpApiAuth, add_document): def cleanup(): - delete_chat_assistants(HttpApiAuth) + delete_all_chat_assistants(HttpApiAuth) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index cb3ca0ae824..5ca56b92584 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -299,6 +299,15 @@ def test_update_internal_failure_paths(monkeypatch): def test_delete_duplicate_no_success_path(monkeypatch): module = _load_chat_module(monkeypatch) + _set_request_json(monkeypatch, module, {}) + monkeypatch.setattr( + module.DialogService, + "query", + lambda **_kwargs: (_ for _ in ()).throw(AssertionError("query must not run for empty delete payload")), + ) + res = _run(module.delete_chats.__wrapped__("tenant-1")) + assert res["code"] == module.RetCode.SUCCESS + _set_request_json(monkeypatch, module, {"ids": ["chat-1", "chat-1"]}) monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: 0) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py index 670ab04d34b..172c66492b0 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py @@ -44,8 +44,8 @@ class TestChatAssistantsDelete: @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ - pytest.param(None, 0, "", 0, marks=pytest.mark.p3), - pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3), + pytest.param(None, 0, "", 5, marks=pytest.mark.p3), + pytest.param({"ids": []}, 0, "", 5, marks=pytest.mark.p3), pytest.param({"ids": ["invalid_id"]}, 102, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3), pytest.param({"ids": ["\n!?。;!?\"'"]}, 102, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3), pytest.param("not json", 100, "AttributeError(\"'str' object has no attribute 'get'\")", 5, marks=pytest.mark.p3), diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py index 7a06a23eb57..48487ee9ea6 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py @@ -18,7 +18,7 @@ from time import sleep import pytest -from common import batch_add_chunks, delete_chunks, list_documents, parse_documents +from common import batch_add_chunks, delete_all_chunks, list_documents, parse_documents from utils import wait_for @@ -34,7 +34,7 @@ def condition(_auth, _dataset_id): @pytest.fixture(scope="function") def add_chunks_func(request, HttpApiAuth, add_document): def cleanup(): - delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": []}) + delete_all_chunks(HttpApiAuth, dataset_id, document_id) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py index 580a2974c26..eae75afada6 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -158,12 +158,12 @@ def test_delete_1k(self, HttpApiAuth, add_document): @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ - pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip), + pytest.param(None, 0, "", 5, marks=pytest.mark.p3), pytest.param({"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), pytest.param(lambda r: {"chunk_ids": r}, 0, "", 1, marks=pytest.mark.p1), - pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), + pytest.param({"chunk_ids": []}, 0, "", 5, marks=pytest.mark.p3), ], ) def test_basic_scenarios( diff --git a/test/testcases/test_http_api/test_dataset_management/conftest.py b/test/testcases/test_http_api/test_dataset_management/conftest.py index d4ef989ff7a..3e03e50b984 100644 --- a/test/testcases/test_http_api/test_dataset_management/conftest.py +++ b/test/testcases/test_http_api/test_dataset_management/conftest.py @@ -16,13 +16,13 @@ import pytest -from common import batch_create_datasets, delete_datasets +from common import batch_create_datasets, delete_all_datasets @pytest.fixture(scope="class") def add_datasets(HttpApiAuth, request): def cleanup(): - delete_datasets(HttpApiAuth, {"ids": None}) + delete_all_datasets(HttpApiAuth) request.addfinalizer(cleanup) @@ -32,7 +32,7 @@ def cleanup(): @pytest.fixture(scope="function") def add_datasets_func(HttpApiAuth, request): def cleanup(): - delete_datasets(HttpApiAuth, {"ids": None}) + delete_all_datasets(HttpApiAuth) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py index f8327704ead..0240857414a 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py @@ -134,7 +134,7 @@ def test_ids_none(self, HttpApiAuth): assert res["code"] == 0, res res = list_datasets(HttpApiAuth) - assert len(res["data"]) == 0, res + assert len(res["data"]) == 3, res @pytest.mark.p2 @pytest.mark.usefixtures("add_dataset_func") diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py b/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py index cd1014382e8..efbbd5d43a9 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py @@ -16,13 +16,13 @@ import pytest -from common import bulk_upload_documents, delete_documents +from common import bulk_upload_documents, delete_all_documents @pytest.fixture(scope="function") def add_document_func(request, HttpApiAuth, add_dataset, ragflow_tmp_dir): def cleanup(): - delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + delete_all_documents(HttpApiAuth, dataset_id) request.addfinalizer(cleanup) @@ -33,7 +33,7 @@ def cleanup(): @pytest.fixture(scope="class") def add_documents(request, HttpApiAuth, add_dataset, ragflow_tmp_dir): def cleanup(): - delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + delete_all_documents(HttpApiAuth, dataset_id) request.addfinalizer(cleanup) @@ -44,7 +44,7 @@ def cleanup(): @pytest.fixture(scope="function") def add_documents_func(request, HttpApiAuth, add_dataset_func, ragflow_tmp_dir): def cleanup(): - delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + delete_all_documents(HttpApiAuth, dataset_id) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py index 74f5c060639..133a05df6a0 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py @@ -45,8 +45,8 @@ class TestDocumentsDeletion: @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ - (None, 0, "", 0), - ({"ids": []}, 0, "", 0), + (None, 0, "", 3), + ({"ids": []}, 0, "", 3), ({"ids": ["invalid_id"]}, 102, "Documents not found: ['invalid_id']", 3), ( {"ids": ["\n!?。;!?\"'"]}, diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 23ac8fcf670..872563ccaeb 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -692,6 +692,10 @@ def test_delete_branches(self, monkeypatch): assert "don't own the dataset" in res["message"] monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: True) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) + res = _run(module.delete.__wrapped__("tenant-1", "ds-1")) + assert res["code"] == module.RetCode.SUCCESS + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"ids": ["doc-1"]})) monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) monkeypatch.setattr(module.FileService, "get_root_folder", lambda _tenant: {"id": "pf-1"}) @@ -871,7 +875,11 @@ def test_rm_chunk_branches(self, monkeypatch): monkeypatch.setattr(module.DocumentService, "get_by_ids", lambda _ids: [_DummyDoc()]) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) - _patch_docstore(monkeypatch, module, delete=lambda *_args, **_kwargs: 2) + _patch_docstore( + monkeypatch, + module, + delete=lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("delete must not run for empty chunk ids")), + ) monkeypatch.setattr(module.DocumentService, "decrement_chunk_num", lambda *_args, **_kwargs: None) res = _run(module.rm_chunk.__wrapped__("tenant-1", "ds-1", "doc-1")) assert res["code"] == 0 diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py index 27b74d42f6f..9061ba39025 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py @@ -63,4 +63,4 @@ def test_batch_update_metadata(self, HttpApiAuth, add_dataset, ragflow_tmp_dir): assert doc["meta_fields"].get("status") == "processed", f"Expected status='processed', got {doc['meta_fields'].get('status')}" # Cleanup - delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids}) diff --git a/test/testcases/test_http_api/test_session_management/conftest.py b/test/testcases/test_http_api/test_session_management/conftest.py index 56eafab0aab..3bae723954d 100644 --- a/test/testcases/test_http_api/test_session_management/conftest.py +++ b/test/testcases/test_http_api/test_session_management/conftest.py @@ -14,14 +14,14 @@ # limitations under the License. # import pytest -from common import batch_add_sessions_with_chat_assistant, delete_session_with_chat_assistants +from common import batch_add_sessions_with_chat_assistant, delete_all_sessions_with_chat_assistant @pytest.fixture(scope="class") def add_sessions_with_chat_assistant(request, HttpApiAuth, add_chat_assistants): def cleanup(): for chat_assistant_id in chat_assistant_ids: - delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_id) request.addfinalizer(cleanup) @@ -33,7 +33,7 @@ def cleanup(): def add_sessions_with_chat_assistant_func(request, HttpApiAuth, add_chat_assistants): def cleanup(): for chat_assistant_id in chat_assistant_ids: - delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_id) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_session_management/test_agent_completions.py b/test/testcases/test_http_api/test_session_management/test_agent_completions.py index e34cc21eca6..bb65fd9f255 100644 --- a/test/testcases/test_http_api/test_session_management/test_agent_completions.py +++ b/test/testcases/test_http_api/test_session_management/test_agent_completions.py @@ -19,7 +19,7 @@ create_agent, create_agent_session, delete_agent, - delete_agent_sessions, + delete_all_agent_sessions, list_agents, ) @@ -65,7 +65,7 @@ def agent_id(HttpApiAuth, request): agent_id = res["data"][0]["id"] def cleanup(): - delete_agent_sessions(HttpApiAuth, agent_id) + delete_all_agent_sessions(HttpApiAuth, agent_id) delete_agent(HttpApiAuth, agent_id) request.addfinalizer(cleanup) diff --git a/test/testcases/test_http_api/test_session_management/test_agent_sessions.py b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py index cfcc1807be6..883ae2af07b 100644 --- a/test/testcases/test_http_api/test_session_management/test_agent_sessions.py +++ b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py @@ -19,6 +19,7 @@ create_agent, create_agent_session, delete_agent, + delete_all_agent_sessions, delete_agent_sessions, list_agent_sessions, list_agents, @@ -67,7 +68,7 @@ def agent_id(HttpApiAuth, request): agent_id = res["data"][0]["id"] def cleanup(): - delete_agent_sessions(HttpApiAuth, agent_id) + delete_all_agent_sessions(HttpApiAuth, agent_id) delete_agent(HttpApiAuth, agent_id) request.addfinalizer(cleanup) @@ -75,6 +76,19 @@ def cleanup(): class TestAgentSessions: + @pytest.mark.p2 + def test_delete_agent_sessions_empty_ids_noop(self, HttpApiAuth, agent_id): + res = create_agent_session(HttpApiAuth, agent_id, payload={}) + assert res["code"] == 0, res + session_id = res["data"]["id"] + + res = delete_agent_sessions(HttpApiAuth, agent_id, {"ids": []}) + assert res["code"] == 0, res + + res = list_agent_sessions(HttpApiAuth, agent_id, params={"id": session_id}) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + @pytest.mark.p2 def test_create_list_delete_agent_sessions(self, HttpApiAuth, agent_id): res = create_agent_session(HttpApiAuth, agent_id, payload={}) diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions.py b/test/testcases/test_http_api/test_session_management/test_chat_completions.py index fa2e225ca6f..000a9058568 100644 --- a/test/testcases/test_http_api/test_session_management/test_chat_completions.py +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions.py @@ -19,8 +19,8 @@ chat_completions, create_chat_assistant, create_session_with_chat_assistant, - delete_chat_assistants, - delete_session_with_chat_assistants, + delete_all_chat_assistants, + delete_all_sessions_with_chat_assistant, list_documents, parse_documents, ) @@ -52,8 +52,8 @@ def test_chat_completion_stream_false_with_session(self, HttpApiAuth, add_datase res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_test", "dataset_ids": [dataset_id]}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_id)) res = create_session_with_chat_assistant(HttpApiAuth, chat_id, {"name": "session_for_completion"}) assert res["code"] == 0, res @@ -85,8 +85,8 @@ def test_chat_completion_invalid_session(self, HttpApiAuth, request): res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_invalid_session", "dataset_ids": []}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_id)) res = chat_completions( HttpApiAuth, @@ -101,8 +101,8 @@ def test_chat_completion_invalid_metadata_condition(self, HttpApiAuth, request): res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_invalid_meta", "dataset_ids": []}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_sessions_with_chat_assistant(HttpApiAuth, chat_id)) res = create_session_with_chat_assistant(HttpApiAuth, chat_id, {"name": "session_for_meta"}) assert res["code"] == 0, res diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py index ffaa3ee4513..54d5fe29d46 100644 --- a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py @@ -18,7 +18,7 @@ bulk_upload_documents, chat_completions_openai, create_chat_assistant, - delete_chat_assistants, + delete_all_chat_assistants, list_documents, parse_documents, ) @@ -53,7 +53,7 @@ def test_openai_chat_completion_non_stream(self, HttpApiAuth, add_dataset_func, res = create_chat_assistant(HttpApiAuth, {"name": "openai_endpoint_test", "dataset_ids": [dataset_id]}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) res = chat_completions_openai( HttpApiAuth, @@ -92,7 +92,7 @@ def test_openai_chat_completion_token_count_reasonable(self, HttpApiAuth, add_da res = create_chat_assistant(HttpApiAuth, {"name": "openai_token_count_test", "dataset_ids": [dataset_id]}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) # Use a message with known token count # "hello" is 1 token in cl100k_base encoding @@ -202,7 +202,7 @@ def test_openai_chat_completion_request_validation( res = create_chat_assistant(HttpApiAuth, {"name": "openai_validation_case", "dataset_ids": []}) assert res["code"] == 0, res chat_id = res["data"]["id"] - request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + request.addfinalizer(lambda: delete_all_chat_assistants(HttpApiAuth)) res = chat_completions_openai(HttpApiAuth, chat_id, payload) assert res.get("code") != 0, res diff --git a/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py index 818050819b2..637cb1f1d10 100644 --- a/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py @@ -141,12 +141,12 @@ def test_delete_1k(self, HttpApiAuth, add_chat_assistants): @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ - pytest.param(None, 0, """TypeError("argument of type \'NoneType\' is not iterable")""", 0, marks=pytest.mark.skip), + pytest.param(None, 0, "", 5, marks=pytest.mark.p3), pytest.param({"ids": ["invalid_id"]}, 102, "The chat doesn't own the session invalid_id", 5, marks=pytest.mark.p3), pytest.param("not json", 100, """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip), pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1), - pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": []}, 0, "", 5, marks=pytest.mark.p3), ], ) def test_basic_scenarios( diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index c1bbd10347c..6852024db30 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -985,6 +985,10 @@ def test_delete_routes_partial_duplicate_unit(monkeypatch): module = _load_session_module(monkeypatch) monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) + res = _run(inspect.unwrap(module.delete)("tenant-1", "chat-1")) + assert res["code"] == 0 + monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda *_args, **_kwargs: True) def _conversation_query(**kwargs): @@ -1016,6 +1020,10 @@ def _conversation_query(**kwargs): assert res["data"]["errors"] == ["Duplicate session ids: ok"] monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: [SimpleNamespace(id="agent-1")]) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) + res = _run(inspect.unwrap(module.delete_agent_session)("tenant-1", "agent-1")) + assert res["code"] == 0 + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"ids": ["session-1"]})) monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) diff --git a/test/testcases/test_sdk_api/common.py b/test/testcases/test_sdk_api/common.py index 3035383a472..84354fc91a6 100644 --- a/test/testcases/test_sdk_api/common.py +++ b/test/testcases/test_sdk_api/common.py @@ -25,6 +25,36 @@ def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]: return [client.create_dataset(name=f"dataset_{i}") for i in range(num)] +def delete_all_datasets(client: RAGFlow, *, page_size: int = 1000) -> None: + # Dataset DELETE now treats null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + dataset_ids: list[str] = [] + while True: + datasets = client.list_datasets(page=page, page_size=page_size) + dataset_ids.extend(dataset.id for dataset in datasets) + if len(datasets) < page_size: + break + page += 1 + + if dataset_ids: + client.delete_datasets(ids=dataset_ids) + + +def delete_all_chats(client: RAGFlow, *, page_size: int = 1000) -> None: + # Chat DELETE now treats null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + chat_ids: list[str] = [] + while True: + chats = client.list_chats(page=page, page_size=page_size) + chat_ids.extend(chat.id for chat in chats) + if len(chats) < page_size: + break + page += 1 + + if chat_ids: + client.delete_chats(ids=chat_ids) + + # FILE MANAGEMENT WITHIN DATASET def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Document]: document_infos = [] @@ -37,6 +67,51 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do return dataset.upload_documents(document_infos) +def delete_all_documents(dataset: DataSet, *, page_size: int = 1000) -> None: + # Document DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + document_ids: list[str] = [] + while True: + documents = dataset.list_documents(page=page, page_size=page_size) + document_ids.extend(document.id for document in documents) + if len(documents) < page_size: + break + page += 1 + + if document_ids: + dataset.delete_documents(ids=document_ids) + + +def delete_all_sessions(chat_assistant: Chat, *, page_size: int = 1000) -> None: + # Session DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + session_ids: list[str] = [] + while True: + sessions = chat_assistant.list_sessions(page=page, page_size=page_size) + session_ids.extend(session.id for session in sessions) + if len(sessions) < page_size: + break + page += 1 + + if session_ids: + chat_assistant.delete_sessions(ids=session_ids) + + +def delete_all_chunks(document: Document, *, page_size: int = 1000) -> None: + # Chunk DELETE now treats missing/null/empty ids as a no-op, so cleanup must enumerate explicit ids. + page = 1 + chunk_ids: list[str] = [] + while True: + chunks = document.list_chunks(page=page, page_size=page_size) + chunk_ids.extend(chunk.id for chunk in chunks) + if len(chunks) < page_size: + break + page += 1 + + if chunk_ids: + document.delete_chunks(ids=chunk_ids) + + # CHUNK MANAGEMENT WITHIN DATASET def batch_add_chunks(document: Document, num: int) -> list[Chunk]: return [document.add_chunk(content=f"chunk test {i}") for i in range(num)] diff --git a/test/testcases/test_sdk_api/conftest.py b/test/testcases/test_sdk_api/conftest.py index 11a258a5ad1..f4791306ccf 100644 --- a/test/testcases/test_sdk_api/conftest.py +++ b/test/testcases/test_sdk_api/conftest.py @@ -23,6 +23,10 @@ batch_create_chat_assistants, batch_create_datasets, bulk_upload_documents, + delete_all_chats, + delete_all_chunks, + delete_all_datasets, + delete_all_sessions, ) from configs import HOST_ADDRESS, VERSION from pytest import FixtureRequest @@ -88,7 +92,7 @@ def client(token: str) -> RAGFlow: @pytest.fixture(scope="function") def clear_datasets(request: FixtureRequest, client: RAGFlow): def cleanup(): - client.delete_datasets(ids=None) + delete_all_datasets(client) request.addfinalizer(cleanup) @@ -96,7 +100,7 @@ def cleanup(): @pytest.fixture(scope="function") def clear_chat_assistants(request: FixtureRequest, client: RAGFlow): def cleanup(): - client.delete_chats(ids=None) + delete_all_chats(client) request.addfinalizer(cleanup) @@ -106,7 +110,7 @@ def clear_session_with_chat_assistants(request, add_chat_assistants): def cleanup(): for chat_assistant in chat_assistants: try: - chat_assistant.delete_sessions(ids=None) + delete_all_sessions(chat_assistant) except Exception: pass @@ -118,7 +122,7 @@ def cleanup(): @pytest.fixture(scope="class") def add_dataset(request: FixtureRequest, client: RAGFlow) -> DataSet: def cleanup(): - client.delete_datasets(ids=None) + delete_all_datasets(client) request.addfinalizer(cleanup) return batch_create_datasets(client, 1)[0] @@ -127,7 +131,7 @@ def cleanup(): @pytest.fixture(scope="function") def add_dataset_func(request: FixtureRequest, client: RAGFlow) -> DataSet: def cleanup(): - client.delete_datasets(ids=None) + delete_all_datasets(client) request.addfinalizer(cleanup) return batch_create_datasets(client, 1)[0] @@ -142,7 +146,7 @@ def add_document(add_dataset: DataSet, ragflow_tmp_dir: Path) -> tuple[DataSet, def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: def cleanup(): try: - document.delete_chunks(ids=[]) + delete_all_chunks(document) except Exception: pass @@ -161,7 +165,7 @@ def cleanup(): def add_chat_assistants(request, client, add_document) -> tuple[DataSet, Document, list[Chat]]: def cleanup(): try: - client.delete_chats(ids=None) + delete_all_chats(client) except Exception: pass diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py index 79347d67a99..c02065061ae 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_chat_assistants +from common import batch_create_chat_assistants, delete_all_chats from pytest import FixtureRequest from ragflow_sdk import Chat, DataSet, Document, RAGFlow from utils import wait_for @@ -32,7 +32,7 @@ def condition(_dataset: DataSet): @pytest.fixture(scope="function") def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chat]]: def cleanup(): - client.delete_chats(ids=None) + delete_all_chats(client) request.addfinalizer(cleanup) diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py index 7f720330968..936d9cf5bd9 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py @@ -23,8 +23,8 @@ class TestChatAssistantsDelete: @pytest.mark.parametrize( "payload, expected_message, remaining", [ - pytest.param(None, "", 0, marks=pytest.mark.p3), - pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + pytest.param(None, "", 5, marks=pytest.mark.p3), + pytest.param({"ids": []}, "", 5, marks=pytest.mark.p3), pytest.param({"ids": ["invalid_id"]}, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3), pytest.param({"ids": ["\n!?。;!?\"'"]}, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3), pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py index d9ed678387f..835662d7ae2 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py @@ -18,7 +18,7 @@ from time import sleep import pytest -from common import batch_add_chunks +from common import batch_add_chunks, delete_all_chunks from pytest import FixtureRequest from ragflow_sdk import Chunk, DataSet, Document from utils import wait_for @@ -37,7 +37,7 @@ def condition(_dataset: DataSet): def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: def cleanup(): try: - document.delete_chunks(ids=[]) + delete_all_chunks(document) except Exception: pass diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py index 319dac0e861..4fd59f01f7a 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -88,12 +88,12 @@ def test_delete_1k(self, add_document): @pytest.mark.parametrize( "payload, expected_message, remaining", [ - pytest.param(None, "TypeError", 5, marks=pytest.mark.skip), + pytest.param(None, "", 5, marks=pytest.mark.p3), pytest.param({"ids": ["invalid_id"]}, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), pytest.param("not json", "UnboundLocalError", 5, marks=pytest.mark.skip(reason="pull/6376")), pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), pytest.param(lambda r: {"ids": r}, "", 1, marks=pytest.mark.p1), - pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": []}, "", 5, marks=pytest.mark.p3), ], ) def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remaining): @@ -107,7 +107,10 @@ def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remai document.delete_chunks(**payload) assert expected_message in str(exception_info.value), str(exception_info.value) else: - document.delete_chunks(**payload) + if payload is None: + document.delete_chunks() + else: + document.delete_chunks(**payload) remaining_chunks = document.list_chunks() assert len(remaining_chunks) == remaining, str(remaining_chunks) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py b/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py index 8d53eac2ee8..998af94995e 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py @@ -16,13 +16,13 @@ import pytest -from common import batch_create_datasets +from common import batch_create_datasets, delete_all_datasets @pytest.fixture(scope="class") def add_datasets(client, request): def cleanup(): - client.delete_datasets(**{"ids": None}) + delete_all_datasets(client) request.addfinalizer(cleanup) @@ -32,7 +32,7 @@ def cleanup(): @pytest.fixture(scope="function") def add_datasets_func(client, request): def cleanup(): - client.delete_datasets(**{"ids": None}) + delete_all_datasets(client) request.addfinalizer(cleanup) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py index d9a9069f4e1..dbf0e588ed8 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py @@ -95,7 +95,7 @@ def test_ids_none(self, client): client.delete_datasets(**payload) datasets = client.list_datasets() - assert len(datasets) == 0, str(datasets) + assert len(datasets) == 3, str(datasets) @pytest.mark.p2 @pytest.mark.usefixtures("add_dataset_func") diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py index 32be9683a5b..b60f5f2886c 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py @@ -16,7 +16,7 @@ import pytest -from common import bulk_upload_documents +from common import bulk_upload_documents, delete_all_documents from pytest import FixtureRequest from ragflow_sdk import DataSet, Document @@ -27,7 +27,7 @@ def add_document_func(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp documents = bulk_upload_documents(dataset, 1, ragflow_tmp_dir) def cleanup(): - dataset.delete_documents(ids=None) + delete_all_documents(dataset) request.addfinalizer(cleanup) return dataset, documents[0] @@ -39,7 +39,7 @@ def add_documents(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir documents = bulk_upload_documents(dataset, 5, ragflow_tmp_dir) def cleanup(): - dataset.delete_documents(ids=None) + delete_all_documents(dataset) request.addfinalizer(cleanup) return dataset, documents @@ -51,7 +51,7 @@ def add_documents_func(request: FixtureRequest, add_dataset_func: DataSet, ragfl documents = bulk_upload_documents(dataset, 3, ragflow_tmp_dir) def cleanup(): - dataset.delete_documents(ids=None) + delete_all_documents(dataset) request.addfinalizer(cleanup) return dataset, documents diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py index 35f146a4de2..9fa9d3b1e0b 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py @@ -24,8 +24,8 @@ class TestDocumentsDeletion: @pytest.mark.parametrize( "payload, expected_message, remaining", [ - ({"ids": None}, "", 0), - ({"ids": []}, "", 0), + ({"ids": None}, "", 3), + ({"ids": []}, "", 3), ({"ids": ["invalid_id"]}, "Documents not found: ['invalid_id']", 3), ({"ids": ["\n!?。;!?\"'"]}, "Documents not found: ['\\n!?。;!?\"\\'']", 3), ("not json", "must be a mapping", 3), diff --git a/test/testcases/test_sdk_api/test_session_management/conftest.py b/test/testcases/test_sdk_api/test_session_management/conftest.py index 3f1289ed602..7361b34849d 100644 --- a/test/testcases/test_sdk_api/test_session_management/conftest.py +++ b/test/testcases/test_sdk_api/test_session_management/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_add_sessions_with_chat_assistant +from common import batch_add_sessions_with_chat_assistant, delete_all_sessions from pytest import FixtureRequest from ragflow_sdk import Chat, DataSet, Document, Session @@ -24,7 +24,7 @@ def add_sessions_with_chat_assistant(request: FixtureRequest, add_chat_assistant def cleanup(): for chat_assistant in chat_assistants: try: - chat_assistant.delete_sessions(ids=None) + delete_all_sessions(chat_assistant) except Exception : pass @@ -39,7 +39,7 @@ def add_sessions_with_chat_assistant_func(request: FixtureRequest, add_chat_assi def cleanup(): for chat_assistant in chat_assistants: try: - chat_assistant.delete_sessions(ids=None) + delete_all_sessions(chat_assistant) except Exception : pass diff --git a/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py index 5d118af6c27..e88b74c4c68 100644 --- a/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py @@ -84,12 +84,12 @@ def test_delete_1k(self, add_chat_assistants): @pytest.mark.parametrize( "payload, expected_message, remaining", [ - pytest.param(None, """TypeError("argument of type \'NoneType\' is not iterable")""", 0, marks=pytest.mark.skip), + pytest.param(None, "", 5, marks=pytest.mark.p3), pytest.param({"ids": ["invalid_id"]}, "The chat doesn't own the session invalid_id", 5, marks=pytest.mark.p3), pytest.param("not json", """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip), pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), pytest.param(lambda r: {"ids": r}, "", 0, marks=pytest.mark.p1), - pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": []}, "", 5, marks=pytest.mark.p3), ], ) def test_basic_scenarios(self, add_sessions_with_chat_assistant_func, payload, expected_message, remaining): @@ -102,7 +102,10 @@ def test_basic_scenarios(self, add_sessions_with_chat_assistant_func, payload, e chat_assistant.delete_sessions(**payload) assert expected_message in str(exception_info.value) else: - chat_assistant.delete_sessions(**payload) + if payload is None: + chat_assistant.delete_sessions() + else: + chat_assistant.delete_sessions(**payload) sessions = chat_assistant.list_sessions() assert len(sessions) == remaining diff --git a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py index 5182500841c..5837b3ff077 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py +++ b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py @@ -673,6 +673,20 @@ def test_rm_chunk_delete_exception_partial_compensation_and_cleanup_unit(monkeyp res = _run(module.rm()) assert res["message"] == "Document not found!", res + _set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": []}) + monkeypatch.setattr( + module.DocumentService, + "get_by_id", + lambda _doc_id: (_ for _ in ()).throw(AssertionError("get_by_id must not run for empty delete payload")), + ) + monkeypatch.setattr( + module.settings.docStoreConn, + "delete", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("delete must not run for empty delete payload")), + ) + res = _run(module.rm()) + assert res["code"] == 0, res + monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc())) def _raise_delete(*_args, **_kwargs): diff --git a/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py index b611fcd457c..6247eae08f2 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py +++ b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py @@ -165,7 +165,7 @@ def test_delete_1k(self, WebApiAuth, add_document): pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3), pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1), - pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), + pytest.param({"chunk_ids": []}, 0, "", 5, marks=pytest.mark.p3), ], ) def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining): diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index bc81ac1254f..967c95ef7b9 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -472,14 +472,8 @@ def test_delete_route_error_summary_matrix_unit(monkeypatch): assert res["data"]["errors"], res req_state["ids"] = None - monkeypatch.setattr( - module.KnowledgebaseService, - "query", - lambda **_kwargs: (_ for _ in ()).throw(module.OperationalError("db down")), - ) res = _run(inspect.unwrap(module.delete)("tenant-1")) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert res["message"] == "Database operation failed", res + assert res["code"] == module.RetCode.SUCCESS, res @pytest.mark.p2 diff --git a/test/testcases/test_web_api/test_kb_app/conftest.py b/test/testcases/test_web_api/test_kb_app/conftest.py index 0a435483ce8..8a2387391b4 100644 --- a/test/testcases/test_web_api/test_kb_app/conftest.py +++ b/test/testcases/test_web_api/test_kb_app/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_datasets +from common import batch_create_datasets, list_kbs, rm_kb from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest from ragflow_sdk import RAGFlow @@ -22,17 +22,31 @@ @pytest.fixture(scope="class") def add_datasets(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: + dataset_ids = batch_create_datasets(WebApiAuth, 5) + def cleanup(): - client.delete_datasets(ids=None) + # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]["kbs"]} + for dataset_id in dataset_ids: + if dataset_id in existing_ids: + rm_kb(WebApiAuth, {"kb_id": dataset_id}) request.addfinalizer(cleanup) - return batch_create_datasets(WebApiAuth, 5) + return dataset_ids @pytest.fixture(scope="function") def add_datasets_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: + dataset_ids = batch_create_datasets(WebApiAuth, 3) + def cleanup(): - client.delete_datasets(ids=None) + # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]["kbs"]} + for dataset_id in dataset_ids: + if dataset_id in existing_ids: + rm_kb(WebApiAuth, {"kb_id": dataset_id}) request.addfinalizer(cleanup) - return batch_create_datasets(WebApiAuth, 3) + return dataset_ids From 1198e2f6d004f3c10ae12dc68a9e3b2b521f8186 Mon Sep 17 00:00:00 2001 From: OliverW <1225191678@qq.com> Date: Fri, 6 Mar 2026 18:18:14 +0800 Subject: [PATCH 165/565] fix(auth): return HTTP 401 for token-auth failures (#13420) Follow-up to #12488 #13386 ### What problem does this PR solve? Previously, token authentication failures returned HTTP 200 with an error code in the response body. This PR updates `token_required` to raise `Unauthorized` and relies on the global error handler to return a structured JSON response with HTTP 401 status. The response body structure (`code`, `message`, `data`) remains unchanged to preserve compatibility with the official SDK. Frontend logic has been updated to handle HTTP 401 responses in addition to checking `data.code`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/__init__.py | 14 +++--- api/utils/api_utils.py | 50 ++++++++++--------- test/testcases/test_http_api/common.py | 3 +- .../test_download_document.py | 2 +- .../test_system_app/test_apps_init_unit.py | 7 ++- web/src/utils/next-request.ts | 40 ++++++++++++--- web/src/utils/request.ts | 39 +++++++++++++-- 7 files changed, 106 insertions(+), 49 deletions(-) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 89078d9fb81..2bea1226290 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -46,15 +46,15 @@ def _unauthorized_message(error): if error is None: return UNAUTHORIZED_MESSAGE + + description = getattr(error, "description", None) + if description: + return description + try: - msg = repr(error) + return repr(error) except Exception: return UNAUTHORIZED_MESSAGE - if msg == UNAUTHORIZED_MESSAGE: - return msg - if "Unauthorized" in msg and "401" in msg: - return msg - return UNAUTHORIZED_MESSAGE app = Quart(__name__) app = cors(app, allow_origin="*") @@ -316,7 +316,7 @@ async def unauthorized_quart_auth(error): @app.errorhandler(WerkzeugUnauthorized) async def unauthorized_werkzeug(error): logging.warning("Unauthorized request (werkzeug)") - return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED + return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED @app.teardown_request def _db_close(exception): diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 9849a5c0eb3..b70ff2f9f15 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -33,7 +33,7 @@ request, has_app_context, ) -from werkzeug.exceptions import BadRequest as WerkzeugBadRequest +from werkzeug.exceptions import BadRequest as WerkzeugBadRequest, Unauthorized as WerkzeugUnauthorized try: from quart.exceptions import BadRequest as QuartBadRequest @@ -270,39 +270,41 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da def token_required(func): - def get_tenant_id(**kwargs): + @wraps(func) + async def wrapper(*args, **kwargs): + # Validate the token (API Key) if os.environ.get("DISABLE_SDK"): - return False, get_json_result(data=False, message="`Authorization` can't be empty") + err = WerkzeugUnauthorized(description="`Authorization` can't be empty") + err.code = RetCode.SUCCESS + raise err + authorization_str = request.headers.get("Authorization") if not authorization_str: - return False, get_json_result(data=False, message="`Authorization` can't be empty") + err = WerkzeugUnauthorized(description="`Authorization` can't be empty") + err.code = RetCode.SUCCESS + raise err + authorization_list = authorization_str.split() if len(authorization_list) < 2: - return False, get_json_result(data=False, message="Please check your authorization format.") + err = WerkzeugUnauthorized(description="Please check your authorization format.") + err.code = RetCode.AUTHENTICATION_ERROR + raise err + token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) - kwargs["tenant_id"] = objs[0].tenant_id - return True, kwargs + err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!") + err.code = RetCode.AUTHENTICATION_ERROR + raise err - @wraps(func) - def decorated_function(*args, **kwargs): - e, kwargs = get_tenant_id(**kwargs) - if not e: - return kwargs - return func(*args, **kwargs) + # On success, inject tenant_id into the route function's kwargs + kwargs["tenant_id"] = objs[0].tenant_id + result = func(*args, **kwargs) + if inspect.iscoroutine(result): + return await result + return result - @wraps(func) - async def adecorated_function(*args, **kwargs): - e, kwargs = get_tenant_id(**kwargs) - if not e: - return kwargs - return await func(*args, **kwargs) - - if inspect.iscoroutinefunction(func): - return adecorated_function - return decorated_function + return wrapper def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index d6334543db3..592d35c3c16 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -116,7 +116,8 @@ def download_document(auth, dataset_id, document_id, save_path): url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id) res = requests.get(url=url, auth=auth, stream=True) try: - if res.status_code == 200: + # available for unauthed downloads + if res.status_code in (200, 401): with open(save_path, "wb") as f: for chunk in res.iter_content(chunk_size=8192): f.write(chunk) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py index 4cbc9e19bd9..36c28b12c3b 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py @@ -40,7 +40,7 @@ class TestAuthorization: ) def test_invalid_auth(self, invalid_auth, tmp_path, expected_code, expected_message): res = download_document(invalid_auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt") - assert res.status_code == codes.ok + assert res.status_code == 401 with (tmp_path / "ragflow_tes.txt").open("r") as f: response_json = json.load(f) assert response_json["code"] == expected_code diff --git a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py index cfd79ce879e..5b8dcca19f6 100644 --- a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py +++ b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py @@ -108,15 +108,14 @@ class _Unauthorized401Repr: def __repr__(self): return "Unauthorized 401 from upstream" - class _OtherRepr: - def __repr__(self): - return "Forbidden 403" + class _WithDescription: + description = "Custom description" assert apps_module._unauthorized_message(None) == apps_module.UNAUTHORIZED_MESSAGE assert apps_module._unauthorized_message(_BrokenRepr()) == apps_module.UNAUTHORIZED_MESSAGE assert apps_module._unauthorized_message(_ExactUnauthorizedRepr()) == apps_module.UNAUTHORIZED_MESSAGE assert apps_module._unauthorized_message(_Unauthorized401Repr()) == "Unauthorized 401 from upstream" - assert apps_module._unauthorized_message(_OtherRepr()) == apps_module.UNAUTHORIZED_MESSAGE + assert apps_module._unauthorized_message(_WithDescription()) == "Custom description" @pytest.mark.p2 diff --git a/web/src/utils/next-request.ts b/web/src/utils/next-request.ts index c16a0295a2b..d2ead134a1a 100644 --- a/web/src/utils/next-request.ts +++ b/web/src/utils/next-request.ts @@ -73,6 +73,9 @@ const errorHandler = (error: { return response ?? { data: { code: 1999 } }; }; +// avoid duplicate 401 redirects +let isRedirecting = false; + const request = axios.create({ // errorHandler, timeout: 300000, @@ -123,13 +126,16 @@ request.interceptors.response.use( if (data?.code === 100) { message.error(data?.message); } else if (data?.code === 401) { - notification.error({ - message: data?.message, - description: data?.message, - duration: 3, - }); - authorizationUtil.removeAll(); - redirectToLogin(); + if (!isRedirecting) { + isRedirecting = true; + notification.error({ + message: data?.message, + description: data?.message, + duration: 3, + }); + authorizationUtil.removeAll(); + redirectToLogin(); + } } else if (data?.code !== 0) { notification.error({ message: `${i18n.t('message.hint')} : ${data?.code}`, @@ -141,6 +147,26 @@ request.interceptors.response.use( }, function (error) { console.log('🚀 ~ error:', error); + + // Handle HTTP 401 (token expired / invalid) + const status = error?.response?.status; + if (status === 401) { + if (!isRedirecting) { + isRedirecting = true; + const messageText = + error?.response?.data?.message || RetcodeMessage[401]; + notification.error({ + message: messageText, + description: messageText, + duration: 3, + }); + authorizationUtil.removeAll(); + redirectToLogin(); + } + + return Promise.reject(error); + } + errorHandler(error); return Promise.reject(error); }, diff --git a/web/src/utils/request.ts b/web/src/utils/request.ts index 5917fbab511..f957cb2a086 100644 --- a/web/src/utils/request.ts +++ b/web/src/utils/request.ts @@ -80,6 +80,9 @@ const request: RequestMethod = extend({ getResponse: true, }); +// avoid duplicate 401 redirects +let isRedirecting = false; + request.interceptors.request.use((url: string, options: any) => { const data = convertTheKeysOfTheObjectToSnake(options.data); const params = convertTheKeysOfTheObjectToSnake(options.params); @@ -109,6 +112,27 @@ request.interceptors.response.use(async (response: Response, options) => { message.error(RetcodeMessage[response?.status as ResultCode]); } + // Handle HTTP 401 + if (response?.status === 401) { + if (!isRedirecting) { + isRedirecting = true; + + const data = await response.clone().json().catch(() => ({})); + + const messageText = + data?.message || RetcodeMessage[401]; + notification.error({ + message: messageText, + description: messageText, + duration: 3, + }); + authorizationUtil.removeAll(); + redirectToLogin(); + } + + return response; + } + if (options.responseType === 'blob') { return response; } @@ -126,11 +150,16 @@ request.interceptors.response.use(async (response: Response, options) => { if (data?.code === 100) { message.error(data?.message); } else if (data?.code === 401) { - notification.error({ - message: data?.message, - description: data?.message, - duration: 3, - }); + if (!isRedirecting) { + isRedirecting = true; + notification.error({ + message: data?.message, + description: data?.message, + duration: 3, + }); + authorizationUtil.removeAll(); + redirectToLogin(); + } authorizationUtil.removeAll(); redirectToLogin(); } else if (data?.code !== 0) { From 007ea7a5de7160a5f3c6917fa53016d7bd0fbf49 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Fri, 6 Mar 2026 20:05:10 +0800 Subject: [PATCH 166/565] Fix data models (#13444) ### What problem does this PR solve? Since database model is updated in python version, go server also need to update ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: Jin Hai --- Dockerfile | 8 +- cmd/admin_server.go | 15 +- cmd/server_main.go | 12 +- internal/admin/service.go | 2 +- internal/dao/database.go | 44 ++++- internal/dao/migration.go | 307 ++++++++++++++++++++++++++++++ internal/model/api.go | 26 +-- internal/model/base.go | 3 +- internal/model/canvas.go | 11 +- internal/model/chat.go | 16 +- internal/model/connector.go | 4 +- internal/model/document.go | 5 +- internal/model/evaluation.go | 46 +++-- internal/model/kb.go | 3 +- internal/model/llm.go | 4 +- internal/model/mcp.go | 4 +- internal/model/memory.go | 2 + internal/model/pipeline.go | 2 +- internal/model/search.go | 2 +- internal/model/system.go | 2 +- internal/model/tenant.go | 30 +-- internal/model/tenant_llm.go | 20 +- internal/service/chat.go | 2 +- internal/service/chat_session.go | 2 +- internal/service/document.go | 5 +- internal/service/llm.go | 48 +++-- internal/service/model_service.go | 4 +- internal/service/user.go | 70 +++++-- internal/tokenizer/tokenizer.go | 3 +- 29 files changed, 565 insertions(+), 137 deletions(-) create mode 100644 internal/dao/migration.go diff --git a/Dockerfile b/Dockerfile index ee19086b3aa..071efdfc33b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -212,7 +212,13 @@ COPY pyproject.toml uv.lock ./ COPY mcp mcp COPY common common COPY memory memory -COPY bin bin + +RUN if [ -d bin ]; then \ + cp -r bin ./; \ + echo "✓ bin copied"; \ + else \ + echo "✗ bin ignored"; \ + fi COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/cmd/admin_server.go b/cmd/admin_server.go index d553a44d9ec..8a7587487b2 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -18,6 +18,7 @@ package main import ( "flag" + "fmt" "os" "github.com/gin-gonic/gin" @@ -65,7 +66,6 @@ func (s *AdminServer) Init() error { // Run start admin server func (s *AdminServer) Run() error { - logger.Info("Starting admin server", zap.String("port", s.port)) return s.engine.Run(":" + s.port) } @@ -107,21 +107,20 @@ func main() { os.Exit(1) } + // Print all configuration settings + server.PrintAll() + // Print RAGFlow Admin logo logger.Info("" + "\n ____ ___ ______________ ___ __ _ \n" + " / __ \\/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___ \n" + " / /_/ / /| |/ / __/ /_ / / __ \\ | /| / / / /| |/ __ / __ `__ \\/ / __ \\ \n" + - " / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / / /\n" + + " / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /\n" + " /_/ |_/_/ |_\\____/_/ /_/\\____/|__/|__/ /_/ |_\\__,_/_/ /_/ /_/_/_/ /_/ \n") // Print RAGFlow version - logger.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion())) - - // Print all configuration settings - server.PrintAll() - - logger.Info("Starting RAGFlow Admin Server", zap.String("port", "9381")) + logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) + logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: 9381")) if err := adminServer.Run(); err != nil { logger.Error("Admin server error", err) os.Exit(1) diff --git a/cmd/server_main.go b/cmd/server_main.go index e079371e331..011869145bc 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -7,6 +7,8 @@ import ( "os" "os/signal" "ragflow/internal/server" + "ragflow/internal/utility" + "strings" "syscall" "time" @@ -154,6 +156,14 @@ func main() { // Start server in a goroutine go func() { + logger.Info( + "\n ____ ___ ______ ______ __\n" + + " / __ \\ / | / ____// ____// /____ _ __\n" + + " / /_/ // /| | / / __ / /_ / // __ \\| | /| / /\n" + + " / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /\n" + + " /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n", + ) + logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) logger.Info(fmt.Sprintf("Server starting on port: %d", cfg.Server.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) @@ -165,7 +175,7 @@ func main() { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) sig := <-quit - logger.Info("Received signal", zap.String("signal", sig.String())) + logger.Info(fmt.Sprintf("Receives %s signal to shutdown server", strings.ToUpper(sig.String()))) logger.Info("Shutting down server...") // Create context with timeout for graceful shutdown diff --git a/internal/admin/service.go b/internal/admin/service.go index 1c4b430f58b..80b2792dfec 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -94,7 +94,7 @@ type UserInfo struct { Email string Nickname string IsActive string - CreateTime int64 + CreateTime *int64 UpdateTime *int64 } diff --git a/internal/dao/database.go b/internal/dao/database.go index 163f172df98..391759431ed 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -77,15 +77,51 @@ func InitDB() error { sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) - // Auto migrate - if err := DB.AutoMigrate( + // Auto migrate all models + models := []interface{}{ &model.User{}, &model.Tenant{}, &model.UserTenant{}, &model.File{}, &model.File2Document{}, - ); err != nil { - return fmt.Errorf("failed to migrate database: %w", err) + &model.TenantLLM{}, + &model.Chat{}, + &model.ChatSession{}, + &model.Task{}, + &model.APIToken{}, + &model.API4Conversation{}, + &model.Knowledgebase{}, + &model.InvitationCode{}, + &model.Document{}, + &model.UserCanvas{}, + &model.CanvasTemplate{}, + &model.UserCanvasVersion{}, + &model.LLMFactories{}, + &model.LLM{}, + &model.TenantLangfuse{}, + &model.SystemSettings{}, + &model.Connector{}, + &model.Connector2Kb{}, + &model.SyncLogs{}, + &model.MCPServer{}, + &model.Memory{}, + &model.Search{}, + &model.PipelineOperationLog{}, + &model.EvaluationDataset{}, + &model.EvaluationCase{}, + &model.EvaluationRun{}, + &model.EvaluationResult{}, + } + + for _, m := range models { + if err := DB.AutoMigrate(m); err != nil { + return fmt.Errorf("failed to migrate model %T: %w", m, err) + } + } + + // Run manual migrations for complex schema changes + if err := RunMigrations(DB); err != nil { + return fmt.Errorf("failed to run manual migrations: %w", err) } logger.Info("Database connected and migrated successfully") diff --git a/internal/dao/migration.go b/internal/dao/migration.go new file mode 100644 index 00000000000..4c2627d65d0 --- /dev/null +++ b/internal/dao/migration.go @@ -0,0 +1,307 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "fmt" + "ragflow/internal/logger" + + "go.uber.org/zap" + "gorm.io/gorm" +) + +// RunMigrations runs all manual database migrations +// These are migrations that cannot be handled by AutoMigrate alone +func RunMigrations(db *gorm.DB) error { + // Check if tenant_llm table has composite primary key and migrate to ID primary key + if err := migrateTenantLLMPrimaryKey(db); err != nil { + return fmt.Errorf("failed to migrate tenant_llm primary key: %w", err) + } + + // Rename columns (correct typos) + if err := renameColumnIfExists(db, "task", "process_duation", "process_duration"); err != nil { + return fmt.Errorf("failed to rename task.process_duation: %w", err) + } + if err := renameColumnIfExists(db, "document", "process_duation", "process_duration"); err != nil { + return fmt.Errorf("failed to rename document.process_duation: %w", err) + } + + // Add unique index on user.email + if err := migrateAddUniqueEmail(db); err != nil { + return fmt.Errorf("failed to add unique index on user.email: %w", err) + } + + // Modify column types that AutoMigrate may not handle correctly + if err := modifyColumnTypes(db); err != nil { + return fmt.Errorf("failed to modify column types: %w", err) + } + + logger.Info("All manual migrations completed successfully") + return nil +} + +// migrateTenantLLMPrimaryKey migrates tenant_llm from composite primary key to ID primary key +// This corresponds to Python's update_tenant_llm_to_id_primary_key function +func migrateTenantLLMPrimaryKey(db *gorm.DB) error { + // Check if tenant_llm table exists + if !db.Migrator().HasTable("tenant_llm") { + return nil + } + + // Check if 'id' column already exists using raw SQL + var idColumnExists int64 + err := db.Raw(` + SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = 'tenant_llm' AND COLUMN_NAME = 'id' + `).Scan(&idColumnExists).Error + if err != nil { + return err + } + + if idColumnExists > 0 { + // Check if id is already a primary key with auto_increment + var count int64 + err := db.Raw(` + SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = 'tenant_llm' + AND COLUMN_NAME = 'id' + AND EXTRA LIKE '%auto_increment%' + `).Scan(&count).Error + if err != nil { + return err + } + if count > 0 { + // Already migrated + return nil + } + } + + logger.Info("Migrating tenant_llm to use ID primary key...") + + // Start transaction + return db.Transaction(func(tx *gorm.DB) error { + // Check for temp_id column and drop it if exists + var tempIdExists int64 + tx.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = 'tenant_llm' AND COLUMN_NAME = 'temp_id'`).Scan(&tempIdExists) + if tempIdExists > 0 { + if err := tx.Exec("ALTER TABLE tenant_llm DROP COLUMN temp_id").Error; err != nil { + logger.Warn("Failed to drop temp_id column", zap.Error(err)) + } + } + + // Check if there's already an 'id' column + if idColumnExists > 0 { + // Modify existing id column to be auto_increment primary key + if err := tx.Exec(` + ALTER TABLE tenant_llm + MODIFY COLUMN id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY + `).Error; err != nil { + return fmt.Errorf("failed to modify id column: %w", err) + } + } else { + // Add id column as auto_increment primary key + if err := tx.Exec(` + ALTER TABLE tenant_llm + ADD COLUMN id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY FIRST + `).Error; err != nil { + return fmt.Errorf("failed to add id column: %w", err) + } + } + + // Add unique index on (tenant_id, llm_factory, llm_name) + var idxExists int64 + tx.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_NAME = 'tenant_llm' AND INDEX_NAME = 'idx_tenant_llm_unique'`).Scan(&idxExists) + if idxExists == 0 { + if err := tx.Exec(` + ALTER TABLE tenant_llm + ADD UNIQUE INDEX idx_tenant_llm_unique (tenant_id, llm_factory, llm_name) + `).Error; err != nil { + logger.Warn("Failed to add unique index idx_tenant_llm_unique", zap.Error(err)) + } + } + + logger.Info("tenant_llm primary key migration completed") + return nil + }) +} + +// migrateAddUniqueEmail adds unique index on user.email +func migrateAddUniqueEmail(db *gorm.DB) error { + if !db.Migrator().HasTable("user") { + return nil + } + + // Check if unique index already exists using raw SQL + var count int64 + db.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_NAME = 'user' AND INDEX_NAME = 'idx_user_email_unique'`).Scan(&count) + if count > 0 { + return nil + } + + // Check if there's a duplicate email issue first + var duplicateCount int64 + err := db.Raw(` + SELECT COUNT(*) FROM ( + SELECT email FROM user GROUP BY email HAVING COUNT(*) > 1 + ) AS duplicates + `).Scan(&duplicateCount).Error + if err != nil { + return err + } + + if duplicateCount > 0 { + logger.Warn("Found duplicate emails in user table, cannot add unique index", zap.Int64("count", duplicateCount)) + return nil + } + + logger.Info("Adding unique index on user.email...") + if err := db.Exec(`ALTER TABLE user ADD UNIQUE INDEX idx_user_email_unique (email)`).Error; err != nil { + return fmt.Errorf("failed to add unique index on email: %w", err) + } + + return nil +} + +// modifyColumnTypes modifies column types that need explicit ALTER statements +func modifyColumnTypes(db *gorm.DB) error { + // Helper function to check if column exists + columnExists := func(table, column string) bool { + var count int64 + db.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = ? AND COLUMN_NAME = ?`, table, column).Scan(&count) + return count > 0 + } + + // dialog.top_k: ensure it's INTEGER with default 1024 + if db.Migrator().HasTable("dialog") && columnExists("dialog", "top_k") { + if err := db.Exec(`ALTER TABLE dialog MODIFY COLUMN top_k BIGINT NOT NULL DEFAULT 1024`).Error; err != nil { + logger.Warn("Failed to modify dialog.top_k", zap.Error(err)) + } + } + + // tenant_llm.api_key: ensure it's TEXT type + if db.Migrator().HasTable("tenant_llm") && columnExists("tenant_llm", "api_key") { + if err := db.Exec(`ALTER TABLE tenant_llm MODIFY COLUMN api_key LONGTEXT`).Error; err != nil { + logger.Warn("Failed to modify tenant_llm.api_key", zap.Error(err)) + } + } + + // api_token.dialog_id: ensure it's varchar(32) + if db.Migrator().HasTable("api_token") && columnExists("api_token", "dialog_id") { + if err := db.Exec(`ALTER TABLE api_token MODIFY COLUMN dialog_id VARCHAR(32)`).Error; err != nil { + logger.Warn("Failed to modify api_token.dialog_id", zap.Error(err)) + } + } + + // canvas_template.title and description: ensure they're LONGTEXT type (same as Python JSONField) + // Note: Python's JSONField uses null=True with application-level default, not database DEFAULT + if db.Migrator().HasTable("canvas_template") { + if columnExists("canvas_template", "title") { + if err := db.Exec(`ALTER TABLE canvas_template MODIFY COLUMN title LONGTEXT NULL`).Error; err != nil { + logger.Warn("Failed to modify canvas_template.title", zap.Error(err)) + } + } + if columnExists("canvas_template", "description") { + if err := db.Exec(`ALTER TABLE canvas_template MODIFY COLUMN description LONGTEXT NULL`).Error; err != nil { + logger.Warn("Failed to modify canvas_template.description", zap.Error(err)) + } + } + } + + // system_settings.value: ensure it's LONGTEXT + if db.Migrator().HasTable("system_settings") && columnExists("system_settings", "value") { + if err := db.Exec(`ALTER TABLE system_settings MODIFY COLUMN value LONGTEXT NOT NULL`).Error; err != nil { + logger.Warn("Failed to modify system_settings.value", zap.Error(err)) + } + } + + // knowledgebase.raptor_task_finish_at: ensure it's DateTime + if db.Migrator().HasTable("knowledgebase") && columnExists("knowledgebase", "raptor_task_finish_at") { + if err := db.Exec(`ALTER TABLE knowledgebase MODIFY COLUMN raptor_task_finish_at DATETIME`).Error; err != nil { + logger.Warn("Failed to modify knowledgebase.raptor_task_finish_at", zap.Error(err)) + } + } + + // knowledgebase.mindmap_task_finish_at: ensure it's DateTime + if db.Migrator().HasTable("knowledgebase") && columnExists("knowledgebase", "mindmap_task_finish_at") { + if err := db.Exec(`ALTER TABLE knowledgebase MODIFY COLUMN mindmap_task_finish_at DATETIME`).Error; err != nil { + logger.Warn("Failed to modify knowledgebase.mindmap_task_finish_at", zap.Error(err)) + } + } + + return nil +} + +// renameColumnIfExists renames a column if it exists and the new column doesn't exist +func renameColumnIfExists(db *gorm.DB, tableName, oldName, newName string) error { + if !db.Migrator().HasTable(tableName) { + return nil + } + + // Helper to check if column exists + columnExists := func(column string) bool { + var count int64 + db.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = ? AND COLUMN_NAME = ?`, tableName, column).Scan(&count) + return count > 0 + } + + // Check if old column exists + if !columnExists(oldName) { + return nil + } + + // Check if new column already exists + if columnExists(newName) { + // Both exist, drop the old one + logger.Warn("Both old and new columns exist, dropping old one", + zap.String("table", tableName), + zap.String("oldColumn", oldName), + zap.String("newColumn", newName)) + return db.Migrator().DropColumn(tableName, oldName) + } + + logger.Info("Renaming column", + zap.String("table", tableName), + zap.String("oldColumn", oldName), + zap.String("newColumn", newName)) + return db.Migrator().RenameColumn(tableName, oldName, newName) +} + +// addColumnIfNotExists adds a column if it doesn't exist +func addColumnIfNotExists(db *gorm.DB, tableName, columnName, columnDef string) error { + if !db.Migrator().HasTable(tableName) { + return nil + } + + // Check if column exists using raw SQL + var count int64 + db.Raw(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = ? AND COLUMN_NAME = ?`, tableName, columnName).Scan(&count) + if count > 0 { + return nil + } + + logger.Info("Adding column", + zap.String("table", tableName), + zap.String("column", columnName)) + sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", tableName, columnName, columnDef) + return db.Exec(sql).Error +} diff --git a/internal/model/api.go b/internal/model/api.go index afc3a985fb2..1f22e8b8ddb 100644 --- a/internal/model/api.go +++ b/internal/model/api.go @@ -33,18 +33,20 @@ func (APIToken) TableName() string { // API4Conversation API for conversation model type API4Conversation struct { - ID string `gorm:"column:id;primaryKey;size:32" json:"id"` - DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` - UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"` - Message JSONMap `gorm:"column:message;type:json" json:"message,omitempty"` - Reference JSONMap `gorm:"column:reference;type:json;default:'[]'" json:"reference"` - Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"` - Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"` - DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` - Duration float64 `gorm:"column:duration;default:0;index" json:"duration"` - Round int64 `gorm:"column:round;default:0;index" json:"round"` - ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"` - Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"` + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Name *string `gorm:"column:name;size:255" json:"name,omitempty"` + DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` + UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"` + ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"` + Message JSONMap `gorm:"column:message;type:longtext" json:"message,omitempty"` + Reference JSONMap `gorm:"column:reference;type:longtext" json:"reference"` + Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"` + Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"` + Duration float64 `gorm:"column:duration;default:0;index" json:"duration"` + Round int64 `gorm:"column:round;default:0;index" json:"round"` + ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"` + Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"` BaseModel } diff --git a/internal/model/base.go b/internal/model/base.go index dfccc45a80b..beb60682a4b 100644 --- a/internal/model/base.go +++ b/internal/model/base.go @@ -23,8 +23,9 @@ import ( ) // BaseModel base model +// All time fields are nullable to match Python Peewee model (null=True) type BaseModel struct { - CreateTime int64 `gorm:"column:create_time;index" json:"create_time"` + CreateTime *int64 `gorm:"column:create_time;index" json:"create_time,omitempty"` CreateDate *time.Time `gorm:"column:create_date;index" json:"create_date,omitempty"` UpdateTime *int64 `gorm:"column:update_time;index" json:"update_time,omitempty"` UpdateDate *time.Time `gorm:"column:update_date;index" json:"update_date,omitempty"` diff --git a/internal/model/canvas.go b/internal/model/canvas.go index 06a0be3edd1..c10c3a19f64 100644 --- a/internal/model/canvas.go +++ b/internal/model/canvas.go @@ -23,10 +23,11 @@ type UserCanvas struct { UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"` Title *string `gorm:"column:title;size:255" json:"title,omitempty"` Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` + Release bool `gorm:"column:release;not null;default:false;index" json:"release"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` CanvasType *string `gorm:"column:canvas_type;size:32;index" json:"canvas_type,omitempty"` CanvasCategory string `gorm:"column:canvas_category;size:32;not null;default:agent_canvas;index" json:"canvas_category"` - DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"` BaseModel } @@ -39,11 +40,11 @@ func (UserCanvas) TableName() string { type CanvasTemplate struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` - Title JSONMap `gorm:"column:title;type:json;default:'{}'" json:"title"` - Description JSONMap `gorm:"column:description;type:json;default:'{}'" json:"description"` + Title JSONMap `gorm:"column:title;type:longtext" json:"title"` + Description JSONMap `gorm:"column:description;type:longtext" json:"description"` CanvasType *string `gorm:"column:canvas_type;size:32;index" json:"canvas_type,omitempty"` CanvasCategory string `gorm:"column:canvas_category;size:32;not null;default:agent_canvas;index" json:"canvas_category"` - DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"` BaseModel } @@ -58,7 +59,7 @@ type UserCanvasVersion struct { UserCanvasID string `gorm:"column:user_canvas_id;size:255;not null;index" json:"user_canvas_id"` Title *string `gorm:"column:title;size:255" json:"title,omitempty"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` - DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"` BaseModel } diff --git a/internal/model/chat.go b/internal/model/chat.go index 2bb54aec40b..cceb4acecad 100644 --- a/internal/model/chat.go +++ b/internal/model/chat.go @@ -27,17 +27,19 @@ type Chat struct { Icon *string `gorm:"column:icon;type:longtext" json:"icon,omitempty"` Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` LLMID string `gorm:"column:llm_id;size:128;not null" json:"llm_id"` - LLMSetting JSONMap `gorm:"column:llm_setting;type:json;not null;default:'{\"temperature\":0.1,\"top_p\":0.3,\"frequency_penalty\":0.7,\"presence_penalty\":0.4,\"max_tokens\":512}'" json:"llm_setting"` - PromptType string `gorm:"column:prompt_type;size:16;not null;default:simple;index" json:"prompt_type"` - PromptConfig JSONMap `gorm:"column:prompt_config;type:json;not null;default:'{\"system\":\"\",\"prologue\":\"Hi! I'm your assistant. What can I do for you?\",\"parameters\":[],\"empty_response\":\"Sorry! No relevant content was found in the knowledge base!\"}'" json:"prompt_config"` - MetaDataFilter *JSONMap `gorm:"column:meta_data_filter;type:json" json:"meta_data_filter,omitempty"` + TenantLLMID *int64 `gorm:"column:tenant_llm_id;index" json:"tenant_llm_id,omitempty"` + LLMSetting JSONMap `gorm:"column:llm_setting;type:longtext;not null" json:"llm_setting"` + PromptType string `gorm:"column:prompt_type;size:16;not null;default:'simple';index" json:"prompt_type"` + PromptConfig JSONMap `gorm:"column:prompt_config;type:longtext;not null" json:"prompt_config"` + MetaDataFilter *JSONMap `gorm:"column:meta_data_filter;type:longtext" json:"meta_data_filter,omitempty"` SimilarityThreshold float64 `gorm:"column:similarity_threshold;default:0.2" json:"similarity_threshold"` VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3" json:"vector_similarity_weight"` TopN int64 `gorm:"column:top_n;default:6" json:"top_n"` TopK int64 `gorm:"column:top_k;default:1024" json:"top_k"` DoRefer string `gorm:"column:do_refer;size:1;not null;default:1" json:"do_refer"` RerankID string `gorm:"column:rerank_id;size:128;not null;default:''" json:"rerank_id"` - KBIDs JSONSlice `gorm:"column:kb_ids;type:json;not null;default:'[]'" json:"kb_ids"` + TenantRerankID *int64 `gorm:"column:tenant_rerank_id;index" json:"tenant_rerank_id,omitempty"` + KBIDs JSONSlice `gorm:"column:kb_ids;type:longtext;not null" json:"kb_ids"` Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` BaseModel } @@ -52,8 +54,8 @@ type ChatSession struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` Name *string `gorm:"column:name;size:255;index" json:"name,omitempty"` - Message json.RawMessage `gorm:"column:message;type:json" json:"message,omitempty"` - Reference json.RawMessage `gorm:"column:reference;type:json;default:'[]'" json:"reference"` + Message json.RawMessage `gorm:"column:message;type:longtext" json:"message,omitempty"` + Reference json.RawMessage `gorm:"column:reference;type:longtext" json:"reference"` UserID *string `gorm:"column:user_id;size:255;index" json:"user_id,omitempty"` BaseModel } diff --git a/internal/model/connector.go b/internal/model/connector.go index 893c12fb63b..33b38ec5c14 100644 --- a/internal/model/connector.go +++ b/internal/model/connector.go @@ -25,7 +25,7 @@ type Connector struct { Name string `gorm:"column:name;size:128;not null" json:"name"` Source string `gorm:"column:source;size:128;not null;index" json:"source"` InputType string `gorm:"column:input_type;size:128;not null;index" json:"input_type"` - Config JSONMap `gorm:"column:config;type:json;not null;default:'{}'" json:"config"` + Config JSONMap `gorm:"column:config;type:longtext;not null" json:"config"` RefreshFreq int64 `gorm:"column:refresh_freq;default:0" json:"refresh_freq"` PruneFreq int64 `gorm:"column:prune_freq;default:0" json:"prune_freq"` TimeoutSecs int64 `gorm:"column:timeout_secs;default:3600" json:"timeout_secs"` @@ -62,7 +62,7 @@ type SyncLogs struct { NewDocsIndexed int64 `gorm:"column:new_docs_indexed;default:0" json:"new_docs_indexed"` TotalDocsIndexed int64 `gorm:"column:total_docs_indexed;default:0" json:"total_docs_indexed"` DocsRemovedFromIndex int64 `gorm:"column:docs_removed_from_index;default:0" json:"docs_removed_from_index"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;default:''" json:"error_msg"` + ErrorMsg string `gorm:"column:error_msg;type:longtext;not null" json:"error_msg"` ErrorCount int64 `gorm:"column:error_count;default:0" json:"error_count"` FullExceptionTrace *string `gorm:"column:full_exception_trace;type:longtext" json:"full_exception_trace,omitempty"` TimeStarted *time.Time `gorm:"column:time_started;index" json:"time_started,omitempty"` diff --git a/internal/model/document.go b/internal/model/document.go index a161e08f772..ffe13fa57f4 100644 --- a/internal/model/document.go +++ b/internal/model/document.go @@ -25,7 +25,7 @@ type Document struct { KbID string `gorm:"column:kb_id;size:256;not null;index" json:"kb_id"` ParserID string `gorm:"column:parser_id;size:32;not null;index" json:"parser_id"` PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` - ParserConfig JSONMap `gorm:"column:parser_config;type:json;not null;default:'{\"pages\":[[1,1000000]],\"table_context_size\":0,\"image_context_size\":0}'" json:"parser_config"` + ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"` SourceType string `gorm:"column:source_type;size:128;not null;default:local;index" json:"source_type"` Type string `gorm:"column:type;size:32;not null;index" json:"type"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` @@ -38,7 +38,8 @@ type Document struct { ProgressMsg *string `gorm:"column:progress_msg;type:longtext" json:"progress_msg,omitempty"` ProcessBeginAt *time.Time `gorm:"column:process_begin_at;index" json:"process_begin_at,omitempty"` ProcessDuration float64 `gorm:"column:process_duration;default:0" json:"process_duration"` - MetaFields *JSONMap `gorm:"column:meta_fields;type:json" json:"meta_fields,omitempty"` + ContentHash *string `gorm:"column:content_hash;size:32;index" json:"content_hash,omitempty"` + MetaFields *JSONMap `gorm:"column:meta_fields;type:longtext" json:"meta_fields,omitempty"` Suffix string `gorm:"column:suffix;size:32;not null;index" json:"suffix"` Run *string `gorm:"column:run;size:1;index" json:"run,omitempty"` Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` diff --git a/internal/model/evaluation.go b/internal/model/evaluation.go index 5b9bac787ac..3e2de1fa5a1 100644 --- a/internal/model/evaluation.go +++ b/internal/model/evaluation.go @@ -17,15 +17,18 @@ package model // EvaluationDataset evaluation dataset model +// Note: Python defines custom create_time/update_time (not null) instead of using BaseModel's type EvaluationDataset struct { - ID string `gorm:"column:id;primaryKey;size:32" json:"id"` - TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` - Name string `gorm:"column:name;size:255;not null;index" json:"name"` + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index" json:"tenant_id"` + Name string `gorm:"column:name;size:255;not null;index" json:"name"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` - KbIDs JSONMap `gorm:"column:kb_ids;type:json;not null" json:"kb_ids"` - CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` - Status int64 `gorm:"column:status;default:1;index" json:"status"` - BaseModel + KbIDs JSONMap `gorm:"column:kb_ids;type:longtext;not null" json:"kb_ids"` + CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` + // Custom time fields (not null) to match Python + CreateTime int64 `gorm:"column:create_time;not null;index" json:"create_time"` + UpdateTime int64 `gorm:"column:update_time;not null" json:"update_time"` + Status int64 `gorm:"column:status;default:1;index" json:"status"` } // TableName specify table name @@ -34,15 +37,17 @@ func (EvaluationDataset) TableName() string { } // EvaluationCase evaluation case model +// Note: Python defines custom create_time (not null) instead of using BaseModel's type EvaluationCase struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` DatasetID string `gorm:"column:dataset_id;size:32;not null;index" json:"dataset_id"` Question string `gorm:"column:question;type:longtext;not null" json:"question"` ReferenceAnswer *string `gorm:"column:reference_answer;type:longtext" json:"reference_answer,omitempty"` - RelevantDocIDs *JSONMap `gorm:"column:relevant_doc_ids;type:json" json:"relevant_doc_ids,omitempty"` - RelevantChunkIDs *JSONMap `gorm:"column:relevant_chunk_ids;type:json" json:"relevant_chunk_ids,omitempty"` - Metadata *JSONMap `gorm:"column:metadata;type:json" json:"metadata,omitempty"` - BaseModel + RelevantDocIDs *JSONMap `gorm:"column:relevant_doc_ids;type:longtext" json:"relevant_doc_ids,omitempty"` + RelevantChunkIDs *JSONMap `gorm:"column:relevant_chunk_ids;type:longtext" json:"relevant_chunk_ids,omitempty"` + Metadata *JSONMap `gorm:"column:metadata;type:longtext" json:"metadata,omitempty"` + // Custom time field (not null) to match Python + CreateTime int64 `gorm:"column:create_time;not null" json:"create_time"` } // TableName specify table name @@ -51,16 +56,19 @@ func (EvaluationCase) TableName() string { } // EvaluationRun evaluation run model +// Note: Python defines custom create_time/complete_time instead of using BaseModel's type EvaluationRun struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` DatasetID string `gorm:"column:dataset_id;size:32;not null;index" json:"dataset_id"` DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"` Name string `gorm:"column:name;size:255;not null" json:"name"` - ConfigSnapshot JSONMap `gorm:"column:config_snapshot;type:json;not null" json:"config_snapshot"` - MetricsSummary *JSONMap `gorm:"column:metrics_summary;type:json" json:"metrics_summary,omitempty"` + ConfigSnapshot JSONMap `gorm:"column:config_snapshot;type:longtext;not null" json:"config_snapshot"` + MetricsSummary *JSONMap `gorm:"column:metrics_summary;type:longtext" json:"metrics_summary,omitempty"` Status string `gorm:"column:status;size:32;not null;default:PENDING" json:"status"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` - BaseModel + // Custom time fields to match Python + CreateTime int64 `gorm:"column:create_time;not null;index" json:"create_time"` + CompleteTime *int64 `gorm:"column:complete_time" json:"complete_time,omitempty"` } // TableName specify table name @@ -69,16 +77,18 @@ func (EvaluationRun) TableName() string { } // EvaluationResult evaluation result model +// Note: Python defines custom create_time (not null) instead of using BaseModel's type EvaluationResult struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` RunID string `gorm:"column:run_id;size:32;not null;index" json:"run_id"` CaseID string `gorm:"column:case_id;size:32;not null;index" json:"case_id"` GeneratedAnswer string `gorm:"column:generated_answer;type:longtext;not null" json:"generated_answer"` - RetrievedChunks JSONMap `gorm:"column:retrieved_chunks;type:json;not null" json:"retrieved_chunks"` - Metrics JSONMap `gorm:"column:metrics;type:json;not null" json:"metrics"` + RetrievedChunks JSONMap `gorm:"column:retrieved_chunks;type:longtext;not null" json:"retrieved_chunks"` + Metrics JSONMap `gorm:"column:metrics;type:longtext;not null" json:"metrics"` ExecutionTime float64 `gorm:"column:execution_time;not null" json:"execution_time"` - TokenUsage *JSONMap `gorm:"column:token_usage;type:json" json:"token_usage,omitempty"` - BaseModel + TokenUsage *JSONMap `gorm:"column:token_usage;type:longtext" json:"token_usage,omitempty"` + // Custom time field to match Python + CreateTime int64 `gorm:"column:create_time;not null" json:"create_time"` } // TableName specify table name diff --git a/internal/model/kb.go b/internal/model/kb.go index 8862b1e1acc..78cc643721d 100644 --- a/internal/model/kb.go +++ b/internal/model/kb.go @@ -27,6 +27,7 @@ type Knowledgebase struct { Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` + TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"` Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"` @@ -36,7 +37,7 @@ type Knowledgebase struct { VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"` ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"` PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` - ParserConfig JSONMap `gorm:"column:parser_config;type:json;not null;default:'{\"pages\":[[1,1000000]],\"table_context_size\":0,\"image_context_size\":0}'" json:"parser_config"` + ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"` Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"` GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"` GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"` diff --git a/internal/model/llm.go b/internal/model/llm.go index 96377d1ebb1..9b9054e7e68 100644 --- a/internal/model/llm.go +++ b/internal/model/llm.go @@ -51,8 +51,8 @@ func (LLM) TableName() string { // TenantLangfuse tenant langfuse model type TenantLangfuse struct { TenantID string `gorm:"column:tenant_id;primaryKey;size:32" json:"tenant_id"` - SecretKey string `gorm:"column:secret_key;size:2048;not null;index" json:"secret_key"` - PublicKey string `gorm:"column:public_key;size:2048;not null;index" json:"public_key"` + SecretKey string `gorm:"column:secret_key;size:2048;not null" json:"secret_key"` + PublicKey string `gorm:"column:public_key;size:2048;not null" json:"public_key"` Host string `gorm:"column:host;size:128;not null;index" json:"host"` BaseModel } diff --git a/internal/model/mcp.go b/internal/model/mcp.go index 044bbdab149..deaf1ac821a 100644 --- a/internal/model/mcp.go +++ b/internal/model/mcp.go @@ -24,8 +24,8 @@ type MCPServer struct { URL string `gorm:"column:url;size:2048;not null" json:"url"` ServerType string `gorm:"column:server_type;size:32;not null" json:"server_type"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` - Variables JSONMap `gorm:"column:variables;type:json;default:'{}'" json:"variables,omitempty"` - Headers JSONMap `gorm:"column:headers;type:json;default:'{}'" json:"headers,omitempty"` + Variables JSONMap `gorm:"column:variables;type:longtext" json:"variables,omitempty"` + Headers JSONMap `gorm:"column:headers;type:longtext" json:"headers,omitempty"` BaseModel } diff --git a/internal/model/memory.go b/internal/model/memory.go index 28f9f58c1c1..9e6480ad969 100644 --- a/internal/model/memory.go +++ b/internal/model/memory.go @@ -25,7 +25,9 @@ type Memory struct { MemoryType int64 `gorm:"column:memory_type;default:1;index" json:"memory_type"` StorageType string `gorm:"column:storage_type;size:32;not null;default:table;index" json:"storage_type"` EmbdID string `gorm:"column:embd_id;size:128;not null" json:"embd_id"` + TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"` LLMID string `gorm:"column:llm_id;size:128;not null" json:"llm_id"` + TenantLLMID *int64 `gorm:"column:tenant_llm_id;index" json:"tenant_llm_id,omitempty"` Permissions string `gorm:"column:permissions;size:16;not null;default:me;index" json:"permissions"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` MemorySize int64 `gorm:"column:memory_size;default:5242880;not null" json:"memory_size"` diff --git a/internal/model/pipeline.go b/internal/model/pipeline.go index a47d6119871..a9e9dcbdb5a 100644 --- a/internal/model/pipeline.go +++ b/internal/model/pipeline.go @@ -35,7 +35,7 @@ type PipelineOperationLog struct { ProgressMsg *string `gorm:"column:progress_msg;type:longtext" json:"progress_msg,omitempty"` ProcessBeginAt *time.Time `gorm:"column:process_begin_at;index" json:"process_begin_at,omitempty"` ProcessDuration float64 `gorm:"column:process_duration;default:0" json:"process_duration"` - DSL JSONMap `gorm:"column:dsl;type:json" json:"dsl,omitempty"` + DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"` TaskType string `gorm:"column:task_type;size:32;not null;default:''" json:"task_type"` OperationStatus string `gorm:"column:operation_status;size:32;not null" json:"operation_status"` Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` diff --git a/internal/model/search.go b/internal/model/search.go index da95ccd6939..c70bf94bb63 100644 --- a/internal/model/search.go +++ b/internal/model/search.go @@ -24,7 +24,7 @@ type Search struct { Name string `gorm:"column:name;size:128;not null;index" json:"name"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` - SearchConfig JSONMap `gorm:"column:search_config;type:json;not null;default:'{\"kb_ids\":[],\"doc_ids\":[],\"similarity_threshold\":0.2,\"vector_similarity_weight\":0.3,\"use_kg\":false,\"rerank_id\":\"\",\"top_k\":1024,\"summary\":false,\"chat_id\":\"\",\"chat_settingcross_languages\":[],\"highlight\":false,\"keyword\":false,\"web_search\":false,\"related_search\":false,\"query_mindmap\":false}'" json:"search_config"` + SearchConfig JSONMap `gorm:"column:search_config;type:longtext;not null" json:"search_config"` Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` BaseModel } diff --git a/internal/model/system.go b/internal/model/system.go index 48775561136..be94f1653a6 100644 --- a/internal/model/system.go +++ b/internal/model/system.go @@ -21,7 +21,7 @@ type SystemSettings struct { Name string `gorm:"column:name;primaryKey;size:128" json:"name"` Source string `gorm:"column:source;size:32;not null" json:"source"` DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` - Value string `gorm:"column:value;size:1024;not null" json:"value"` + Value string `gorm:"column:value;type:longtext;not null" json:"value"` } // TableName specify table name diff --git a/internal/model/tenant.go b/internal/model/tenant.go index f7f76df8d20..046bd701f77 100644 --- a/internal/model/tenant.go +++ b/internal/model/tenant.go @@ -18,18 +18,24 @@ package model // Tenant tenant model type Tenant struct { - ID string `gorm:"column:id;primaryKey;size:32" json:"id"` - Name *string `gorm:"column:name;size:100;index" json:"name,omitempty"` - PublicKey *string `gorm:"column:public_key;size:255;index" json:"public_key,omitempty"` - LLMID string `gorm:"column:llm_id;size:128;not null;index" json:"llm_id"` - EmbDID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` - ASRID string `gorm:"column:asr_id;size:128;not null;index" json:"asr_id"` - Img2TxtID string `gorm:"column:img2txt_id;size:128;not null;index" json:"img2txt_id"` - RerankID string `gorm:"column:rerank_id;size:128;not null;index" json:"rerank_id"` - TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"` - ParserIDs string `gorm:"column:parser_ids;size:256;not null" json:"parser_ids"` - Credit int64 `gorm:"column:credit;default:512;index" json:"credit"` - Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + Name *string `gorm:"column:name;size:100;index" json:"name,omitempty"` + PublicKey *string `gorm:"column:public_key;size:255;index" json:"public_key,omitempty"` + LLMID string `gorm:"column:llm_id;size:128;not null;index" json:"llm_id"` + TenantLLMID *int64 `gorm:"column:tenant_llm_id;index" json:"tenant_llm_id,omitempty"` + EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` + TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"` + ASRID string `gorm:"column:asr_id;size:128;not null;index" json:"asr_id"` + TenantASRID *int64 `gorm:"column:tenant_asr_id;index" json:"tenant_asr_id,omitempty"` + Img2TxtID string `gorm:"column:img2txt_id;size:128;not null;index" json:"img2txt_id"` + TenantImg2TxtID *int64 `gorm:"column:tenant_img2txt_id;index" json:"tenant_img2txt_id,omitempty"` + RerankID string `gorm:"column:rerank_id;size:128;not null;index" json:"rerank_id"` + TenantRerankID *int64 `gorm:"column:tenant_rerank_id;index" json:"tenant_rerank_id,omitempty"` + TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"` + TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"` + ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"` + Credit int64 `gorm:"column:credit;default:512;index" json:"credit"` + Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` BaseModel } diff --git a/internal/model/tenant_llm.go b/internal/model/tenant_llm.go index dbadca6bd95..613032372dd 100644 --- a/internal/model/tenant_llm.go +++ b/internal/model/tenant_llm.go @@ -17,16 +17,18 @@ package model // TenantLLM tenant LLM model +// Python uses PrimaryKeyField (auto-increment ID) with unique index on (tenant_id, llm_factory, llm_name) type TenantLLM struct { - TenantID string `gorm:"column:tenant_id;size:32;not null;primaryKey" json:"tenant_id"` - LLMFactory string `gorm:"column:llm_factory;size:128;not null;primaryKey" json:"llm_factory"` - ModelType string `gorm:"column:model_type;size:128;not null;index" json:"model_type"` - LLMName string `gorm:"column:llm_name;size:128;not null;primaryKey;default:\"\"" json:"llm_name"` - APIKey string `gorm:"column:api_key;type:longtext" json:"api_key,omitempty"` - APIBase string `gorm:"column:api_base;size:255" json:"api_base,omitempty"` - MaxTokens int64 `gorm:"column:max_tokens;default:8192;index" json:"max_tokens"` - UsedTokens int64 `gorm:"column:used_tokens;default:0;index" json:"used_tokens"` - Status string `gorm:"column:status;size:1;not null;default:1;index" json:"status"` + ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + TenantID string `gorm:"column:tenant_id;size:32;not null;index:idx_tenant_llm_unique,unique" json:"tenant_id"` + LLMFactory string `gorm:"column:llm_factory;size:128;not null;index:idx_tenant_llm_unique,unique" json:"llm_factory"` + ModelType *string `gorm:"column:model_type;size:128;index" json:"model_type,omitempty"` + LLMName *string `gorm:"column:llm_name;size:128;index:idx_tenant_llm_unique,unique;default:\"\"" json:"llm_name,omitempty"` + APIKey *string `gorm:"column:api_key;type:longtext" json:"api_key,omitempty"` + APIBase *string `gorm:"column:api_base;size:255" json:"api_base,omitempty"` + MaxTokens int64 `gorm:"column:max_tokens;default:8192;index" json:"max_tokens"` + UsedTokens int64 `gorm:"column:used_tokens;default:0;index" json:"used_tokens"` + Status string `gorm:"column:status;size:1;not null;default:1;index" json:"status"` BaseModel } diff --git a/internal/service/chat.go b/internal/service/chat.go index 3192a2152d9..b53706d180e 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -460,7 +460,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo KBIDs: kbIDsJSON, Status: strPtr("1"), } - chat.CreateTime = createTime + chat.CreateTime = &createTime chat.CreateDate = &now chat.UpdateTime = &createTime chat.UpdateDate = &now diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 7de702e92c6..e32dea3fa38 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -139,7 +139,7 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe UserID: &userID, Reference: referenceJSON, } - session.CreateTime = createTime + session.CreateTime = &createTime session.CreateDate = &now session.UpdateTime = &createTime session.UpdateDate = &now diff --git a/internal/service/document.go b/internal/service/document.go index 94267b797e0..ad94ce647c0 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -178,7 +178,10 @@ func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) ( // toResponse convert model.Document to DocumentResponse func (s *DocumentService) toResponse(doc *model.Document) *DocumentResponse { - createdAt := time.Unix(doc.CreateTime, 0).Format("2006-01-02 15:04:05") + createdAt := "" + if doc.CreateTime != nil { + createdAt = time.Unix(*doc.CreateTime, 0).Format("2006-01-02 15:04:05") + } updatedAt := "" if doc.UpdateTime != nil { updatedAt = time.Unix(*doc.UpdateTime, 0).Format("2006-01-02 15:04:05") diff --git a/internal/service/llm.go b/internal/service/llm.go index 5478f3d18fc..85b1cd99f8f 100644 --- a/internal/service/llm.go +++ b/internal/service/llm.go @@ -114,18 +114,18 @@ func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string // LLMListItem represents a single LLM item in the list response type LLMListItem struct { - LLMName string `json:"llm_name"` - ModelType string `json:"model_type"` - FID string `json:"fid"` - Available bool `json:"available"` - Status string `json:"status"` - MaxTokens int64 `json:"max_tokens,omitempty"` - CreateDate *string `json:"create_date,omitempty"` - CreateTime int64 `json:"create_time,omitempty"` - UpdateDate *string `json:"update_date,omitempty"` - UpdateTime *int64 `json:"update_time,omitempty"` - IsTools bool `json:"is_tools"` - Tags string `json:"tags,omitempty"` + LLMName string `json:"llm_name"` + ModelType string `json:"model_type"` + FID string `json:"fid"` + Available bool `json:"available"` + Status string `json:"status"` + MaxTokens int64 `json:"max_tokens,omitempty"` + CreateDate *string `json:"create_date,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + UpdateDate *string `json:"update_date,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` + IsTools bool `json:"is_tools"` + Tags string `json:"tags,omitempty"` } // ListLLMsResponse represents the response for list LLMs @@ -153,10 +153,14 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon // Build set of valid LLM names@factories status := make(map[string]bool) for _, tl := range tenantLLMs { - if tl.APIKey != "" && tl.Status == "1" { + if tl.APIKey != nil && *tl.APIKey != "" && tl.Status == "1" { facts[tl.LLMFactory] = true } - key := tl.LLMName + "@" + tl.LLMFactory + llmName := "" + if tl.LLMName != nil { + llmName = *tl.LLMName + } + key := llmName + "@" + tl.LLMFactory if tl.Status == "1" { status[key] = true } @@ -223,19 +227,27 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon // Add tenant LLMs that are not in the global list for _, tl := range tenantLLMs { - key := tl.LLMName + "@" + tl.LLMFactory + llmName := "" + if tl.LLMName != nil { + llmName = *tl.LLMName + } + key := llmName + "@" + tl.LLMFactory if llmSet[key] { continue } // Filter by model type if specified - if modelType != "" && !strings.Contains(tl.ModelType, modelType) { + modelTypeValue := "" + if tl.ModelType != nil { + modelTypeValue = *tl.ModelType + } + if modelType != "" && !strings.Contains(modelTypeValue, modelType) { continue } item := LLMListItem{ - LLMName: tl.LLMName, - ModelType: tl.ModelType, + LLMName: llmName, + ModelType: modelTypeValue, FID: tl.LLMFactory, Available: true, Status: tl.Status, diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 423c6856079..75082485c82 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -80,7 +80,7 @@ func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID stri } apiKey := embeddingModel.APIKey - if apiKey == "" { + if apiKey == nil || *apiKey == "" { return nil, fmt.Errorf("no API key found for tenant %s and model %s", tenantID, compositeModelName) } // Always get API base from model provider configuration @@ -91,7 +91,7 @@ func (p *ModelProviderImpl) GetEmbeddingModel(ctx context.Context, tenantID stri } apiBase := providerConfig.DefaultEmbeddingURL - return models.CreateEmbeddingModel(provider, apiKey, apiBase, modelName, p.httpClient) + return models.CreateEmbeddingModel(provider, *apiKey, apiBase, modelName, p.httpClient) } // GetChatModel returns a chat model for the given tenant diff --git a/internal/service/user.go b/internal/service/user.go index 99dd5351b57..9db2a264398 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -144,7 +144,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC } now := time.Now().Unix() - user.CreateTime = now + user.CreateTime = &now user.UpdateTime = &now now_date := time.Now() user.CreateDate = &now_date @@ -156,13 +156,13 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC ID: userID, Name: &tenantName, LLMID: cfg.Server.Mode, - EmbDID: cfg.Server.Mode, + EmbdID: cfg.Server.Mode, ASRID: cfg.Server.Mode, Img2TxtID: cfg.Server.Mode, RerankID: cfg.Server.Mode, ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", } - tenant.CreateTime = now + tenant.CreateTime = &now tenant.UpdateTime = &now tenant.CreateDate = &now_date tenant.UpdateDate = &now_date @@ -176,7 +176,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC InvitedBy: userID, Status: &status, } - userTenant.CreateTime = now + userTenant.CreateTime = &now userTenant.UpdateTime = &now userTenant.CreateDate = &now_date userTenant.UpdateDate = &now_date @@ -191,7 +191,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC Type: "folder", Size: 0, } - rootFile.CreateTime = now + rootFile.CreateTime = &now rootFile.UpdateTime = &now rootFile.CreateDate = &now_date rootFile.UpdateDate = &now_date @@ -205,20 +205,38 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC } if err := tenantDAO.Create(tenant); err != nil { - s.userDAO.DeleteByID(userID) + err := s.userDAO.DeleteByID(userID) + if err != nil { + return nil, 0, err + } return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err) } if err := userTenantDAO.Create(userTenant); err != nil { - s.userDAO.DeleteByID(userID) - tenantDAO.Delete(userID) + err := s.userDAO.DeleteByID(userID) + if err != nil { + return nil, 0, err + } + err = tenantDAO.Delete(userID) + if err != nil { + return nil, 0, err + } return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err) } if err := fileDAO.Create(rootFile); err != nil { - s.userDAO.DeleteByID(userID) - tenantDAO.Delete(userID) - userTenantDAO.Delete(userTenantID) + err := s.userDAO.DeleteByID(userID) + if err != nil { + return nil, 0, err + } + err = tenantDAO.Delete(userID) + if err != nil { + return nil, 0, err + } + err = userTenantDAO.Delete(userTenantID) + if err != nil { + return nil, 0, err + } return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err) } @@ -314,11 +332,16 @@ func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, err } return &UserResponse{ - ID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - Status: user.Status, - CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), + ID: user.ID, + Email: user.Email, + Nickname: user.Nickname, + Status: user.Status, + CreatedAt: func() string { + if user.CreateTime != nil { + return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05") + } + return "" + }(), }, common.CodeSuccess, nil } @@ -333,11 +356,16 @@ func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, com responses := make([]*UserResponse, len(users)) for i, user := range users { responses[i] = &UserResponse{ - ID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - Status: user.Status, - CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), + ID: user.ID, + Email: user.Email, + Nickname: user.Nickname, + Status: user.Status, + CreatedAt: func() string { + if user.CreateTime != nil { + return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05") + } + return "" + }(), } } diff --git a/internal/tokenizer/tokenizer.go b/internal/tokenizer/tokenizer.go index 54f89d34869..9fe895e7118 100644 --- a/internal/tokenizer/tokenizer.go +++ b/internal/tokenizer/tokenizer.go @@ -347,8 +347,7 @@ func (p *analyzerPool) Close() { p.baseAnalyzer = nil } - logger.Info("Analyzer pool closed", - zap.Int32("final_size", atomic.LoadInt32(&p.currentSize))) + logger.Info(fmt.Sprintf("Analyzer pool closed, final_size: %d", atomic.LoadInt32(&p.currentSize))) } // GetPoolStats returns current pool statistics From d8f4d4d239b7de426574529cb311c9de9f59d25c Mon Sep 17 00:00:00 2001 From: Achieve3318 Date: Fri, 6 Mar 2026 07:17:11 -0500 Subject: [PATCH 167/565] feat(memory): implement get_highlight for OceanBase memory (#13449) ### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- memory/utils/highlight_utils.py | 89 +++++++++++++++++++ memory/utils/ob_conn.py | 10 ++- .../memory/utils/test_ob_conn_highlight.py | 79 ++++++++++++++++ 3 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 memory/utils/highlight_utils.py create mode 100644 test/unit_test/memory/utils/test_ob_conn_highlight.py diff --git a/memory/utils/highlight_utils.py b/memory/utils/highlight_utils.py new file mode 100644 index 00000000000..977fbe3a0fd --- /dev/null +++ b/memory/utils/highlight_utils.py @@ -0,0 +1,89 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Highlight helpers for search results (wraps keywords in ).""" + +import re +from collections.abc import Callable + + +def highlight_text( + txt: str, + keywords: list[str], + is_english_fn: Callable[[str], bool] | None = None, +) -> str: + """Wrap keyword matches in text with , by sentence. + + - If is_english_fn(sentence) is True: use word-boundary regex. + - Otherwise: literal replace (longest keywords first). + Only sentences that contain a match are included. + """ + if not txt or not keywords: + return "" + + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) + txt_list = [] + + for t in re.split(r"[.?!;\n]", txt): + t = t.strip() + if not t: + continue + + if is_english_fn is None or is_english_fn(t): + for w in keywords: + t = re.sub( + r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-]|$)" % re.escape(w), + r"\1\2\3", + t, + flags=re.IGNORECASE | re.MULTILINE, + ) + else: + for w in sorted(keywords, key=len, reverse=True): + t = re.sub( + re.escape(w), + f"{w}", + t, + flags=re.IGNORECASE | re.MULTILINE, + ) + + if re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): + txt_list.append(t) + + return "...".join(txt_list) if txt_list else txt + + +def get_highlight_from_messages( + messages: list[dict] | None, + keywords: list[str], + field_name: str, + is_english_fn: Callable[[str], bool] | None = None, +) -> dict[str, str]: + """Build id -> highlighted text from a list of message dicts.""" + if not messages or not keywords: + return {} + + ans = {} + for doc in messages: + doc_id = doc.get("id") + if not doc_id: + continue + txt = doc.get(field_name) + if not txt or not isinstance(txt, str): + continue + highlighted = highlight_text(txt, keywords, is_english_fn) + if highlighted and re.search(r"[^<>]+", highlighted, flags=re.IGNORECASE | re.MULTILINE): + ans[doc_id] = highlighted + return ans diff --git a/memory/utils/ob_conn.py b/memory/utils/ob_conn.py index 09c976e2ca5..f179992373c 100644 --- a/memory/utils/ob_conn.py +++ b/memory/utils/ob_conn.py @@ -25,9 +25,11 @@ from common.decorator import singleton from memory.utils.aggregation_utils import aggregate_by_field +from memory.utils.highlight_utils import get_highlight_from_messages from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template from common.float_utils import get_float +from rag.nlp import is_english from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize # Column definitions for memory message table @@ -605,8 +607,12 @@ def get_fields(self, res, fields: list[str]) -> dict[str, dict]: def get_highlight(self, res, keywords: list[str], field_name: str): """Get highlighted text for search results.""" - # TODO: Implement highlight functionality for OceanBase memory - return {} + if isinstance(res, tuple): + res = res[0] + messages = getattr(res, "messages", None) + return get_highlight_from_messages( + messages, keywords, field_name, is_english_fn=lambda s: is_english([s]) + ) def get_aggregation(self, res, field_name: str): """Get aggregation for search results.""" diff --git a/test/unit_test/memory/utils/test_ob_conn_highlight.py b/test/unit_test/memory/utils/test_ob_conn_highlight.py new file mode 100644 index 00000000000..99550cf0117 --- /dev/null +++ b/test/unit_test/memory/utils/test_ob_conn_highlight.py @@ -0,0 +1,79 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for OceanBase memory get_highlight. + +Tests the pure highlight logic used by OBConnection.get_highlight, +without requiring a real OceanBase instance or heavy dependencies. +""" + +from memory.utils.highlight_utils import get_highlight_from_messages, highlight_text + + +class TestHighlightText: + """Tests for highlight_text (word-boundary mode when is_english_fn is None).""" + + def test_empty_text_returns_empty(self): + assert highlight_text("", ["foo"]) == "" + assert highlight_text("hello", []) == "" + + def test_wraps_keyword_with_em(self): + out = highlight_text("The quick brown fox.", ["quick"], None) + assert "quick" in out + assert "The" in out and "brown fox" in out + + def test_only_sentences_with_match_included(self): + out = highlight_text( + "First sentence. Second has keyword. Third none.", + ["keyword"], + None, + ) + assert "Second has keyword" in out + assert "First sentence" not in out and "Third none" not in out + + def test_multiple_keywords(self): + out = highlight_text("Alpha and beta here.", ["Alpha", "beta"], None) + assert "Alpha" in out and "beta" in out + + +class TestGetHighlightFromMessages: + """Tests for get_highlight_from_messages (used by get_highlight).""" + + def test_empty_messages_returns_empty_dict(self): + assert get_highlight_from_messages([], ["k"], "content_ltks") == {} + assert get_highlight_from_messages(None, ["k"], "content_ltks") == {} + + def test_empty_keywords_returns_empty_dict(self): + assert get_highlight_from_messages( + [{"id": "m1", "content_ltks": "hello"}], [], "content_ltks" + ) == {} + + def test_returns_id_to_highlighted_text(self): + messages = [ + {"id": "msg1", "content_ltks": "The cat sat."}, + {"id": "msg2", "content_ltks": "The dog ran."}, + ] + out = get_highlight_from_messages(messages, ["cat"], "content_ltks") + assert list(out.keys()) == ["msg1"] + assert "cat" in out["msg1"] + out2 = get_highlight_from_messages(messages, ["dog"], "content_ltks") + assert list(out2.keys()) == ["msg2"] + assert "dog" in out2["msg2"] + + def test_skips_docs_without_field(self): + messages = [{"id": "m1"}, {"id": "m2", "content_ltks": "hello world."}] + out = get_highlight_from_messages(messages, ["hello"], "content_ltks") + assert "m2" in out and "hello" in out["m2"] From a4c688a84dbea96d8a63cc2e97fe87a7fcdaceb6 Mon Sep 17 00:00:00 2001 From: balibabu Date: Fri, 6 Mar 2026 20:17:21 +0800 Subject: [PATCH 168/565] Feat: Add PublishConfirmDialog (#13447) ### What problem does this PR solve? Feat: Add PublishConfirmDialog ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/hooks/use-agent-request.ts | 1 + web/src/locales/en.ts | 5 + .../components/publish-confirm-dialog.tsx | 103 ++++++++++++++++++ web/src/pages/agent/hooks/use-save-graph.ts | 15 ++- web/src/pages/agent/index.tsx | 46 ++++---- 5 files changed, 147 insertions(+), 23 deletions(-) create mode 100644 web/src/pages/agent/components/publish-confirm-dialog.tsx diff --git a/web/src/hooks/use-agent-request.ts b/web/src/hooks/use-agent-request.ts index a385c8b87d9..df577530e27 100644 --- a/web/src/hooks/use-agent-request.ts +++ b/web/src/hooks/use-agent-request.ts @@ -306,6 +306,7 @@ export const useSetAgent = (showMessage: boolean = true) => { dsl?: Record; avatar?: string; canvas_category?: string; + release?: string; }) => { const { data = {} } = await agentService.setCanvas(params); if (data.code === 0) { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index b928d8496e3..0977e4eddca 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -2180,6 +2180,10 @@ This process aggregates variables from multiple branches into a single variable 'Write your SQL query here. You can use variables, raw SQL, or mix both using variable syntax.', frameworkPrompts: 'Framework', release: 'Publish', + confirmPublish: 'Confirm Publish', + publishDescription: 'You are about to publish this data pipeline.', + linkedDataset: 'Linked dataset', + lastPublished: 'Last published', createFromBlank: 'Create from blank', createFromTemplate: 'Create from template', importJsonFile: 'Import JSON file', @@ -2554,6 +2558,7 @@ Important structured information may include: names, dates, locations, events, k import: 'Import', description: 'Description', noDescription: 'No description', + none: 'None', resourceType: { dataset: 'Dataset', diff --git a/web/src/pages/agent/components/publish-confirm-dialog.tsx b/web/src/pages/agent/components/publish-confirm-dialog.tsx new file mode 100644 index 00000000000..a058bef9cbf --- /dev/null +++ b/web/src/pages/agent/components/publish-confirm-dialog.tsx @@ -0,0 +1,103 @@ +import { Button, ButtonLoading } from '@/components/ui/button'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from '@/components/ui/dialog'; +import { Operator } from '@/pages/agent/constant'; +import useGraphStore from '@/pages/agent/store'; +import { formatDate } from '@/utils/date'; +import { BookPlus } from 'lucide-react'; +import { useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +interface PublishConfirmDialogProps { + agentDetail: { title: string; update_time?: number }; + loading: boolean; + onPublish: () => void; +} + +export function PublishConfirmDialog({ + agentDetail, + loading, + onPublish, +}: PublishConfirmDialogProps) { + const { t } = useTranslation(); + const [open, setOpen] = useState(false); + const nodes = useGraphStore((state) => state.nodes); + + const linkedDatasets = useMemo(() => { + const datasets: string[] = []; + nodes.forEach((node) => { + if (node.data.label === Operator.Retrieval) { + const kbIds = node.data.form?.kb_ids || []; + datasets.push(...kbIds); + } + }); + return [...new Set(datasets)]; + }, [nodes]); + + const lastPublished = useMemo(() => { + if (agentDetail?.update_time) { + return formatDate(agentDetail.update_time); + } + return '-'; + }, [agentDetail?.update_time]); + + const handleConfirmPublish = () => { + onPublish(); + setOpen(false); + }; + + return ( + + + + {t('flow.release')} + + + + + {t('flow.confirmPublish')} + + +
+
+ + {agentDetail.title} + +
+
+
+ + {t('flow.linkedDataset')} + + + {linkedDatasets.length > 0 + ? linkedDatasets.join(', ') + : t('common.none')} + +
+
+ + {t('flow.lastPublished')} + + {lastPublished} +
+
+
+
+ + + + +
+
+ ); +} diff --git a/web/src/pages/agent/hooks/use-save-graph.ts b/web/src/pages/agent/hooks/use-save-graph.ts index d308c21e0d9..ac764605d59 100644 --- a/web/src/pages/agent/hooks/use-save-graph.ts +++ b/web/src/pages/agent/hooks/use-save-graph.ts @@ -21,13 +21,22 @@ export const useSaveGraph = (showMessage: boolean = true) => { const saveGraph = useCallback( async ( currentNodes?: RAGFlowNodeType[], - otherParam?: { globalVariables: Record }, + otherParam?: { + globalVariables: Record; + }, + release?: boolean, ) => { - return setAgent({ + const params: Record = { id, title: data.title, dsl: buildDslData(currentNodes, otherParam), - }); + }; + + if (release) { + params.release = 'true'; + } + + return setAgent(params); }, [setAgent, data, id, buildDslData], ); diff --git a/web/src/pages/agent/index.tsx b/web/src/pages/agent/index.tsx index f53c6a2ee51..1623b91d876 100644 --- a/web/src/pages/agent/index.tsx +++ b/web/src/pages/agent/index.tsx @@ -17,6 +17,7 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu'; +import message from '@/components/ui/message'; import { SharedFrom } from '@/constants/chat'; import { useSetModalState } from '@/hooks/common-hooks'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; @@ -38,6 +39,7 @@ import { useTranslation } from 'react-i18next'; import { useParams } from 'react-router'; import AgentCanvas from './canvas'; import { DropdownProvider } from './canvas/context'; +import { PublishConfirmDialog } from './components/publish-confirm-dialog'; import { Operator } from './constant'; import { GlobalParamSheet } from './gobal-variable-sheet'; import { useCancelCurrentDataflow } from './hooks/use-cancel-dataflow'; @@ -238,13 +240,6 @@ export default function Agent() { > {t('flow.save')} - showGlobalParamSheet()} - loading={loading} - > - {t('flow.conversationVariable')} - - - {isPipeline || ( - - )} {isConversationMode && ( + showGlobalParamSheet()}> + + {t('flow.conversationVariable')} + + + + + {t('flow.historyVersion')} + + + {isPipeline || ( + navigateToAgentLogs(id as string)()} + > + + {t('flow.log')} + + )} + {t('flow.export')} From 795503619672364d44679726b83f2b1449ae7d57 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Fri, 6 Mar 2026 20:17:29 +0800 Subject: [PATCH 169/565] Fix: Add folder upload #9743 (#13448) ### What problem does this PR solve? Fix: Add folder upload #9743 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- .../components/file-upload-dialog/index.tsx | 12 +- web/src/components/file-uploader.tsx | 206 ++++++++++++------ web/src/locales/en.ts | 2 + web/src/locales/zh.ts | 2 + 4 files changed, 155 insertions(+), 67 deletions(-) diff --git a/web/src/components/file-upload-dialog/index.tsx b/web/src/components/file-upload-dialog/index.tsx index ba90d3fa490..d7239e2e8b6 100644 --- a/web/src/components/file-upload-dialog/index.tsx +++ b/web/src/components/file-upload-dialog/index.tsx @@ -20,8 +20,16 @@ import { Switch } from '../ui/switch'; function buildUploadFormSchema(t: TFunction) { const FormSchema = z.object({ parseOnCreation: z.boolean().optional(), + // Update schema to allow files with path property to handle folder uploads fileList: z - .array(z.instanceof(File)) + .array( + z.instanceof(File).or( + z.object({ + file: z.instanceof(File), + path: z.string(), // Store the relative path for files in folders + }), + ), + ) .min(1, { message: t('fileManager.pleaseUploadAtLeastOneFile') }), }); @@ -72,7 +80,7 @@ function UploadForm({ submit, showParseOnCreation }: UploadFormProps) { )} )} - + {(field) => ( (null); + const reachesMaxFileCount = (files?.length ?? 0) >= maxFileCount; - const onDrop = React.useCallback( + const processFiles = React.useCallback( (acceptedFiles: File[], rejectedFiles: FileRejection[]) => { if (!multiple && maxFileCount === 1 && acceptedFiles.length > 1) { toast.error('Cannot upload more than 1 file at a time'); @@ -216,11 +219,16 @@ export function FileUploader(props: FileUploaderProps) { return; } - const newFiles = acceptedFiles.map((file) => - Object.assign(file, { - preview: URL.createObjectURL(file), - }), - ); + const newFiles = acceptedFiles.map((file) => { + const enhancedFile = file as File & { preview?: string }; + Object.defineProperty(enhancedFile, 'preview', { + value: URL.createObjectURL(file), + writable: true, + enumerable: true, + configurable: true, + }); + return enhancedFile; + }); const updatedFiles = files ? [...files, ...newFiles] : newFiles; @@ -250,10 +258,26 @@ export function FileUploader(props: FileUploaderProps) { }); } }, - [files, maxFileCount, multiple, onUpload, setFiles], ); + const onDrop = React.useCallback( + (acceptedFiles: File[], rejectedFiles: FileRejection[]) => { + processFiles(acceptedFiles, rejectedFiles); + }, + [processFiles], + ); + + const handleFolderSelect = React.useCallback( + (e: React.ChangeEvent) => { + if (!e.target.files) return; + const fileList = Array.from(e.target.files); + processFiles(fileList, []); + e.target.value = ''; + }, + [processFiles], + ); + function onRemove(index: number) { if (!files) return; const newFiles = files.filter((_, i) => i !== index); @@ -276,68 +300,120 @@ export function FileUploader(props: FileUploaderProps) { const isDisabled = disabled || (files?.length ?? 0) >= maxFileCount; - return ( -
- {!(hideDropzoneOnMaxFileCount && reachesMaxFileCount) && ( - 1 || multiple} - disabled={isDisabled} + const renderDropzone = (isFolderMode: boolean = false) => ( + 1 || multiple} + disabled={isDisabled} + noClick={isFolderMode} + noDrag={isFolderMode} + > + {({ getRootProps, getInputProps, isDragActive }) => ( +
- {({ getRootProps, getInputProps, isDragActive }) => ( + {!isFolderMode && } + {isDragActive && !isFolderMode ? ( +
+
+
+

+ {t('fileManager.dropFilesHere', 'Drop the files here')} +

+
+ ) : (
{ + if (isFolderMode && !isDisabled) { + folderInputRef.current?.click(); + } + }} > - - {isDragActive ? ( -
-
-
-

- Drop the files here -

-
- ) : ( -
-
-
-
-

- {title || t('knowledgeDetails.uploadTitle')} -

-

- {description || t('knowledgeDetails.uploadDescription')} - {/* You can upload - {maxFileCount > 1 - ? ` ${maxFileCount === Infinity ? 'multiple' : maxFileCount} - files (up to ${formatBytes(maxSize)} each)` - : ` a file with ${formatBytes(maxSize)}`} */} -

-
-
- )} +
+ {isFolderMode ? ( +
+
+

+ {title || + (isFolderMode + ? t('fileManager.uploadFolderTitle', 'Upload Folder') + : t('knowledgeDetails.uploadTitle'))} +

+

+ {description || + (isFolderMode + ? t( + 'knowledgeDetails.uploadDescription', + 'Select a folder to upload all files inside', + ) + : t('knowledgeDetails.uploadDescription'))} +

+
)} - +
+ )} +
+ ); + + return ( +
+ {!(hideDropzoneOnMaxFileCount && reachesMaxFileCount) && ( + + + + + {t('fileManager.files', 'Files')} + + + + {t('fileManager.folder', 'Folder')} + + + + {renderDropzone(false)} + + + {renderDropzone(true)} + + + )} {files?.length ? ( diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 0977e4eddca..a071391f1ab 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1446,6 +1446,8 @@ Example: Virtual Hosted Style`, hint: 'hint', }, fileManager: { + uploadFolderTitle: 'Upload folder', + folder: 'Folder', files: 'Files', name: 'Name', uploadDate: 'Upload date', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 8326a673d14..8a278177f56 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -1200,6 +1200,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 hint: '提示', }, fileManager: { + uploadFolderTitle: '上传文件夹', + folder: '文件夹', files: '文件', name: '名称', uploadDate: '上传日期', From 6310add0503c472233b585951b5dfdd4b3d328fd Mon Sep 17 00:00:00 2001 From: Liu An Date: Fri, 6 Mar 2026 20:17:39 +0800 Subject: [PATCH 170/565] Test: adjust test priority markers for API tests (#13450) ### What problem does this PR solve? Changed test priority markers from p1/p2 to p3 in three test files: - test_table_parser_dataset_chat.py: Adjusted priority for table parser dataset chat test - test_delete_chunks.py: Updated priority for chunk deletion test with invalid IDs - test_retrieval_chunks.py: Modified priority for chunks retrieval pagination test These changes demote the priority of specific test cases to p3, indicating they are lower priority tests that can run later in the test suite execution. ### Type of change - [x] Test update --- .../test_chat_management/test_table_parser_dataset_chat.py | 2 +- .../test_chunk_management_within_dataset/test_delete_chunks.py | 2 +- .../test_retrieval_chunks.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py index 2fefa50ba72..3da599300f8 100644 --- a/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py +++ b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py @@ -156,7 +156,7 @@ def _teardown_chat_assistant(self): except Exception as e: print(f"[Teardown] Warning: Failed to delete chat assistant: {e}") - @pytest.mark.p1 + @pytest.mark.p3 @pytest.mark.parametrize( "question, expected_answer_pattern", [ diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py index eae75afada6..119974365dd 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -71,7 +71,7 @@ def test_invalid_document_id(self, HttpApiAuth, add_chunks_func, document_id, ex "payload", [ pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), - pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p3), pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), ], ) diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index 2c94f2d30e7..3e4d11c94dd 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -175,7 +175,7 @@ def test_vector_similarity_weight(self, HttpApiAuth, add_chunks, payload, expect else: assert res["message"] == expected_message - @pytest.mark.p2 + @pytest.mark.p3 @pytest.mark.parametrize( "payload, expected_code, expected_page_size, expected_message", [ From 0da796d71c6759d43a6dfc0be1126de48e654b6f Mon Sep 17 00:00:00 2001 From: Jimmy Ben Klieve Date: Fri, 6 Mar 2026 21:13:14 +0800 Subject: [PATCH 171/565] refactor(ui): adjust dataset page styles (#13452) ### What problem does this PR solve? - Adjust UI styles in **Dataset** pages. - Adjust several shared components styles - Modify files and directory structure in `src/layouts` ### Type of change - [x] Refactoring --- web/src/components/bulk-operate-bar.tsx | 108 ++++++------ .../dynamic-page-range.tsx | 22 ++- .../components/chunk-method-dialog/index.tsx | 91 +++++----- web/src/components/empty/empty.tsx | 19 ++- web/src/components/empty/interface.ts | 2 +- web/src/components/home-card.tsx | 81 +++++---- web/src/components/list-filter-bar/index.tsx | 6 +- web/src/components/more-button.tsx | 6 +- web/src/components/ragflow-avatar.tsx | 93 ++++------- web/src/components/svg-icon.tsx | 4 +- .../admin => }/components/theme-switch.tsx | 0 web/src/components/theme-toggle.tsx | 48 ------ web/src/components/ui/avatar.tsx | 2 +- web/src/components/ui/button.tsx | 14 +- web/src/components/ui/card.tsx | 34 ++-- web/src/components/ui/dialog.tsx | 7 +- web/src/components/ui/hover-card.tsx | 8 +- web/src/components/ui/radio.tsx | 37 +++- web/src/components/ui/segmented.tsx | 13 +- .../layouts/{ => components}/bell-button.tsx | 0 .../{ => components}/global-navbar.tsx | 0 .../header.tsx} | 0 .../{ => components}/page-container.tsx | 2 +- .../layouts/{ => components}/theme-button.tsx | 0 web/src/layouts/{next.tsx => root-layout.tsx} | 10 +- web/src/locales/en.ts | 3 +- .../pages/admin/layouts/navigation-layout.tsx | 2 +- web/src/pages/admin/login.tsx | 2 +- .../dataset/dataset/dataset-action-cell.tsx | 38 ++--- .../dataset/generate-button/generate.tsx | 38 +++-- web/src/pages/dataset/dataset/index.tsx | 42 +++-- .../pages/dataset/dataset/parsing-card.tsx | 5 +- .../dataset/dataset/parsing-status-cell.tsx | 158 +++++++++--------- .../dataset/use-bulk-operate-dataset.tsx | 24 +-- .../dataset/use-dataset-table-columns.tsx | 81 ++++++--- web/src/pages/datasets/dataset-card.tsx | 2 +- web/src/pages/home/application-card.tsx | 15 +- web/src/pages/home/applications.tsx | 34 ++-- web/src/pages/home/banner.tsx | 8 +- web/src/pages/home/datasets.tsx | 34 ++-- web/src/pages/home/index.tsx | 14 +- web/src/pages/next-chats/chat/index.tsx | 6 +- web/src/pages/user-setting/sidebar/index.tsx | 13 +- web/src/routes.tsx | 4 +- 44 files changed, 588 insertions(+), 542 deletions(-) rename web/src/{pages/admin => }/components/theme-switch.tsx (100%) delete mode 100644 web/src/components/theme-toggle.tsx rename web/src/layouts/{ => components}/bell-button.tsx (100%) rename web/src/layouts/{ => components}/global-navbar.tsx (100%) rename web/src/layouts/{next-header.tsx => components/header.tsx} (100%) rename web/src/layouts/{ => components}/page-container.tsx (81%) rename web/src/layouts/{ => components}/theme-button.tsx (100%) rename web/src/layouts/{next.tsx => root-layout.tsx} (60%) diff --git a/web/src/components/bulk-operate-bar.tsx b/web/src/components/bulk-operate-bar.tsx index ea71b50525d..6c9ceaf2c10 100644 --- a/web/src/components/bulk-operate-bar.tsx +++ b/web/src/components/bulk-operate-bar.tsx @@ -1,8 +1,7 @@ import { Button } from '@/components/ui/button'; import { Card, CardContent } from '@/components/ui/card'; -import { cn } from '@/lib/utils'; import { BrushCleaning } from 'lucide-react'; -import { ReactNode, useCallback } from 'react'; +import { ReactNode, useId } from 'react'; import { useTranslation } from 'react-i18next'; import { ConfirmDeleteDialog, @@ -30,58 +29,69 @@ export function BulkOperateBar({ className, unit, }: BulkOperateBarProps) { - const isDeleteItem = useCallback((id: string) => { - return id === 'delete'; - }, []); const { t } = useTranslation(); + const ariaDescriptionId = useId(); return ( - - -
- - {t('common.selected')}: {count}{' '} - {unit ?? t('knowledgeDetails.files')} - - -
- + + +

+ {t('common.selected')}: {count} {unit ?? t('knowledgeDetails.files')} + +

+ + +
    - {list.map((x) => ( -
  • - -
  • - ))} + {x.icon} {x.label} + + ); + + return ( +
  • + {isDeleteItem ? ( + + ), + }} + > + {buttonEl} + + ) : ( + buttonEl + )} +
  • + ); + })}
diff --git a/web/src/components/chunk-method-dialog/dynamic-page-range.tsx b/web/src/components/chunk-method-dialog/dynamic-page-range.tsx index 0e9863b13fa..6775d38a263 100644 --- a/web/src/components/chunk-method-dialog/dynamic-page-range.tsx +++ b/web/src/components/chunk-method-dialog/dynamic-page-range.tsx @@ -10,7 +10,7 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; -import { Plus, X } from 'lucide-react'; +import { LucidePlus, LucideTrash2 } from 'lucide-react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { Separator } from '../ui/separator'; @@ -51,7 +51,9 @@ export function DynamicPageRange() { )} /> + + )} /> -
); })} +
diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index a4291eea71b..c845fda35e4 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -42,7 +42,6 @@ import { DataFlowSelect } from '../data-pipeline-select'; import { DelimiterFormField } from '../delimiter-form-field'; import { EntityTypesFormField } from '../entity-types-form-field'; import { ExcelToHtmlFormField } from '../excel-to-html-form-field'; -import { FormContainer } from '../form-container'; import { LayoutRecognizeFormField } from '../layout-recognize-form-field'; import { MaxTokenNumberFormField } from '../max-token-number-from-field'; import { MinerUOptionsFormField } from '../mineru-options-form-field'; @@ -293,22 +292,16 @@ export function ChunkMethodDialog({ {t('knowledgeDetails.chunkMethod')} +
- +
- {parseType === 1 && } - {parseType === 2 && ( - - )} + {parseType === 1 && } {/* )} /> */} - {showPages && parseType === 1 && ( - - )} + + {showPages && parseType === 1 && } + {showPages && parseType === 1 && layoutRecognize && ( - + )} /> )} - +
+ {parseType === 1 && ( <> - +
{showOne && ( <> {isMineruSelected && } )} + {showMaxTokenNumber && ( <> - + /> + )} - - +
+ +
{selectedTag === DocumentParserType.Naive && ( <> )} + {showAutoKeywords(selectedTag) && ( <> )} + {showExcelToHtml && ( )} - +
{/* {showRaptorParseConfiguration( - selectedTag as DocumentParserType, - ) && ( - - - - )} */} - {/* {showGraphRagItems(selectedTag as DocumentParserType) && - useGraphRag && ( + selectedTag as DocumentParserType, + ) && ( - + )} */} - {showEntityTypes && ( - - )} + {/* {showGraphRagItems(selectedTag as DocumentParserType) && + useGraphRag && ( + + + + )} */} +
+ {showEntityTypes && } +
)} + +
+ {parseType === 2 && ( + + )} +
diff --git a/web/src/components/empty/empty.tsx b/web/src/components/empty/empty.tsx index fbb97506f67..f9f54b460f6 100644 --- a/web/src/components/empty/empty.tsx +++ b/web/src/components/empty/empty.tsx @@ -53,22 +53,24 @@ const Empty = (props: EmptyProps) => { export default Empty; export const EmptyCard = (props: EmptyCardProps) => { - const { icon, className, children, title, description, style } = props; + const { icon, className, children, title, description, style, ...restProps } = + props; return ( -
{icon} - {title &&
{title}
} + {title &&
{title}
} {description && ( -
{description}
+

{description}

)} {children} -
+ ); }; @@ -104,11 +106,14 @@ export const EmptyAppCard = (props: { break; } return ( -
+
diff --git a/web/src/components/empty/interface.ts b/web/src/components/empty/interface.ts index 73fef1d6069..87b7531289b 100644 --- a/web/src/components/empty/interface.ts +++ b/web/src/components/empty/interface.ts @@ -15,4 +15,4 @@ export type EmptyCardProps = { title?: string; description?: string; style?: React.CSSProperties; -}; +} & Omit, 'title'>; diff --git a/web/src/components/home-card.tsx b/web/src/components/home-card.tsx index d8ec97d749f..7320960b954 100644 --- a/web/src/components/home-card.tsx +++ b/web/src/components/home-card.tsx @@ -1,5 +1,5 @@ import { RAGFlowAvatar } from '@/components/ragflow-avatar'; -import { Card, CardContent } from '@/components/ui/card'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { formatDate } from '@/utils/date'; import { ReactNode } from 'react'; @@ -26,48 +26,61 @@ export function HomeCard({ }: IProps) { return ( { // navigateToSearch(data?.id); onClick?.(); }} + tabIndex={0} + className="px-2.5 py-4 flex gap-2 items-start group h-full w-full hover:shadow-md" > - -
- -
-
-
-
-
- {data.name} +
+ +
+ +
+ + +

+ {data.name} +

+ + {icon} +
+ +
{moreDropdown}
+
+ + +
+
+ +
+
+ {data.description} +
+
+

+ {formatDate(data.update_time)} +

+ {sharedBadge}
- {icon}
- {moreDropdown} -
- -
-
- {data.description} -
-
-

- {formatDate(data.update_time)} -

- {sharedBadge} -
-
-
-
+
+ +
); } diff --git a/web/src/components/list-filter-bar/index.tsx b/web/src/components/list-filter-bar/index.tsx index 3ee505220b2..e0b69a2353c 100644 --- a/web/src/components/list-filter-bar/index.tsx +++ b/web/src/components/list-filter-bar/index.tsx @@ -91,7 +91,7 @@ export default function ListFilterBar({ }, [value]); return ( -
+
{typeof icon === 'string' ? ( // @@ -101,7 +101,8 @@ export default function ListFilterBar({ )} {leftPanel || title}
-
+ +
{preChildren} {showFilter && ( {children}
diff --git a/web/src/components/more-button.tsx b/web/src/components/more-button.tsx index d9234a71a98..a73dad2eb87 100644 --- a/web/src/components/more-button.tsx +++ b/web/src/components/more-button.tsx @@ -9,10 +9,10 @@ export const MoreButton = React.forwardRef( - ); -}; - -export default ThemeToggle; diff --git a/web/src/components/ui/avatar.tsx b/web/src/components/ui/avatar.tsx index 9b47c4ecd65..567b9e00624 100644 --- a/web/src/components/ui/avatar.tsx +++ b/web/src/components/ui/avatar.tsx @@ -39,7 +39,7 @@ const AvatarFallback = React.forwardRef< ; + export type ButtonProps = { asChild?: boolean; asLink?: boolean; @@ -114,7 +120,7 @@ export type ButtonProps = { block?: boolean; disabled?: boolean; dot?: boolean; -} & VariantProps & +} & ButtonVariants & (IsAnchor extends true ? LinkProps : React.ButtonHTMLAttributes); @@ -144,7 +150,7 @@ const Button = React.forwardRef( ->(({ className, ...props }, ref) => ( -
& { as?: React.ElementType } +>(({ as: As = 'div', className, ...props }, ref) => ( + ->(({ className, ...props }, ref) => ( -
& { as?: React.ElementType } +>(({ as: As = 'div', className, ...props }, ref) => ( + & { as?: React.ElementType } ->(({ className, as: As = 'div', ...props }, ref) => ( +>(({ as: As = 'div', className, ...props }, ref) => ( ->(({ className, ...props }, ref) => ( -
& { as?: React.ElementType } +>(({ as: As = 'div', className, ...props }, ref) => ( + ->(({ className, ...props }, ref) => ( -
& { as?: React.ElementType } +>(({ as: As = 'div', className, ...props }, ref) => ( + ->(({ className, ...props }, ref) => ( -
& { as?: React.ElementType } +>(({ as: As = 'div', className, ...props }, ref) => ( + void; disabled?: boolean; @@ -50,30 +51,46 @@ function Radio({ return (
diff --git a/web/src/routes.tsx b/web/src/routes.tsx index 5bad48793c5..fb57d93e531 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -146,7 +146,7 @@ const routeConfigOptions = [ { path: Routes.Root, layout: false, - Component: () => import('@/layouts/next'), + Component: () => import('@/layouts/root-layout'), loader: ({ request }) => { const url = new URL(request.url); const auth = url.searchParams.get('auth'); @@ -170,7 +170,7 @@ const routeConfigOptions = [ }, { path: Routes.Root, - Component: () => import('@/layouts/next'), + Component: () => import('@/layouts/root-layout'), children: [ { path: Routes.Datasets, From b3fa1573919ed520052f72a0c01c76eb0b6b18f7 Mon Sep 17 00:00:00 2001 From: Heyang Wang Date: Fri, 6 Mar 2026 21:13:23 +0800 Subject: [PATCH 172/565] =?UTF-8?q?Feat:=20add=20=20DingTalk=20AI=20Table?= =?UTF-8?q?=20connector=20and=20integration=20for=20data=20synch=E2=80=A6?= =?UTF-8?q?=20(#13413)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Add DingTalk AI Table connector and integration for data synchronization Issue #13400 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: wangheyang --- common/constants.py | 1 + common/data_source/__init__.py | 2 + common/data_source/config.py | 1 + .../dingtalk_ai_table_connector.py | 433 ++++++++++++++++++ pyproject.toml | 1 + rag/svr/sync_data_source.py | 45 ++ uv.lock | 29 ++ .../svg/data-source/dingtalk-ai-table.svg | 27 ++ web/src/locales/bg.ts | 2 + web/src/locales/de.ts | 2 + web/src/locales/en.ts | 2 + web/src/locales/es.ts | 2 + web/src/locales/fr.ts | 2 + web/src/locales/id.ts | 2 + web/src/locales/it.ts | 2 + web/src/locales/ja.ts | 2 + web/src/locales/pt-br.ts | 2 + web/src/locales/ru.ts | 2 + web/src/locales/vi.ts | 2 + web/src/locales/zh-traditional.ts | 2 + web/src/locales/zh.ts | 2 + .../data-source/constant/index.tsx | 37 ++ 22 files changed, 602 insertions(+) create mode 100644 common/data_source/dingtalk_ai_table_connector.py create mode 100644 web/src/assets/svg/data-source/dingtalk-ai-table.svg diff --git a/common/constants.py b/common/constants.py index 6a939cf4cfd..cbc2f534c95 100644 --- a/common/constants.py +++ b/common/constants.py @@ -138,6 +138,7 @@ class FileSource(StrEnum): SEAFILE = "seafile" MYSQL = "mysql" POSTGRESQL = "postgresql" + DINGTALK_AI_TABLE = "dingtalk_ai_table" class PipelineTaskType(StrEnum): diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 099f3d7b3bd..022a1076135 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -36,6 +36,7 @@ from .teams_connector import TeamsConnector from .moodle_connector import MoodleConnector from .airtable_connector import AirtableConnector +from .dingtalk_ai_table_connector import DingTalkAITableConnector from .asana_connector import AsanaConnector from .imap_connector import ImapConnector from .zendesk_connector import ZendeskConnector @@ -83,4 +84,5 @@ "SeaFileConnector", "RDBMSConnector", "WebDAVConnector", + "DingTalkAITableConnector", ] diff --git a/common/data_source/config.py b/common/data_source/config.py index b05d8af24af..65338f34a65 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -66,6 +66,7 @@ class DocumentSource(str, Enum): SEAFILE = "seafile" MYSQL = "mysql" POSTGRESQL = "postgresql" + DINGTALK_AI_TABLE = "dingtalk_ai_table" class FileOrigin(str, Enum): diff --git a/common/data_source/dingtalk_ai_table_connector.py b/common/data_source/dingtalk_ai_table_connector.py new file mode 100644 index 00000000000..66588d4d307 --- /dev/null +++ b/common/data_source/dingtalk_ai_table_connector.py @@ -0,0 +1,433 @@ +"""DingTalk AI Table connector for RAGFlow. By the way, "notable" is a reference to the DingTalk AI Table. + +This connector ingests records from DingTalk AI Table as documents. +It first retrieves all sheets from a specified table, then fetches all records +from each sheet. + +API Documentation: +- GetAllSheets: https://open.dingtalk.com/document/development/api-notable-getallsheets +- ListRecords: https://open.dingtalk.com/document/development/api-notable-listrecords +""" + +import json +import logging +from datetime import datetime, timezone +from typing import Any + +from alibabacloud_dingtalk.notable_1_0.client import Client as NotableClient +from alibabacloud_dingtalk.notable_1_0 import models as notable_models +from alibabacloud_tea_openapi import models as open_api_models +from alibabacloud_tea_util import models as util_models +from alibabacloud_tea_util.client import Client as UtilClient + +from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch +from common.data_source.models import Document, GenerateDocumentsOutput + +logger = logging.getLogger(__name__) + +# Document ID prefix for DingTalk Notable +_DINGTALK_AI_TABLE_DOC_ID_PREFIX = "dingtalk_ai_table:" + + +class DingTalkAITableClientNotSetUpError(PermissionError): + """Exception raised when DingTalk Notable client is not initialized.""" + + def __init__(self) -> None: + super().__init__("DingTalk Notable client is not set up. Did you forget to call load_credentials()?") + + +class DingTalkAITableConnector(LoadConnector, PollConnector): + """ + DingTalk AI Table (Notable) connector for accessing table records. + + This connector: + 1. Retrieves all sheets from a specified Notable table using GetAllSheets API + 2. For each sheet, fetches all records using ListRecords API with pagination + 3. Converts each record into a Document for RAGFlow ingestion + + Required credentials: + - access_token: DingTalk access token (x-acs-dingtalk-access-token) + - operator_id: User's unionId for API calls + + Configuration: + - table_id: The Notable table ID (e.g., 'qnYxxx') + """ + + def __init__( + self, + table_id: str, + operator_id: str, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + """ + Initialize the DingTalk Notable connector. + + Args: + table_id: The Notable table ID + operator_id: User's unionId for API calls + batch_size: Number of records per batch for document generation + """ + self.table_id = table_id + self.operator_id = operator_id + self.batch_size = batch_size + self._client: NotableClient | None = None + self._access_token: str | None = None + + def _create_client(self) -> NotableClient: + """Create DingTalk Notable API client.""" + config = open_api_models.Config() + config.protocol = "https" + config.region_id = "central" + return NotableClient(config) + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + """ + Load DingTalk credentials. + + Args: + credentials: Dictionary containing 'access_token' + + Returns: + None + """ + access_token = credentials.get("access_token") + if not access_token: + raise ConnectorMissingCredentialError("DingTalk access_token is required") + + self._access_token = access_token + self._client = self._create_client() + return None + + @property + def client(self) -> NotableClient: + """Get the DingTalk AITable client.""" + if self._client is None: + raise DingTalkAITableClientNotSetUpError() + return self._client + + @property + def access_token(self) -> str: + """Get the access token.""" + if self._access_token is None: + raise ConnectorMissingCredentialError("DingTalk access_token not loaded") + return self._access_token + + def validate_connector_settings(self) -> None: + """Validate DingTalk connector settings by trying to get all sheets.""" + if self._client is None or self._access_token is None: + raise ConnectorMissingCredentialError("DingTalk Notable") + + try: + # Try to get sheets to validate credentials + headers = notable_models.GetAllSheetsHeaders() + headers.x_acs_dingtalk_access_token = self._access_token + + request = notable_models.GetAllSheetsRequest( + operator_id=self.operator_id, + ) + + self.client.get_all_sheets_with_options( + self.table_id, + request, + headers, + util_models.RuntimeOptions(), + ) + except Exception as e: + logger.exception("[DingTalk Notable]: Failed to validate credentials") + raise ConnectorValidationError(f"DingTalk Notable credential validation failed: {e}") + + def _get_all_sheets(self) -> list[dict[str, Any]]: + """ + Retrieve all sheets from the Notable table. + + Returns: + List of sheet information dictionaries + """ + headers = notable_models.GetAllSheetsHeaders() + headers.x_acs_dingtalk_access_token = self._access_token + + request = notable_models.GetAllSheetsRequest( + operator_id=self.operator_id, + ) + + try: + response = self.client.get_all_sheets_with_options( + self.table_id, + request, + headers, + util_models.RuntimeOptions(), + ) + + sheets = [] + if response.body and response.body.value: + for sheet in response.body.value: + sheets.append( + { + "id": sheet.id, + "name": sheet.name, + } + ) + + logger.info(f"[DingTalk Notable]: Found {len(sheets)} sheets in table {self.table_id}") + return sheets + + except Exception as e: + logger.exception(f"[DingTalk Notable]: Failed to get sheets: {e}") + raise + + def _list_records( + self, + sheet_id: str, + next_token: str | None = None, + max_results: int = 100, + ) -> tuple[list[dict[str, Any]], str | None]: + """ + List records from a specific sheet with pagination. + + Args: + sheet_id: The sheet ID + next_token: Token for pagination + max_results: Maximum number of results per page + + Returns: + Tuple of (records list, next_token or None if no more) + """ + headers = notable_models.ListRecordsHeaders() + headers.x_acs_dingtalk_access_token = self._access_token + + request = notable_models.ListRecordsRequest( + operator_id=self.operator_id, + max_results=max_results, + next_token=next_token or "", + ) + + try: + response = self.client.list_records_with_options( + self.table_id, + sheet_id, + request, + headers, + util_models.RuntimeOptions(), + ) + + records = [] + new_next_token = None + + if response.body: + if response.body.records: + for record in response.body.records: + records.append( + { + "id": record.id, + "fields": record.fields, + } + ) + if response.body.next_token: + new_next_token = response.body.next_token + + return records, new_next_token + + except Exception as e: + if not UtilClient.empty(getattr(e, "code", None)) and not UtilClient.empty(getattr(e, "message", None)): + logger.error(f"[DingTalk AITable]: API error - code: {e.code}, message: {e.message}") + raise + + def _get_all_records(self, sheet_id: str) -> list[dict[str, Any]]: + """ + Retrieve all records from a sheet with pagination. + + Args: + sheet_id: The sheet ID + + Returns: + List of all records + """ + all_records = [] + next_token = None + + while True: + records, next_token = self._list_records( + sheet_id=sheet_id, + next_token=next_token, + ) + all_records.extend(records) + + if not next_token: + break + + logger.info(f"[DingTalk Notable]: Retrieved {len(all_records)} records from sheet {sheet_id}") + return all_records + + def _convert_record_to_document( + self, + record: dict[str, Any], + sheet_id: str, + sheet_name: str, + ) -> Document: + """ + Convert a Notable record to a Document. + + Args: + record: The record dictionary + sheet_id: The sheet ID + sheet_name: The sheet name + + Returns: + Document object + """ + record_id = record.get("id", "unknown") + fields = record.get("fields", {}) + + # Convert fields to JSON string for blob content + content = json.dumps(fields, ensure_ascii=False, indent=2) + blob = content.encode("utf-8") + + # Create semantic identifier from record fields + # Try to find a meaningful title/name field + semantic_identifier = f"{sheet_name} - Record {record_id}" + + # Try to find a title-like field + for key, value in fields.items(): + if isinstance(value, str) and len(value) > 0 and len(value) < 100: + semantic_identifier = f"{sheet_name} - {value[:50]}" + break + + # Metadata + metadata: dict[str, str | list[str]] = { + "table_id": self.table_id, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "record_id": record_id, + } + + # Create document + doc = Document( + id=f"{_DINGTALK_AI_TABLE_DOC_ID_PREFIX}{self.table_id}:{sheet_id}:{record_id}", + source=DocumentSource.DINGTALK_AI_TABLE, + semantic_identifier=semantic_identifier, + extension=".json", + blob=blob, + size_bytes=len(blob), + doc_updated_at=datetime.now(timezone.utc), + metadata=metadata, + ) + + return doc + + def _yield_documents_from_table( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateDocumentsOutput: + """ + Yield documents from all sheets in the table. + + Args: + start: Optional start timestamp for filtering + end: Optional end timestamp for filtering + + Yields: + Lists of Document objects + """ + # Get all sheets + sheets = self._get_all_sheets() + + batch: list[Document] = [] + + for sheet in sheets: + sheet_id = sheet["id"] + sheet_name = sheet["name"] + + # Get all records from this sheet + records = self._get_all_records(sheet_id) + + for record in records: + doc = self._convert_record_to_document( + record=record, + sheet_id=sheet_id, + sheet_name=sheet_name, + ) + + # Apply time filtering if specified + if start is not None or end is not None: + doc_time = doc.doc_updated_at.timestamp() if doc.doc_updated_at else None + if doc_time is not None: + if start is not None and doc_time < start: + continue + if end is not None and doc_time > end: + continue + + batch.append(doc) + + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def load_from_state(self) -> GenerateDocumentsOutput: + """ + Load all documents from the DingTalk Notable table. + + Yields: + Lists of Document objects + """ + return self._yield_documents_from_table() + + def poll_source( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + ) -> GenerateDocumentsOutput: + """ + Poll for documents within a time range. + + Args: + start: Start timestamp + end: End timestamp + + Yields: + Lists of Document objects + """ + return self._yield_documents_from_table(start=start, end=end) + + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.DEBUG) + + # Example usage + table_id = os.environ.get("DINGTALK_AI_TABLE_BASE_ID", "") + operator_id = os.environ.get("DINGTALK_OPERATOR_ID", "") + access_token = os.environ.get("DINGTALK_ACCESS_TOKEN", "") + + if not all([table_id, operator_id, access_token]): + print("Please set DINGTALK_AI_TABLE_BASE_ID, DINGTALK_OPERATOR_ID, and DINGTALK_ACCESS_TOKEN environment variables") + exit(1) + + connector = DingTalkAITableConnector( + table_id=table_id, + operator_id=operator_id, + ) + connector.load_credentials({"access_token": access_token}) + + try: + connector.validate_connector_settings() + print("Connector settings validated successfully") + except Exception as e: + print(f"Validation failed: {e}") + exit(1) + + document_batches = connector.load_from_state() + try: + first_batch = next(document_batches) + print(f"Loaded {len(first_batch)} documents in first batch.") + for doc in first_batch[:5]: # Print first 5 docs + print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") + print(f" Metadata: {doc.metadata}") + except StopIteration: + print("No documents available in DingTalk Notable table.") diff --git a/pyproject.toml b/pyproject.toml index 53dc38cf8cf..0665a1c5365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,7 @@ dependencies = [ "pygithub>=2.8.1", "asana>=5.2.2", "python-gitlab>=7.0.0", + "alibabacloud-dingtalk>=2.0.0", "quart-schema==0.23.0", ] diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index ac317d418ef..044c7484dff 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -55,6 +55,7 @@ ZendeskConnector, SeaFileConnector, RDBMSConnector, + DingTalkAITableConnector, ) from common.constants import FileSource, TaskStatus from common.data_source.config import INDEX_BATCH_SIZE @@ -1221,6 +1222,49 @@ async def _generate(self, task: dict): return document_generator +class DingTalkAITable(SyncBase): + SOURCE_NAME: str = FileSource.DINGTALK_AI_TABLE + + async def _generate(self, task: dict): + """ + Sync records from DingTalk AI Table (Notable). + """ + self.connector = DingTalkAITableConnector( + table_id=self.conf.get("table_id"), + operator_id=self.conf.get("operator_id"), + batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), + ) + + credentials = self.conf.get("credentials", {}) + if "access_token" not in credentials: + raise ValueError("Missing access_token in credentials") + + self.connector.load_credentials( + {"access_token": credentials["access_token"]} + ) + + poll_start = task.get("poll_range_start") + + if task.get("reindex") == "1" or poll_start is None: + document_generator = self.connector.load_from_state() + begin_info = "totally" + else: + document_generator = self.connector.poll_source( + poll_start.timestamp(), + datetime.now(timezone.utc).timestamp(), + ) + begin_info = f"from {poll_start}" + + logging.info( + "Connect to DingTalk AI Table: table_id(%s), operator_id(%s) %s", + self.conf.get("table_id"), + self.conf.get("operator_id"), + begin_info, + ) + + return document_generator + + class MySQL(SyncBase): SOURCE_NAME: str = FileSource.MYSQL @@ -1321,6 +1365,7 @@ async def _generate(self, task: dict): FileSource.SEAFILE: SeaFile, FileSource.MYSQL: MySQL, FileSource.POSTGRESQL: PostgreSQL, + FileSource.DINGTALK_AI_TABLE: DingTalkAITable, } diff --git a/uv.lock b/uv.lock index 6c545065650..0b1423a014c 100644 --- a/uv.lock +++ b/uv.lock @@ -319,12 +319,39 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1b/c6/7d375cc1b1cab0f46950f556b70a2b17235747429a0889b73f3d46ff6023/alibabacloud_devs20230714-2.4.1-py3-none-any.whl", hash = "sha256:dbd260718e6db50021d804218b40bc99ee9c7e40b1def382aef8e542f5921113", size = 59307, upload-time = "2025-08-08T07:40:28.504Z" }, ] +[[package]] +name = "alibabacloud-dingtalk" +version = "2.2.38" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "alibabacloud-endpoint-util" }, + { name = "alibabacloud-gateway-dingtalk" }, + { name = "alibabacloud-gateway-spi" }, + { name = "alibabacloud-openapi-util" }, + { name = "alibabacloud-tea-openapi" }, + { name = "alibabacloud-tea-util" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1c/95/ad71f7cb1a2814d17f3a731b37c27eb71df0895daa912e2436d2f0dd06d4/alibabacloud_dingtalk-2.2.38.tar.gz", hash = "sha256:39cba6ff3accf0a5c7fe7de651a65a5a784c4ef63e442750fd822c19864ed6f1", size = 1954698, upload-time = "2026-01-08T11:38:04.913Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/77/f6679f78becf73bfbf3468367ffa6fe694120e0cd677eacbc454dbb379a1/alibabacloud_dingtalk-2.2.38-py3-none-any.whl", hash = "sha256:c3dfc918c45f49fe61469230c0808fd6f316341594a2895564511b5542f50019", size = 2074155, upload-time = "2026-01-08T11:38:03.395Z" }, +] + [[package]] name = "alibabacloud-endpoint-util" version = "0.0.4" source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/7d/8cc92a95c920e344835b005af6ea45a0db98763ad6ad19299d26892e6c8d/alibabacloud_endpoint_util-0.0.4.tar.gz", hash = "sha256:a593eb8ddd8168d5dc2216cd33111b144f9189fcd6e9ca20e48f358a739bbf90", size = 2813, upload-time = "2025-06-12T07:20:52.572Z" } +[[package]] +name = "alibabacloud-gateway-dingtalk" +version = "1.0.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "alibabacloud-gateway-spi" }, + { name = "alibabacloud-tea-util" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/40/751d8bdf133d7fcf053f10c98e8e506810e7bee06458a02eaaa14d30ac26/alibabacloud_gateway_dingtalk-1.0.2.tar.gz", hash = "sha256:acea8b0b1d11e0394913f0b0899ddd19c0bfceab716060449b57fcc250ceb300", size = 2938, upload-time = "2023-04-25T09:48:42.249Z" } + [[package]] name = "alibabacloud-gateway-spi" version = "0.0.3" @@ -6218,6 +6245,7 @@ dependencies = [ { name = "agentrun-sdk" }, { name = "aiosmtplib" }, { name = "akshare" }, + { name = "alibabacloud-dingtalk" }, { name = "anthropic" }, { name = "arxiv" }, { name = "asana" }, @@ -6357,6 +6385,7 @@ requires-dist = [ { name = "agentrun-sdk", specifier = ">=0.0.16,<1.0.0" }, { name = "aiosmtplib", specifier = ">=5.0.0" }, { name = "akshare", specifier = ">=1.15.78,<2.0.0" }, + { name = "alibabacloud-dingtalk", specifier = ">=2.0.0" }, { name = "anthropic", specifier = "==0.34.1" }, { name = "arxiv", specifier = "==2.1.3" }, { name = "asana", specifier = ">=5.2.2" }, diff --git a/web/src/assets/svg/data-source/dingtalk-ai-table.svg b/web/src/assets/svg/data-source/dingtalk-ai-table.svg new file mode 100644 index 00000000000..589602c4804 --- /dev/null +++ b/web/src/assets/svg/data-source/dingtalk-ai-table.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/locales/bg.ts b/web/src/locales/bg.ts index 03b391cf975..ef4fc7822ed 100644 --- a/web/src/locales/bg.ts +++ b/web/src/locales/bg.ts @@ -998,6 +998,8 @@ The above is the content you need to summarize.`, 'Свържете GitHub за синхронизиране на pull requests и issues за извличане.', airtableDescription: 'Свържете се с Airtable и синхронизирайте файлове от определена таблица в определено работно пространство.', + dingtalkAITableDescription: + 'Свържете се с Dingtalk AI Table и синхронизирайте записи от определена таблица.', gitlabDescription: 'Свържете GitLab за синхронизиране на хранилища, issues, merge requests и свързана документация.', asanaDescription: diff --git a/web/src/locales/de.ts b/web/src/locales/de.ts index 508115b186a..d549af3fb00 100644 --- a/web/src/locales/de.ts +++ b/web/src/locales/de.ts @@ -1016,6 +1016,8 @@ Beispiel: Virtual Hosted Style`, 'Verbinden Sie GitHub, um Pull Requests und Issues zur Recherche zu synchronisieren.', airtableDescription: 'Verbinden Sie sich mit Airtable und synchronisieren Sie Dateien aus einer bestimmten Tabelle in einem vorgesehenen Arbeitsbereich.', + dingtalkAITableDescription: + 'Verbinden Sie sich mit Dingtalk AI Table und synchronisieren Sie Datensätze aus einer bestimmten Tabelle.', asanaDescription: 'Verbinden Sie sich mit Asana und synchronisieren Sie Dateien aus einem bestimmten Arbeitsbereich.', imapDescription: diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 4df8e286b52..4257d5a0e18 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1034,6 +1034,8 @@ Example: Virtual Hosted Style`, 'Connect GitHub to sync pull requests and issues for retrieval.', airtableDescription: 'Connect to Airtable and synchronize files from a specified table within a designated workspace.', + dingtalkAITableDescription: + 'Connect to Dingtalk AI Table and synchronize records from a specified table.', gitlabDescription: 'Connect GitLab to sync repositories, issues, merge requests, and related documentation.', asanaDescription: diff --git a/web/src/locales/es.ts b/web/src/locales/es.ts index 816353495fa..1ea94962dc0 100644 --- a/web/src/locales/es.ts +++ b/web/src/locales/es.ts @@ -472,6 +472,8 @@ export default { apiVersionMessage: '¡Por favor ingresa la versión de la API!', modelsToBeAddedTooltip: 'Si tu proveedor de modelos no aparece en la lista pero afirma ser compatible con OpenAI, selecciona la tarjeta OpenAI-API-compatible para añadir el/los modelo(s) correspondiente(s).', + dingtalkAITableDescription: + 'Conéctese a Dingtalk AI Table y sincronice registros de una tabla especificada.', }, message: { registered: '¡Registrado!', diff --git a/web/src/locales/fr.ts b/web/src/locales/fr.ts index ccb735c4103..f19f1b6d05c 100644 --- a/web/src/locales/fr.ts +++ b/web/src/locales/fr.ts @@ -680,6 +680,8 @@ export default { modelsToBeAddedTooltip: 'Si votre fournisseur de modèle n\'est pas listé mais prétend être "compatible OpenAI", sélectionnez la carte compatible OpenAI-API pour ajouter le(s) modèle(s) pertinent(s).', mcp: 'MCP', + dingtalkAITableDescription: + 'Connectez-vous à Dingtalk AI Table et synchronisez les enregistrements d\'une table spécifiée.', }, message: { registered: 'Enregistré !', diff --git a/web/src/locales/id.ts b/web/src/locales/id.ts index 7969d85666c..09aa6a29d18 100644 --- a/web/src/locales/id.ts +++ b/web/src/locales/id.ts @@ -669,6 +669,8 @@ export default { apiVersionMessage: 'Silakan masukkan versi API', modelsToBeAddedTooltip: 'Jika penyedia model Anda tidak tercantum tetapi mengklaim kompatibel dengan OpenAI, pilih kartu OpenAI-API-compatible untuk menambahkan model yang relevan.', + dingtalkAITableDescription: + 'Hubungkan ke Dingtalk AI Table dan sinkronkan catatan dari tabel yang ditentukan.', }, message: { registered: 'Terdaftar!', diff --git a/web/src/locales/it.ts b/web/src/locales/it.ts index 04222f4607c..05627ea1dc8 100644 --- a/web/src/locales/it.ts +++ b/web/src/locales/it.ts @@ -840,6 +840,8 @@ Quanto sopra è il contenuto che devi riassumere.`, configuration: 'Configurazione', view: 'Visualizza', mcp: 'MCP', + dingtalkAITableDescription: + 'Connettiti a Dingtalk AI Table e sincronizza i record da una tabella specificata.', }, message: { registered: 'Registrato!', diff --git a/web/src/locales/ja.ts b/web/src/locales/ja.ts index 3eb93aae5e0..5b32e0fbf28 100644 --- a/web/src/locales/ja.ts +++ b/web/src/locales/ja.ts @@ -691,6 +691,8 @@ export default { sureQuit: '参加したチームから退出してもよろしいですか?', modelsToBeAddedTooltip: 'モデルプロバイダーがリストにないが「OpenAI互換」を謳っている場合は、OpenAI-API-compatible カードを選択して関連モデルを追加してください。', + dingtalkAITableDescription: + 'Dingtalk AI Table に接続し、指定されたテーブルからレコードを同期します。', }, message: { registered: '登録完了!', diff --git a/web/src/locales/pt-br.ts b/web/src/locales/pt-br.ts index 1ce96814ca1..719f21a97fc 100644 --- a/web/src/locales/pt-br.ts +++ b/web/src/locales/pt-br.ts @@ -639,6 +639,8 @@ export default { sureQuit: 'Tem certeza de que deseja sair da equipe que você ingressou?', modelsToBeAddedTooltip: 'Se o seu provedor de modelo não estiver listado, mas afirmar ser compatível com a OpenAI, selecione o card OpenAI-API-compatible para adicionar o(s) modelo(s) relevante(s). ', + dingtalkAITableDescription: + 'Conecte-se ao Dingtalk AI Table e sincronize registros de uma tabela especificada.', }, message: { registered: 'Registrado!', diff --git a/web/src/locales/ru.ts b/web/src/locales/ru.ts index 60b8d8ab1c7..923314ef5a5 100644 --- a/web/src/locales/ru.ts +++ b/web/src/locales/ru.ts @@ -778,6 +778,8 @@ export default { 'Подключите GitHub для синхронизации содержимого Pull Request и Issue для поиска.', airtableDescription: 'Подключите Airtable и синхронизируйте файлы из указанной таблицы в заданном рабочем пространстве.', + dingtalkAITableDescription: + 'Подключите Dingtalk AI Table и синхронизируйте записи из указанной таблицы.', gitlabDescription: 'Подключите GitLab для синхронизации репозиториев, задач, merge requests и связанной документации.', asanaDescription: diff --git a/web/src/locales/vi.ts b/web/src/locales/vi.ts index ffe73121312..c3f933c30a3 100644 --- a/web/src/locales/vi.ts +++ b/web/src/locales/vi.ts @@ -723,6 +723,8 @@ export default { FishAudioRefIDMessage: `Vui lòng nhập ID của model tham chiếu (để trống để sử dụng model mặc định)`, modelsToBeAddedTooltip: 'Nếu nhà cung cấp mô hình của bạn không có trong danh sách nhưng tuyên bố tương thích với "OpenAI", hãy chọn thẻ OpenAI-API-compatible để thêm mô hình liên quan.', + dingtalkAITableDescription: + 'Kết nối với Dingtalk AI Table và đồng bộ hóa bản ghi từ một bảng được chỉ định.', }, message: { registered: 'Đã đăng ký!', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 229c6bea5f1..611dddf0078 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -563,6 +563,8 @@ export default { avatar: '头像', avatarTip: '這會在你的個人主頁展示', profileDescription: '在此更新您的照片和個人詳細信息。', + dingtalkAITableDescription: + '連接釘釘AI表格,同步指定表格中的記錄。', gitlabDescription: '連接 GitLab,同步儲存庫、Issue、合併請求(MR)及相關文件內容。', bedrockCredentialsHint: diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 8a278177f56..3a624d42ebc 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -919,6 +919,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 githubDescription: '连接 GitHub,可同步 Pull Request 与 Issue 内容用于检索。', airtableDescription: '连接 Airtable,同步指定工作区下指定表格中的文件。', + dingtalkAITableDescription: + '连接钉钉AI表格,同步指定表格中的记录。', gitlabDescription: '连接 GitLab,同步仓库、Issue、合并请求(MR)及相关文档内容。', asanaDescription: '连接 Asana,同步工作区中的文件。', diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index 96f6987323a..b06d424b4cb 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -30,6 +30,7 @@ export enum DataSourceKey { OCI_STORAGE = 'oci_storage', GOOGLE_CLOUD_STORAGE = 'google_cloud_storage', AIRTABLE = 'airtable', + DINGTALK_AI_TABLE = 'dingtalk_ai_table', GITLAB = 'gitlab', ASANA = 'asana', IMAP = 'imap', @@ -123,6 +124,11 @@ export const generateDataSourceInfo = (t: TFunction) => { description: t(`setting.${DataSourceKey.AIRTABLE}Description`), icon: , }, + [DataSourceKey.DINGTALK_AI_TABLE]: { + name: 'Dingtalk AI Table', + description: t(`setting.dingtalkAITableDescription`), + icon: , + }, [DataSourceKey.GITLAB]: { name: 'GitLab', description: t(`setting.${DataSourceKey.GITLAB}Description`), @@ -658,6 +664,26 @@ export const DataSourceFormFields = { required: true, }, ], + [DataSourceKey.DINGTALK_AI_TABLE]: [ + { + label: 'Access Token', + name: 'config.credentials.access_token', + type: FormFieldType.Password, + required: true, + }, + { + label: 'Base ID', + name: 'config.table_id', + type: FormFieldType.Text, + required: true, + }, + { + label: 'Operator ID', + name: 'config.operator_id', + type: FormFieldType.Text, + required: true, + }, + ], [DataSourceKey.GITLAB]: [ { label: 'Project Owner', @@ -1135,6 +1161,17 @@ export const DataSourceFormDefaultValues = { }, }, }, + [DataSourceKey.DINGTALK_AI_TABLE]: { + name: '', + source: DataSourceKey.DINGTALK_AI_TABLE, + config: { + table_id: '', + operator_id: '', + credentials: { + access_token: '', + }, + }, + }, [DataSourceKey.GITLAB]: { name: '', source: DataSourceKey.GITLAB, From 3709075ea1e448fbfeec1b46e3e326b599f13fcd Mon Sep 17 00:00:00 2001 From: Liu An Date: Mon, 9 Mar 2026 10:32:51 +0800 Subject: [PATCH 173/565] Chore: update release workflow configuration (#13466) ### What problem does this PR solve? update release workflow configuration ### Type of change - [x] Update CI --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e18d1e2e51c..a5ddade391f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,7 +23,7 @@ concurrency: jobs: release: - runs-on: [ "self-hosted", "ragflow-test" ] + runs-on: [ "self-hosted", "ragflow-release" ] steps: - name: Ensure workspace ownership run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE} From 9972e84aba421c66153242ee53929267da8f9306 Mon Sep 17 00:00:00 2001 From: Eden <146086744+edenfunf@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:36:02 +0800 Subject: [PATCH 174/565] fix(agent): ensure database connections are properly closed in ExeSQL tool (#13427) ## Summary Fix a database connection and cursor resource leak in the ExeSQL agent tool. When SQL execution raises an exception (for example syntax error or missing table), the existing code path skips `cursor.close()` and `db.close()`, causing database connections to accumulate over time. This can eventually lead to connection exhaustion in long-running agent workflows. ## Root Cause The cleanup logic for database cursors and connections is placed after the SQL execution loop without `try/finally` protection. If an exception occurs during `cursor.execute()`, `fetchmany()`, or result processing, the cleanup code is not reached and the connection remains open. The same issue also exists in the IBM DB2 execution path where `ibm_db.close(conn)` may be skipped when exceptions occur. ## Fix - Wrap SQL execution logic in `try/finally` blocks to guarantee resource cleanup. - Ensure `cursor.close()` and `db.close()` are always executed. - Add explicit `db.close()` when `db.cursor()` creation fails. - Remove redundant close calls in early-return branches since `finally` now handles cleanup. ## Impact - No change to normal execution behavior. - Ensures database resources are always released when errors occur. - Prevents connection leaks in long-running workflows. - Only affects `agent/tools/exesql.py`. ## Testing Manual test scenarios: 1. Valid SQL execution 2. SQL syntax error 3. Query against a non-existing table 4. Execution cancellation during query In all scenarios the database cursor and connection are properly closed. Code quality checks: - `ruff check` passed - No new warnings introduced --- agent/tools/exesql.py | 134 ++++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 65 deletions(-) diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index 3f969f43164..305801124c1 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib import json import os import re @@ -195,43 +196,43 @@ def _parse_catalog_schema(db: str): except Exception as e: raise Exception("Database Connection Failed! \n" + str(e)) - sql_res = [] - formalized_content = [] - for single_sql in sqls: - if self.check_if_canceled("ExeSQL processing"): - ibm_db.close(conn) - return - - single_sql = single_sql.replace("```", "").strip() - if not single_sql: - continue - single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) - - stmt = ibm_db.exec_immediate(conn, single_sql) - rows = [] - row = ibm_db.fetch_assoc(stmt) - while row and len(rows) < self._param.max_records: + try: + sql_res = [] + formalized_content = [] + for single_sql in sqls: if self.check_if_canceled("ExeSQL processing"): - ibm_db.close(conn) return - rows.append(row) - row = ibm_db.fetch_assoc(stmt) - - if not rows: - sql_res.append({"content": "No record in the database!"}) - continue - - df = pd.DataFrame(rows) - for col in df.columns: - if pd.api.types.is_datetime64_any_dtype(df[col]): - df[col] = df[col].dt.strftime("%Y-%m-%d") - - df = df.where(pd.notnull(df), None) - sql_res.append(convert_decimals(df.to_dict(orient="records"))) - formalized_content.append(df.to_markdown(index=False, floatfmt=".6f")) + single_sql = single_sql.replace("```", "").strip() + if not single_sql: + continue + single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) - ibm_db.close(conn) + stmt = ibm_db.exec_immediate(conn, single_sql) + rows = [] + row = ibm_db.fetch_assoc(stmt) + while row and len(rows) < self._param.max_records: + if self.check_if_canceled("ExeSQL processing"): + return + rows.append(row) + row = ibm_db.fetch_assoc(stmt) + + if not rows: + sql_res.append({"content": "No record in the database!"}) + continue + + df = pd.DataFrame(rows) + for col in df.columns: + if pd.api.types.is_datetime64_any_dtype(df[col]): + df[col] = df[col].dt.strftime("%Y-%m-%d") + + df = df.where(pd.notnull(df), None) + + sql_res.append(convert_decimals(df.to_dict(orient="records"))) + formalized_content.append(df.to_markdown(index=False, floatfmt=".6f")) + finally: + with contextlib.suppress(Exception): + ibm_db.close(conn) self.set_output("json", sql_res) self.set_output("formalized_content", "\n\n".join(formalized_content)) @@ -239,42 +240,45 @@ def _parse_catalog_schema(db: str): try: cursor = db.cursor() except Exception as e: + with contextlib.suppress(Exception): + db.close() raise Exception("Database Connection Failed! \n" + str(e)) - sql_res = [] - formalized_content = [] - for single_sql in sqls: - if self.check_if_canceled("ExeSQL processing"): + try: + sql_res = [] + formalized_content = [] + for single_sql in sqls: + if self.check_if_canceled("ExeSQL processing"): + return + + single_sql = single_sql.replace('```', '').strip() + if not single_sql: + continue + single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) + cursor.execute(single_sql) + if cursor.rowcount == 0: + sql_res.append({"content": "No record in the database!"}) + break + if self._param.db_type == 'mssql': + single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), + columns=[desc[0] for desc in cursor.description]) + else: + single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) + single_res.columns = [i[0] for i in cursor.description] + + for col in single_res.columns: + if pd.api.types.is_datetime64_any_dtype(single_res[col]): + single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') + + single_res = single_res.where(pd.notnull(single_res), None) + + sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) + formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) + finally: + with contextlib.suppress(Exception): cursor.close() + with contextlib.suppress(Exception): db.close() - return - - single_sql = single_sql.replace('```','') - if not single_sql: - continue - single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) - cursor.execute(single_sql) - if cursor.rowcount == 0: - sql_res.append({"content": "No record in the database!"}) - break - if self._param.db_type == 'mssql': - single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), - columns=[desc[0] for desc in cursor.description]) - else: - single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) - single_res.columns = [i[0] for i in cursor.description] - - for col in single_res.columns: - if pd.api.types.is_datetime64_any_dtype(single_res[col]): - single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') - - single_res = single_res.where(pd.notnull(single_res), None) - - sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) - formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) - - cursor.close() - db.close() self.set_output("json", sql_res) self.set_output("formalized_content", "\n\n".join(formalized_content)) From 5757ca39548e688f6aeb468429e9d48be036ad0a Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 9 Mar 2026 10:44:53 +0800 Subject: [PATCH 175/565] Add more API of admin server of go (#13403) ### What problem does this PR solve? Add APIs to admin server. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- cmd/admin_server.go | 161 +++-- internal/admin/handler.go | 917 ++++++++++++++++++++---- internal/admin/router.go | 81 ++- internal/admin/service.go | 716 +++++++++++++++--- internal/dao/user.go | 9 +- internal/engine/elasticsearch/client.go | 142 ++++ internal/server/config.go | 200 ++++++ internal/service/user.go | 21 +- internal/utility/token.go | 5 + web/src/locales/en.ts | 2 +- web/src/utils/next-request.ts | 6 +- web/src/utils/request.ts | 8 +- 12 files changed, 1957 insertions(+), 311 deletions(-) diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 8a7587487b2..103ad6d227c 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -17,56 +17,37 @@ package main import ( + "context" "flag" "fmt" + "net/http" "os" + "os/signal" + "ragflow/internal/cache" + "ragflow/internal/engine" + "syscall" + "time" "github.com/gin-gonic/gin" "go.uber.org/zap" "ragflow/internal/admin" "ragflow/internal/dao" + "ragflow/internal/handler" "ragflow/internal/logger" "ragflow/internal/server" + "ragflow/internal/service" "ragflow/internal/utility" ) // AdminServer admin server type AdminServer struct { - router *admin.Router - handler *admin.Handler - service *admin.Service - engine *gin.Engine - port string -} - -// NewAdminServer create admin server -func NewAdminServer(port string) *AdminServer { - return &AdminServer{ - port: port, - } -} - -// Init initialize admin server -func (s *AdminServer) Init() error { - gin.SetMode(gin.ReleaseMode) - s.engine = gin.New() - s.engine.Use(gin.Recovery()) - - // Initialize layers - s.service = admin.NewService() - s.handler = admin.NewHandler(s.service) - s.router = admin.NewRouter(s.handler) - - // Setup routes - s.router.Setup(s.engine) - - return nil -} - -// Run start admin server -func (s *AdminServer) Run() error { - return s.engine.Run(":" + s.port) + router *admin.Router + handler *admin.Handler + service *admin.Service + userHandler *handler.UserHandler + engine *gin.Engine + port string } func main() { @@ -85,14 +66,29 @@ func main() { os.Exit(1) } + cfg := server.GetConfig() + + // Reinitialize logger with configured level if different + if cfg.Log.Level != "" && cfg.Log.Level != "info" { + if err := logger.Init(cfg.Log.Level); err != nil { + logger.Error("Failed to reinitialize logger with configured level", err) + } + } + // Set logger for server package server.SetLogger(logger.Logger) - cfg := server.GetConfig() - logger.Info("Configuration loaded", - zap.String("database_host", cfg.Database.Host), - zap.Int("database_port", cfg.Database.Port), - ) + logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) + + // Print all configuration settings + server.PrintAll() + + // Set Gin mode + if cfg.Server.Mode == "release" { + gin.SetMode(gin.ReleaseMode) + } else { + gin.SetMode(gin.DebugMode) + } // Initialize database if err := dao.InitDB(); err != nil { @@ -100,13 +96,58 @@ func main() { os.Exit(1) } - // Create and start admin server (port 9381) - adminServer := NewAdminServer("9381") - if err := adminServer.Init(); err != nil { - logger.Error("Failed to initialize admin server", err) - os.Exit(1) + // Initialize doc engine + if err := engine.Init(&cfg.DocEngine); err != nil { + logger.Fatal("Failed to initialize doc engine", zap.Error(err)) + } + defer engine.Close() + + // Initialize Redis cache + if err := cache.Init(&cfg.Redis); err != nil { + logger.Fatal("Failed to initialize Redis", zap.Error(err)) + } + defer cache.Close() + + // Initialize server variables (runtime variables that can change during operation) + // This must be done after Cache is initialized + if err := server.InitVariables(cache.Get()); err != nil { + logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) + } + + adminService := admin.NewService() + userService := service.NewUserService() + adminHandler := admin.NewHandler(adminService, userService) + + // Initialize router + r := admin.NewRouter(adminHandler) + + // Create Gin engine + ginEngine := gin.New() + + // Middleware + if cfg.Server.Mode == "debug" { + ginEngine.Use(gin.Logger()) + } + ginEngine.Use(gin.Recovery()) + // Log request URL for every request + ginEngine.Use(func(c *gin.Context) { + logger.Info("HTTP Request", zap.String("url", c.Request.URL.String()), zap.String("method", c.Request.Method)) + c.Next() + }) + + // Setup routes + r.Setup(ginEngine) + + // Create HTTP server + addr := fmt.Sprintf(":9381") + srv := &http.Server{ + Addr: addr, + Handler: ginEngine, } + // Print RAGFlow version + logger.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion())) + // Print all configuration settings server.PrintAll() @@ -118,11 +159,31 @@ func main() { " / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /\n" + " /_/ |_/_/ |_\\____/_/ /_/\\____/|__/|__/ /_/ |_\\__,_/_/ /_/ /_/_/_/ /_/ \n") - // Print RAGFlow version - logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: 9381")) - if err := adminServer.Run(); err != nil { - logger.Error("Admin server error", err) - os.Exit(1) + // Start server in a goroutine + go func() { + logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) + logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: 9381")) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Fatal("Failed to start server", zap.Error(err)) + } + }() + + // Wait for interrupt signal to gracefully shutdown + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) + sig := <-quit + + logger.Info("Received signal", zap.String("signal", sig.String())) + logger.Info("Shutting down server...") + + // Create context with timeout for graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Shutdown server + if err := srv.Shutdown(ctx); err != nil { + logger.Fatal("Server forced to shutdown", zap.Error(err)) } + + logger.Info("Server exited") } diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 18511ea6f3f..526a22fa912 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -19,6 +19,10 @@ package admin import ( "errors" "net/http" + "ragflow/internal/server" + "ragflow/internal/service" + "ragflow/internal/utility" + "strconv" "github.com/gin-gonic/gin" ) @@ -32,12 +36,52 @@ var ( // Handler admin handler type Handler struct { - service *Service + service *Service + userService *service.UserService } // NewHandler create admin handler -func NewHandler(service *Service) *Handler { - return &Handler{service: service} +func NewHandler(service *Service, userService *service.UserService) *Handler { + return &Handler{service: service, userService: userService} +} + +// SuccessResponse success response +type SuccessResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data"` +} + +// ErrorResponse error response +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// success returns success response +func success(c *gin.Context, data interface{}, message string) { + c.JSON(200, SuccessResponse{ + Code: 0, + Message: message, + Data: data, + }) +} + +// successNoData returns success response without data +func successNoData(c *gin.Context, message string) { + c.JSON(200, SuccessResponse{ + Code: 0, + Message: message, + Data: nil, + }) +} + +// error returns error response +func errorResponse(c *gin.Context, message string, code int) { + c.JSON(code, ErrorResponse{ + Code: code, + Message: message, + }) } // Health health check @@ -45,6 +89,11 @@ func (h *Handler) Health(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) } +// Ping ping endpoint +func (h *Handler) Ping(c *gin.Context) { + successNoData(c, "PONG") +} + // LoginHTTPRequest login request body type LoginHTTPRequest struct { Email string `json:"email" binding:"required"` @@ -53,183 +102,811 @@ type LoginHTTPRequest struct { // Login handle admin login func (h *Handler) Login(c *gin.Context) { - var req LoginHTTPRequest + var req service.EmailLoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(400, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) return } - svcReq := &LoginRequest{ - Email: req.Email, - Password: req.Password, - } - - resp, err := h.service.Login(svcReq) + user, code, err := h.userService.LoginByEmail(&req) if err != nil { - if errors.Is(err, ErrInvalidCredentials) { - c.JSON(401, gin.H{"error": "invalid credentials"}) - return - } - c.JSON(500, gin.H{"error": err.Error()}) + c.JSON(http.StatusUnauthorized, gin.H{ + "code": code, + "message": err.Error(), + }) return } - c.JSON(200, gin.H{ - "token": resp.Token, - "user": gin.H{ - "id": resp.UserID, - "email": resp.Email, - "nickname": resp.Nickname, - }, + variables := server.GetVariables() + secretKey := variables.SecretKey + authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) + + // Set Authorization header with access_token + if user.AccessToken != nil { + c.Header("Authorization", authToken) + } + // Set CORS headers + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "*") + c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Expose-Headers", "Authorization") + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "Login successful", }) } -// ListUsers handle list users -func (h *Handler) ListUsers(c *gin.Context) { - // Parse pagination params - offset := 0 - limit := 20 +// Logout handle logout +func (h *Handler) Logout(c *gin.Context) { + user, exists := c.Get("user") + if !exists { + errorResponse(c, "Not authenticated", 401) + return + } - svcReq := &ListUsersRequest{ - Offset: offset, - Limit: limit, + if err := h.service.Logout(user); err != nil { + errorResponse(c, err.Error(), 500) + return } - resp, err := h.service.ListUsers(svcReq) + successNoData(c, "Logout successful") +} + +// AuthCheck check admin auth +func (h *Handler) AuthCheck(c *gin.Context) { + successNoData(c, "Admin is authorized") +} + +// ListUsers handle list users +func (h *Handler) ListUsers(c *gin.Context) { + users, err := h.service.ListUsers() if err != nil { - c.JSON(500, gin.H{"error": err.Error()}) + errorResponse(c, err.Error(), 500) return } - // Convert to response format - var result []gin.H - for _, user := range resp.Users { - result = append(result, gin.H{ - "id": user.ID, - "email": user.Email, - "nickname": user.Nickname, - "is_active": user.IsActive, - "create_time": user.CreateTime, - "update_time": user.UpdateTime, - }) + success(c, users, "Get all users") +} + +// CreateUserHTTPRequest create user request +type CreateUserHTTPRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + Role string `json:"role"` +} + +// CreateUser handle create user +func (h *Handler) CreateUser(c *gin.Context) { + var req CreateUserHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Username and password are required", 400) + return } - c.JSON(200, gin.H{ - "data": result, - "total": resp.Total, - }) + if req.Role == "" { + req.Role = "user" + } + + userInfo, err := h.service.CreateUser(req.Username, req.Password, req.Role) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, userInfo, "User created successfully") } // GetUser handle get user func (h *Handler) GetUser(c *gin.Context) { - id := c.Param("id") - if id == "" { - c.JSON(400, gin.H{"error": "user id is required"}) + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) return } - svcReq := &GetUserRequest{ID: id} - user, err := h.service.GetUser(svcReq) + userDetails, err := h.service.GetUserDetails(username) if err != nil { if errors.Is(err, ErrUserNotFound) { - c.JSON(404, gin.H{"error": "user not found"}) + errorResponse(c, "User not found", 404) return } - c.JSON(500, gin.H{"error": err.Error()}) + errorResponse(c, err.Error(), 500) return } - c.JSON(200, gin.H{ - "id": user.ID, - "email": user.Email, - "nickname": user.Nickname, - "is_active": user.IsActive, - "create_time": user.CreateTime, - "update_time": user.UpdateTime, - }) + success(c, userDetails, "") +} + +// DeleteUser handle delete user +func (h *Handler) DeleteUser(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + if err := h.service.DeleteUser(username); err != nil { + errorResponse(c, err.Error(), 500) + return + } + + successNoData(c, "User deleted successfully") } -// UpdateUserHTTPRequest update user request body -type UpdateUserHTTPRequest struct { - Nickname string `json:"nickname"` - IsActive *string `json:"is_active,omitempty"` +// ChangePasswordHTTPRequest change password request +type ChangePasswordHTTPRequest struct { + NewPassword string `json:"new_password" binding:"required"` } -// UpdateUser handle update user -func (h *Handler) UpdateUser(c *gin.Context) { - id := c.Param("id") - if id == "" { - c.JSON(400, gin.H{"error": "user id is required"}) +// ChangePassword handle change password +func (h *Handler) ChangePassword(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) return } - var req UpdateUserHTTPRequest + var req ChangePasswordHTTPRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(400, gin.H{"error": err.Error()}) + errorResponse(c, "New password is required", 400) return } - svcReq := &UpdateUserRequest{ - ID: id, - Nickname: req.Nickname, - IsActive: req.IsActive, + if err := h.service.ChangePassword(username, req.NewPassword); err != nil { + errorResponse(c, err.Error(), 500) + return } - if err := h.service.UpdateUser(svcReq); err != nil { - if errors.Is(err, ErrUserNotFound) { - c.JSON(404, gin.H{"error": "user not found"}) + successNoData(c, "Password updated successfully") +} + +// UpdateActivateStatusHTTPRequest update activate status request +type UpdateActivateStatusHTTPRequest struct { + ActivateStatus bool `json:"activate_status" binding:"required"` +} + +// UpdateUserActivateStatus handle update user activate status +func (h *Handler) UpdateUserActivateStatus(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + var req UpdateActivateStatusHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Activation status is required", 400) + return + } + + if err := h.service.UpdateUserActivateStatus(username, req.ActivateStatus); err != nil { + errorResponse(c, err.Error(), 500) + return + } + + successNoData(c, "Activation status updated") +} + +// GrantAdmin handle grant admin role +func (h *Handler) GrantAdmin(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + // Get current user from context + currentUser, _ := c.Get("user") + if currentUser != nil && currentUser.(string) == username { + errorResponse(c, "can't grant current user: "+username, 409) + return + } + + if err := h.service.GrantAdmin(username); err != nil { + errorResponse(c, err.Error(), 500) + return + } + + successNoData(c, "Admin role granted") +} + +// RevokeAdmin handle revoke admin role +func (h *Handler) RevokeAdmin(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + // Get current user from context + currentUser, _ := c.Get("user") + if currentUser != nil && currentUser.(string) == username { + errorResponse(c, "can't revoke current user: "+username, 409) + return + } + + if err := h.service.RevokeAdmin(username); err != nil { + errorResponse(c, err.Error(), 500) + return + } + + successNoData(c, "Admin role revoked") +} + +// GetUserDatasets handle get user datasets +func (h *Handler) GetUserDatasets(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + datasets, err := h.service.GetUserDatasets(username) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, datasets, "") +} + +// GetUserAgents handle get user agents +func (h *Handler) GetUserAgents(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + agents, err := h.service.GetUserAgents(username) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, agents, "") +} + +// GetUserAPIKeys handle get user API keys +func (h *Handler) GetUserAPIKeys(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + apiKeys, err := h.service.GetUserAPIKeys(username) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, apiKeys, "Get user API keys") +} + +// GenerateUserAPIKey handle generate user API key +func (h *Handler) GenerateUserAPIKey(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + apiKey, err := h.service.GenerateUserAPIKey(username) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, apiKey, "API key generated successfully") +} + +// DeleteUserAPIKey handle delete user API key +func (h *Handler) DeleteUserAPIKey(c *gin.Context) { + username := c.Param("username") + key := c.Param("key") + if username == "" || key == "" { + errorResponse(c, "Username and key are required", 400) + return + } + + if err := h.service.DeleteUserAPIKey(username, key); err != nil { + errorResponse(c, err.Error(), 404) + return + } + + successNoData(c, "API key deleted successfully") +} + +// ListRoles handle list roles +func (h *Handler) ListRoles(c *gin.Context) { + roles, err := h.service.ListRoles() + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, roles, "") +} + +// CreateRoleHTTPRequest create role request +type CreateRoleHTTPRequest struct { + RoleName string `json:"role_name" binding:"required"` + Description string `json:"description"` +} + +// CreateRole handle create role +func (h *Handler) CreateRole(c *gin.Context) { + var req CreateRoleHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Role name is required", 400) + return + } + + role, err := h.service.CreateRole(req.RoleName, req.Description) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, role, "") +} + +// GetRole handle get role +func (h *Handler) GetRole(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + role, err := h.service.GetRole(roleName) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, role, "") +} + +// UpdateRoleHTTPRequest update role request +type UpdateRoleHTTPRequest struct { + Description string `json:"description" binding:"required"` +} + +// UpdateRole handle update role +func (h *Handler) UpdateRole(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + var req UpdateRoleHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Role description is required", 400) + return + } + + role, err := h.service.UpdateRole(roleName, req.Description) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, role, "") +} + +// DeleteRole handle delete role +func (h *Handler) DeleteRole(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + if err := h.service.DeleteRole(roleName); err != nil { + errorResponse(c, err.Error(), 500) + return + } + + successNoData(c, "") +} + +// GetRolePermission handle get role permission +func (h *Handler) GetRolePermission(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + permissions, err := h.service.GetRolePermission(roleName) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, permissions, "") +} + +// GrantRolePermissionHTTPRequest grant role permission request +type GrantRolePermissionHTTPRequest struct { + Actions []string `json:"actions" binding:"required"` + Resource string `json:"resource" binding:"required"` +} + +// GrantRolePermission handle grant role permission +func (h *Handler) GrantRolePermission(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + var req GrantRolePermissionHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Permission is required", 400) + return + } + + result, err := h.service.GrantRolePermission(roleName, req.Actions, req.Resource) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, result, "") +} + +// RevokeRolePermissionHTTPRequest revoke role permission request +type RevokeRolePermissionHTTPRequest struct { + Actions []string `json:"actions" binding:"required"` + Resource string `json:"resource" binding:"required"` +} + +// RevokeRolePermission handle revoke role permission +func (h *Handler) RevokeRolePermission(c *gin.Context) { + roleName := c.Param("role_name") + if roleName == "" { + errorResponse(c, "Role name is required", 400) + return + } + + var req RevokeRolePermissionHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Permission is required", 400) + return + } + + result, err := h.service.RevokeRolePermission(roleName, req.Actions, req.Resource) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, result, "") +} + +// UpdateUserRoleHTTPRequest update user role request +type UpdateUserRoleHTTPRequest struct { + RoleName string `json:"role_name" binding:"required"` +} + +// UpdateUserRole handle update user role +func (h *Handler) UpdateUserRole(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + var req UpdateUserRoleHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Role name is required", 400) + return + } + + result, err := h.service.UpdateUserRole(username, req.RoleName) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, result, "") +} + +// GetUserPermission handle get user permission +func (h *Handler) GetUserPermission(c *gin.Context) { + username := c.Param("username") + if username == "" { + errorResponse(c, "Username is required", 400) + return + } + + permissions, err := h.service.GetUserPermission(username) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, permissions, "") +} + +// GetServices handle get all services +func (h *Handler) GetServices(c *gin.Context) { + services, err := h.service.GetAllServices() + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, services, "Get all services") +} + +// GetServicesByType handle get services by type +func (h *Handler) GetServicesByType(c *gin.Context) { + serviceType := c.Param("service_type") + if serviceType == "" { + errorResponse(c, "Service type is required", 400) + return + } + + services, err := h.service.GetServicesByType(serviceType) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, services, "") +} + +// GetService handle get service details +func (h *Handler) GetService(c *gin.Context) { + serviceID := c.Param("service_id") + if serviceID == "" { + errorResponse(c, "Service ID is required", 400) + return + } + + // Get all services and find the one with matching ID + allConfigs := server.GetAllConfigs() + + var targetService map[string]interface{} + for _, config := range allConfigs { + if id, ok := config["id"]; ok { + if strconv.Itoa(id.(int)) == serviceID { + targetService = config + break + } + } + } + + if targetService == nil { + errorResponse(c, "Service not found", 404) + return + } + + serviceStatus, err := h.service.GetServiceDetails(targetService) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, serviceStatus, "") +} + +// ShutdownService handle shutdown service +func (h *Handler) ShutdownService(c *gin.Context) { + serviceID := c.Param("service_id") + if serviceID == "" { + errorResponse(c, "Service ID is required", 400) + return + } + + result, err := h.service.ShutdownService(serviceID) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, result, "") +} + +// RestartService handle restart service +func (h *Handler) RestartService(c *gin.Context) { + serviceID := c.Param("service_id") + if serviceID == "" { + errorResponse(c, "Service ID is required", 400) + return + } + + result, err := h.service.RestartService(serviceID) + if err != nil { + errorResponse(c, err.Error(), 500) + return + } + + success(c, result, "") +} + +// GetVariables handle get variables +func (h *Handler) GetVariables(c *gin.Context) { + varName := c.Query("var_name") + + if varName != "" { + // Get single variable + variable, err := h.service.GetVariable(varName) + if err != nil { + errorResponse(c, err.Error(), 400) return } - c.JSON(500, gin.H{"error": err.Error()}) + success(c, variable, "") + return + } + + // List all variables + variables, err := h.service.GetAllVariables() + if err != nil { + errorResponse(c, err.Error(), 500) return } - c.JSON(200, gin.H{"message": "user updated"}) + success(c, variables, "") } -// DeleteUser handle delete user -func (h *Handler) DeleteUser(c *gin.Context) { - id := c.Param("id") - if id == "" { - c.JSON(400, gin.H{"error": "user id is required"}) +// SetVariableHTTPRequest set variable request +type SetVariableHTTPRequest struct { + VarName string `json:"var_name" binding:"required"` + VarValue string `json:"var_value" binding:"required"` +} + +// SetVariable handle set variable +func (h *Handler) SetVariable(c *gin.Context) { + var req SetVariableHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Var name and value are required", 400) return } - svcReq := &DeleteUserRequest{ID: id} - if err := h.service.DeleteUser(svcReq); err != nil { - c.JSON(500, gin.H{"error": err.Error()}) + if err := h.service.SetVariable(req.VarName, req.VarValue); err != nil { + errorResponse(c, err.Error(), 400) return } - c.JSON(200, gin.H{"message": "user deleted"}) + successNoData(c, "Set variable successfully") } -// GetConfig handle get system config -func (h *Handler) GetConfig(c *gin.Context) { - config := h.service.GetSystemConfig() - c.JSON(200, config) +// GetConfigs handle get configs +func (h *Handler) GetConfigs(c *gin.Context) { + configs, err := h.service.GetAllConfigs() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, configs, "") +} + +// GetEnvironments handle get environments +func (h *Handler) GetEnvironments(c *gin.Context) { + environments, err := h.service.GetAllEnvironments() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, environments, "") } -// UpdateConfig handle update system config -func (h *Handler) UpdateConfig(c *gin.Context) { - var req map[string]interface{} +// GetVersion handle get version +func (h *Handler) GetVersion(c *gin.Context) { + version := h.service.GetVersion() + success(c, gin.H{"version": version}, "") +} + +// ListSandboxProviders handle list sandbox providers +func (h *Handler) ListSandboxProviders(c *gin.Context) { + providers, err := h.service.ListSandboxProviders() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, providers, "") +} + +// GetSandboxProviderSchema handle get sandbox provider schema +func (h *Handler) GetSandboxProviderSchema(c *gin.Context) { + providerID := c.Param("provider_id") + if providerID == "" { + errorResponse(c, "Provider ID is required", 400) + return + } + + schema, err := h.service.GetSandboxProviderSchema(providerID) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, schema, "") +} + +// GetSandboxConfig handle get sandbox config +func (h *Handler) GetSandboxConfig(c *gin.Context) { + config, err := h.service.GetSandboxConfig() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, config, "") +} + +// SetSandboxConfigHTTPRequest set sandbox config request +type SetSandboxConfigHTTPRequest struct { + ProviderType string `json:"provider_type" binding:"required"` + Config map[string]interface{} `json:"config"` + SetActive bool `json:"set_active"` +} + +// SetSandboxConfig handle set sandbox config +func (h *Handler) SetSandboxConfig(c *gin.Context) { + var req SetSandboxConfigHTTPRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(400, gin.H{"error": err.Error()}) + errorResponse(c, "Request body is required", 400) + return + } + + if req.ProviderType == "" { + errorResponse(c, "provider_type is required", 400) return } - if err := h.service.UpdateSystemConfig(req); err != nil { - c.JSON(500, gin.H{"error": err.Error()}) + // Default to true for backward compatibility + _ = c.Request.Body.Close() + req.SetActive = true + + result, err := h.service.SetSandboxConfig(req.ProviderType, req.Config, req.SetActive) + if err != nil { + errorResponse(c, err.Error(), 400) return } - c.JSON(200, gin.H{"message": "config updated"}) + success(c, result, "Sandbox configuration updated successfully") } -// GetStatus handle get system status -func (h *Handler) GetStatus(c *gin.Context) { - status := h.service.GetSystemStatus() - c.JSON(200, status) +// TestSandboxConnectionHTTPRequest test sandbox connection request +type TestSandboxConnectionHTTPRequest struct { + ProviderType string `json:"provider_type" binding:"required"` + Config map[string]interface{} `json:"config"` +} + +// TestSandboxConnection handle test sandbox connection +func (h *Handler) TestSandboxConnection(c *gin.Context) { + var req TestSandboxConnectionHTTPRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Request body is required", 400) + return + } + + if req.ProviderType == "" { + errorResponse(c, "provider_type is required", 400) + return + } + + result, err := h.service.TestSandboxConnection(req.ProviderType, req.Config) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, result, "") } // AuthMiddleware JWT auth middleware @@ -237,34 +914,32 @@ func (h *Handler) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { token := c.GetHeader("Authorization") if token == "" { - c.JSON(401, gin.H{"error": "missing authorization header"}) + errorResponse(c, "missing authorization header", 401) c.Abort() return } - // Remove "Bearer " prefix - if len(token) > 7 && token[:7] == "Bearer " { - token = token[7:] - } - - // Validate token - user, err := h.service.ValidateToken(token) + // Get user by access token + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(401, gin.H{"error": "invalid token"}) - c.Abort() + c.JSON(http.StatusUnauthorized, gin.H{ + "code": code, + "message": "Invalid access token", + }) return } c.Set("user", user) + c.Set("user_id", user.ID) + c.Set("email", user.Email) c.Next() } } // HandleNoRoute handle undefined routes func (h *Handler) HandleNoRoute(c *gin.Context) { - c.JSON(http.StatusNotFound, gin.H{ - "error": "Not Found", - "message": "The requested resource was not found", - "path": c.Request.URL.Path, + c.JSON(http.StatusNotFound, ErrorResponse{ + Code: 404, + Message: "The requested resource was not found", }) } diff --git a/internal/admin/router.go b/internal/admin/router.go index fa6572888bf..3dc03c2c143 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -18,16 +18,21 @@ package admin import ( "github.com/gin-gonic/gin" + + "ragflow/internal/handler" ) // Router admin router type Router struct { - handler *Handler + handler *Handler + userHandler *handler.UserHandler } // NewRouter create admin router func NewRouter(handler *Handler) *Router { - return &Router{handler: handler} + return &Router{ + handler: handler, + } } // Setup setup routes @@ -35,28 +40,78 @@ func (r *Router) Setup(engine *gin.Engine) { // Health check engine.GET("/health", r.handler.Health) - // Admin API routes - admin := engine.Group("/admin") + // Admin API routes with prefix /api/v1/admin + admin := engine.Group("/api/v1/admin") { - // Auth + // Public routes + admin.GET("/ping", r.handler.Ping) admin.POST("/login", r.handler.Login) // Protected routes protected := admin.Group("") protected.Use(r.handler.AuthMiddleware()) { + // Auth + protected.GET("/auth", r.handler.AuthCheck) + protected.GET("/logout", r.handler.Logout) + // User management protected.GET("/users", r.handler.ListUsers) - protected.GET("/users/:id", r.handler.GetUser) - protected.PUT("/users/:id", r.handler.UpdateUser) - protected.DELETE("/users/:id", r.handler.DeleteUser) + protected.POST("/users", r.handler.CreateUser) + protected.GET("/users/:username", r.handler.GetUser) + protected.DELETE("/users/:username", r.handler.DeleteUser) + protected.PUT("/users/:username/password", r.handler.ChangePassword) + protected.PUT("/users/:username/activate", r.handler.UpdateUserActivateStatus) + protected.PUT("/users/:username/admin", r.handler.GrantAdmin) + protected.DELETE("/users/:username/admin", r.handler.RevokeAdmin) + protected.GET("/users/:username/datasets", r.handler.GetUserDatasets) + protected.GET("/users/:username/agents", r.handler.GetUserAgents) + + // API Keys + protected.GET("/users/:username/keys", r.handler.GetUserAPIKeys) + protected.POST("/users/:username/keys", r.handler.GenerateUserAPIKey) + protected.DELETE("/users/:username/keys/:key", r.handler.DeleteUserAPIKey) + + // Role management + protected.GET("/roles", r.handler.ListRoles) + protected.POST("/roles", r.handler.CreateRole) + protected.GET("/roles/:role_name", r.handler.GetRole) + protected.PUT("/roles/:role_name", r.handler.UpdateRole) + protected.DELETE("/roles/:role_name", r.handler.DeleteRole) + protected.GET("/roles/:role_name/permission", r.handler.GetRolePermission) + protected.POST("/roles/:role_name/permission", r.handler.GrantRolePermission) + protected.DELETE("/roles/:role_name/permission", r.handler.RevokeRolePermission) + + // User roles and permissions + protected.PUT("/users/:username/role", r.handler.UpdateUserRole) + protected.GET("/users/:username/permission", r.handler.GetUserPermission) + + // Service management + protected.GET("/services", r.handler.GetServices) + protected.GET("/service_types/:service_type", r.handler.GetServicesByType) + protected.GET("/services/:service_id", r.handler.GetService) + protected.DELETE("/services/:service_id", r.handler.ShutdownService) + protected.PUT("/services/:service_id", r.handler.RestartService) + + // Variables/Settings + protected.GET("/variables", r.handler.GetVariables) + protected.PUT("/variables", r.handler.SetVariable) + + // Configs + protected.GET("/configs", r.handler.GetConfigs) + + // Environments + protected.GET("/environments", r.handler.GetEnvironments) - // System config - protected.GET("/config", r.handler.GetConfig) - protected.PUT("/config", r.handler.UpdateConfig) + // Version + protected.GET("/version", r.handler.GetVersion) - // System status - protected.GET("/status", r.handler.GetStatus) + // Sandbox + protected.GET("/sandbox/providers", r.handler.ListSandboxProviders) + protected.GET("/sandbox/providers/:provider_id/schema", r.handler.GetSandboxProviderSchema) + protected.GET("/sandbox/config", r.handler.GetSandboxConfig) + protected.POST("/sandbox/config", r.handler.SetSandboxConfig) + protected.POST("/sandbox/test", r.handler.TestSandboxConnection) } } diff --git a/internal/admin/service.go b/internal/admin/service.go index 80b2792dfec..2438ef6c85d 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -17,10 +17,20 @@ package admin import ( - "time" - + "crypto/rand" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "net/http" + "os" + "ragflow/internal/cache" "ragflow/internal/dao" + "ragflow/internal/engine/elasticsearch" "ragflow/internal/model" + "ragflow/internal/server" + "ragflow/internal/utility" + "time" ) // Service admin service layer @@ -57,13 +67,13 @@ func (s *Service) Login(req *LoginRequest) (*LoginResponse, error) { return nil, ErrInvalidCredentials } - // Verify password - if user.Password == nil || *user.Password != req.Password { - return nil, ErrInvalidCredentials + // Check if user is active + if user.IsActive != "1" { + return nil, errors.New("user is not active") } // Generate access token - token := generateToken() + token := utility.GenerateToken() if err := s.userDAO.UpdateAccessToken(user, token); err != nil { return nil, err } @@ -76,155 +86,649 @@ func (s *Service) Login(req *LoginRequest) (*LoginResponse, error) { }, nil } -// ListUsersRequest list users request -type ListUsersRequest struct { - Offset int - Limit int -} - -// ListUsersResponse list users response -type ListUsersResponse struct { - Users []*UserInfo - Total int64 +// Logout user logout +func (s *Service) Logout(user interface{}) error { + // Invalidate token by setting it to INVALID_ prefix + if u, ok := user.(*model.User); ok { + invalidToken := "INVALID_" + generateRandomHex(16) + return s.userDAO.UpdateAccessToken(u, invalidToken) + } + return nil } -// UserInfo user info -type UserInfo struct { - ID string - Email string - Nickname string - IsActive string - CreateTime *int64 - UpdateTime *int64 +// generateRandomHex generate random hex string +func generateRandomHex(n int) string { + bytes := make([]byte, n) + rand.Read(bytes) + return hex.EncodeToString(bytes) } // ListUsers list all users -func (s *Service) ListUsers(req *ListUsersRequest) (*ListUsersResponse, error) { - users, total, err := s.userDAO.List(req.Offset, req.Limit) +func (s *Service) ListUsers() ([]map[string]interface{}, error) { + users, _, err := s.userDAO.List(0, 0) if err != nil { return nil, err } - var result []*UserInfo + result := make([]map[string]interface{}, 0, len(users)) for _, user := range users { - result = append(result, &UserInfo{ - ID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - IsActive: user.IsActive, - CreateTime: user.CreateTime, - UpdateTime: user.UpdateTime, + result = append(result, map[string]interface{}{ + "email": user.Email, + "nickname": user.Nickname, + "create_date": user.CreateTime, + "is_active": user.IsActive, + "is_superuser": user.IsSuperuser, }) } - - return &ListUsersResponse{ - Users: result, - Total: total, - }, nil + return result, nil } -// GetUserRequest get user request -type GetUserRequest struct { - ID string +// CreateUser create a new user +func (s *Service) CreateUser(username, password, role string) (map[string]interface{}, error) { + // TODO: Implement user creation with proper password hashing + return map[string]interface{}{ + "username": username, + "role": role, + }, nil } -// GetUser get user by ID -func (s *Service) GetUser(req *GetUserRequest) (*UserInfo, error) { +// GetUserDetails get user details +func (s *Service) GetUserDetails(username string) (map[string]interface{}, error) { + // Query user by email/username var user model.User - err := dao.DB.Where("id = ?", req.ID).First(&user).Error + err := dao.DB.Where("email = ?", username).First(&user).Error if err != nil { return nil, ErrUserNotFound } - return &UserInfo{ - ID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - IsActive: user.IsActive, - CreateTime: user.CreateTime, - UpdateTime: user.UpdateTime, + return map[string]interface{}{ + "id": user.ID, + "email": user.Email, + "nickname": user.Nickname, + "is_active": user.IsActive, + "create_time": user.CreateTime, + "update_time": user.UpdateTime, }, nil } -// UpdateUserRequest update user request -type UpdateUserRequest struct { - ID string - Nickname string - IsActive *string +// DeleteUser delete user +func (s *Service) DeleteUser(username string) error { + // TODO: Implement user deletion + return nil } -// UpdateUser update user -func (s *Service) UpdateUser(req *UpdateUserRequest) error { - var user model.User - if err := dao.DB.Where("id = ?", req.ID).First(&user).Error; err != nil { - return ErrUserNotFound - } +// ChangePassword change user password +func (s *Service) ChangePassword(username, newPassword string) error { + // TODO: Implement password change + return nil +} - if req.Nickname != "" { - user.Nickname = req.Nickname - } - if req.IsActive != nil { - user.IsActive = *req.IsActive +// UpdateUserActivateStatus update user activate status +func (s *Service) UpdateUserActivateStatus(username string, isActive bool) error { + // TODO: Implement activate status update + return nil +} + +// GrantAdmin grant admin privileges +func (s *Service) GrantAdmin(username string) error { + // TODO: Implement grant admin + return nil +} + +// RevokeAdmin revoke admin privileges +func (s *Service) RevokeAdmin(username string) error { + // TODO: Implement revoke admin + return nil +} + +// GetUserDatasets get user datasets +func (s *Service) GetUserDatasets(username string) ([]map[string]interface{}, error) { + // TODO: Implement get user datasets + return []map[string]interface{}{}, nil +} + +// GetUserAgents get user agents +func (s *Service) GetUserAgents(username string) ([]map[string]interface{}, error) { + // TODO: Implement get user agents + return []map[string]interface{}{}, nil +} + +// API Key methods + +// GetUserAPIKeys get user API keys +func (s *Service) GetUserAPIKeys(username string) ([]map[string]interface{}, error) { + // TODO: Implement get API keys + return []map[string]interface{}{}, nil +} + +// GenerateUserAPIKey generate API key for user +func (s *Service) GenerateUserAPIKey(username string) (map[string]interface{}, error) { + // TODO: Implement generate API key + return map[string]interface{}{}, nil +} + +// DeleteUserAPIKey delete user API key +func (s *Service) DeleteUserAPIKey(username, key string) error { + // TODO: Implement delete API key + return nil +} + +// Role management methods + +// ListRoles list all roles +func (s *Service) ListRoles() ([]map[string]interface{}, error) { + // TODO: Implement list roles + return []map[string]interface{}{}, nil +} + +// CreateRole create a new role +func (s *Service) CreateRole(roleName, description string) (map[string]interface{}, error) { + // TODO: Implement create role + return map[string]interface{}{}, nil +} + +// GetRole get role details +func (s *Service) GetRole(roleName string) (map[string]interface{}, error) { + // TODO: Implement get role + return map[string]interface{}{}, nil +} + +// UpdateRole update role +func (s *Service) UpdateRole(roleName, description string) (map[string]interface{}, error) { + // TODO: Implement update role + return map[string]interface{}{}, nil +} + +// DeleteRole delete role +func (s *Service) DeleteRole(roleName string) error { + // TODO: Implement delete role + return nil +} + +// GetRolePermission get role permissions +func (s *Service) GetRolePermission(roleName string) ([]map[string]interface{}, error) { + // TODO: Implement get role permissions + return []map[string]interface{}{}, nil +} + +// GrantRolePermission grant permission to role +func (s *Service) GrantRolePermission(roleName string, actions []string, resource string) (map[string]interface{}, error) { + // TODO: Implement grant role permission + return map[string]interface{}{}, nil +} + +// RevokeRolePermission revoke permission from role +func (s *Service) RevokeRolePermission(roleName string, actions []string, resource string) (map[string]interface{}, error) { + // TODO: Implement revoke role permission + return map[string]interface{}{}, nil +} + +// UpdateUserRole update user role +func (s *Service) UpdateUserRole(username, roleName string) ([]map[string]interface{}, error) { + // TODO: Implement update user role + return []map[string]interface{}{}, nil +} + +// GetUserPermission get user permissions +func (s *Service) GetUserPermission(username string) ([]map[string]interface{}, error) { + // TODO: Implement get user permissions + return []map[string]interface{}{}, nil +} + +// GetAllServices get all services +func (s *Service) GetAllServices() ([]map[string]interface{}, error) { + allConfigs := server.GetAllConfigs() + + var result []map[string]interface{} + for _, configDict := range allConfigs { + // Get service details to check status + serviceDetail, err := s.GetServiceDetails(configDict) + if err == nil { + if status, ok := serviceDetail["status"]; ok { + configDict["status"] = status + } else { + configDict["status"] = "timeout" + } + } else { + configDict["status"] = "timeout" + } + result = append(result, configDict) } - return dao.DB.Save(&user).Error + return result, nil } -// DeleteUserRequest delete user request -type DeleteUserRequest struct { - ID string +// GetServicesByType get services by type +func (s *Service) GetServicesByType(serviceType string) ([]map[string]interface{}, error) { + return nil, errors.New("get_services_by_type: not implemented") } -// DeleteUser delete user -func (s *Service) DeleteUser(req *DeleteUserRequest) error { - return dao.DB.Where("id = ?", req.ID).Delete(&model.User{}).Error +// GetServiceDetails get service details +func (s *Service) GetServiceDetails(configDict map[string]interface{}) (map[string]interface{}, error) { + serviceType, _ := configDict["service_type"].(string) + name, _ := configDict["name"].(string) + + // Call detail function based on service type + switch serviceType { + case "meta_data": + return s.getMySQLStatus(name) + case "message_queue": + return s.getRedisInfo(name) + case "retrieval": + // Check the extra.retrieval_type to determine which retrieval service + if extra, ok := configDict["extra"].(map[string]interface{}); ok { + if retrievalType, ok := extra["retrieval_type"].(string); ok { + if retrievalType == "infinity" { + return s.getInfinityStatus(name) + } + } + } + return s.getESClusterStats(name) + case "ragflow_server": + return s.checkRAGFlowServerAlive(name) + case "file_store": + return s.checkMinioAlive(name) + case "task_executor": + return s.checkTaskExecutorAlive(name) + default: + return map[string]interface{}{ + "service_name": name, + "status": "unknown", + "message": "Service type not supported", + }, nil + } } -// GetSystemConfig get system config -func (s *Service) GetSystemConfig() map[string]interface{} { - // TODO: Load from database or config file +// getMySQLStatus gets MySQL service status +func (s *Service) getMySQLStatus(name string) (map[string]interface{}, error) { + startTime := time.Now() + + // Check basic connectivity with SELECT 1 + sqlDB, err := dao.DB.DB() + if err != nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "message": err.Error(), + }, nil + } + + // Execute SELECT 1 to check connectivity + _, err = sqlDB.Exec("SELECT 1") + if err != nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "message": err.Error(), + }, nil + } + return map[string]interface{}{ - "system_name": "RAGFlow Admin", - "version": "1.0.0", + "service_name": name, + "status": "alive", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "message": "MySQL connection successful", + }, nil +} + +// getRedisInfo gets Redis service info +func (s *Service) getRedisInfo(name string) (map[string]interface{}, error) { + startTime := time.Now() + + redisClient := cache.Get() + if redisClient == nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "error": "Redis client not initialized", + }, nil } + + // Check health + if !redisClient.Health() { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "error": "Redis health check failed", + }, nil + } + + return map[string]interface{}{ + "service_name": name, + "status": "alive", + "elapsed": fmt.Sprintf("%.1f", time.Since(startTime).Milliseconds()), + "message": "Redis connection successful", + }, nil } -// UpdateSystemConfig update system config -func (s *Service) UpdateSystemConfig(config map[string]interface{}) error { - // TODO: Save to database or config file - return nil +// getESClusterStats gets Elasticsearch cluster stats +func (s *Service) getESClusterStats(name string) (map[string]interface{}, error) { + // Check if Elasticsearch is the doc engine + docEngine := os.Getenv("DOC_ENGINE") + if docEngine == "" { + docEngine = "elasticsearch" + } + if docEngine != "elasticsearch" { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": "error: Elasticsearch is not in use.", + }, nil + } + + // Get ES config from server config + cfg := server.GetConfig() + if cfg == nil || cfg.DocEngine.ES == nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": "error: Elasticsearch configuration not found", + }, nil + } + + // Create ES engine and get cluster stats + esEngine, err := elasticsearch.NewEngine(cfg.DocEngine.ES) + if err != nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("error: %s", err.Error()), + }, nil + } + defer esEngine.Close() + + clusterStats, err := esEngine.GetClusterStats() + if err != nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("error: %s", err.Error()), + }, nil + } + + return map[string]interface{}{ + "service_name": name, + "status": "alive", + "message": clusterStats, + }, nil } -// GetSystemStatus get system status -func (s *Service) GetSystemStatus() map[string]interface{} { - // TODO: Get real status from services +// getInfinityStatus gets Infinity service status +func (s *Service) getInfinityStatus(name string) (map[string]interface{}, error) { + // TODO: Implement actual Infinity health check return map[string]interface{}{ - "status": "running", - "uptime": time.Since(time.Now()).String(), - "db_status": "connected", + "service_name": name, + "status": "unknown", + "message": "Infinity health check not implemented", + }, nil +} + +// checkRAGFlowServerAlive checks if RAGFlow server is alive +func (s *Service) checkRAGFlowServerAlive(name string) (map[string]interface{}, error) { + startTime := time.Now() + + // Get ragflow config from allConfigs + var host string + var port int + allConfigs := server.GetAllConfigs() + for _, config := range allConfigs { + if serviceType, ok := config["service_type"].(string); ok && serviceType == "ragflow_server" { + if h, ok := config["host"].(string); ok { + host = h + } + if p, ok := config["port"].(int); ok { + port = p + } + break + } + } + + // Default values + if host == "" { + host = "127.0.0.1" + } + if port == 0 { + port = 9380 + } + + // Replace 0.0.0.0 with 127.0.0.1 for local check + if host == "0.0.0.0" { + host = "127.0.0.1" } + + url := fmt.Sprintf("http://%s:%d/v1/system/ping", host, port) + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + resp, err := client.Get(url) + if err != nil { + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("error: %s", err.Error()), + }, nil + } + defer resp.Body.Close() + + elapsed := time.Since(startTime).Milliseconds() + if resp.StatusCode == 200 { + return map[string]interface{}{ + "service_name": name, + "status": "alive", + "message": fmt.Sprintf("Confirm elapsed: %.1f ms.", float64(elapsed)), + }, nil + } + + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("Confirm elapsed: %.1f ms.", float64(elapsed)), + }, nil } -// ValidateToken validate access token -func (s *Service) ValidateToken(token string) (*model.User, error) { - user, err := s.userDAO.GetByAccessToken(token) +// checkMinioAlive checks if MinIO is alive +func (s *Service) checkMinioAlive(name string) (map[string]interface{}, error) { + startTime := time.Now() + + // Get minio config from allConfigs + var host string + var secure bool + var verify bool = true + + allConfigs := server.GetAllConfigs() + for _, config := range allConfigs { + if serviceType, ok := config["service_type"].(string); ok && serviceType == "file_store" { + // Get host from config + if h, ok := config["host"].(string); ok { + host = h + } + // Get secure from extra config + if extra, ok := config["extra"].(map[string]interface{}); ok { + if s, ok := extra["secure"].(bool); ok { + secure = s + } else if s, ok := extra["secure"].(string); ok { + secure = s == "true" || s == "1" || s == "yes" + } + if v, ok := extra["verify"].(bool); ok { + verify = v + } else if v, ok := extra["verify"].(string); ok { + verify = !(v == "false" || v == "0" || v == "no") + } + } + break + } + } + + // Default host + if host == "" { + host = "localhost:9000" + } + + // Determine scheme + scheme := "http" + if secure { + scheme = "https" + } + + url := fmt.Sprintf("%s://%s/minio/health/live", scheme, host) + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // If verify is false, we need to skip SSL verification + if !verify && scheme == "https" { + client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } + + resp, err := client.Get(url) if err != nil { - return nil, ErrInvalidToken + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("error: %s", err.Error()), + }, nil } - return user, nil + defer resp.Body.Close() + + elapsed := time.Since(startTime).Milliseconds() + if resp.StatusCode == 200 { + return map[string]interface{}{ + "service_name": name, + "status": "alive", + "message": fmt.Sprintf("Confirm elapsed: %.1f ms.", float64(elapsed)), + }, nil + } + + return map[string]interface{}{ + "service_name": name, + "status": "timeout", + "message": fmt.Sprintf("Confirm elapsed: %.1f ms.", float64(elapsed)), + }, nil } -// generateToken generate a simple token -func generateToken() string { - return time.Now().Format("20060102150405") + randomString(16) +// checkTaskExecutorAlive checks if task executor is alive +func (s *Service) checkTaskExecutorAlive(name string) (map[string]interface{}, error) { + // TODO: Implement actual task executor health check + return map[string]interface{}{ + "service_name": name, + "status": "unknown", + "message": "Task executor health check not implemented", + }, nil } -// randomString generate random string -func randomString(n int) string { - const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, n) - for i := range b { - b[i] = letters[time.Now().UnixNano()%int64(len(letters))] - } - return string(b) +// ShutdownService shutdown service +func (s *Service) ShutdownService(serviceID string) (map[string]interface{}, error) { + // TODO: Implement with proper service manager + return map[string]interface{}{ + "service_id": serviceID, + "status": "shutdown", + }, nil +} + +// RestartService restart service +func (s *Service) RestartService(serviceID string) (map[string]interface{}, error) { + // TODO: Implement with proper service manager + return map[string]interface{}{ + "service_id": serviceID, + "status": "restarted", + }, nil +} + +// Variable/Settings methods + +// GetVariable get variable +func (s *Service) GetVariable(varName string) (map[string]interface{}, error) { + // TODO: Implement with settings manager + return map[string]interface{}{ + "var_name": varName, + "var_value": "", + }, nil +} + +// GetAllVariables get all variables +func (s *Service) GetAllVariables() ([]map[string]interface{}, error) { + // TODO: Implement with settings manager + return []map[string]interface{}{}, nil +} + +// SetVariable set variable +func (s *Service) SetVariable(varName, varValue string) error { + // TODO: Implement with settings manager + _ = varName + _ = varValue + return nil +} + +// Config methods + +// GetAllConfigs get all configs +func (s *Service) GetAllConfigs() ([]map[string]interface{}, error) { + // TODO: Implement with config manager + return []map[string]interface{}{}, nil +} + +// Environment methods + +// GetAllEnvironments get all environments +func (s *Service) GetAllEnvironments() ([]map[string]interface{}, error) { + // TODO: Implement with environment manager + return []map[string]interface{}{}, nil +} + +// Version methods + +// GetVersion get RAGFlow version +func (s *Service) GetVersion() string { + return utility.GetRAGFlowVersion() +} + +// Sandbox methods + +// ListSandboxProviders list sandbox providers +func (s *Service) ListSandboxProviders() ([]map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return []map[string]interface{}{}, nil +} + +// GetSandboxProviderSchema get sandbox provider schema +func (s *Service) GetSandboxProviderSchema(providerID string) (map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return map[string]interface{}{}, nil +} + +// GetSandboxConfig get sandbox config +func (s *Service) GetSandboxConfig() (map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return map[string]interface{}{}, nil +} + +// SetSandboxConfig set sandbox config +func (s *Service) SetSandboxConfig(providerType string, config map[string]interface{}, setActive bool) (map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return map[string]interface{}{ + "provider_type": providerType, + "config": config, + "set_active": setActive, + }, nil +} + +// TestSandboxConnection test sandbox connection +func (s *Service) TestSandboxConnection(providerType string, config map[string]interface{}) (map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return map[string]interface{}{ + "provider_type": providerType, + "config": config, + "connected": true, + }, nil } diff --git a/internal/dao/user.go b/internal/dao/user.go index ff134683bc1..7fb5b10b308 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -93,7 +93,14 @@ func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) { return nil, 0, err } - err := DB.Offset(offset).Limit(limit).Find(&users).Error + query := DB.Model(&model.User{}) + if offset > 0 { + query = query.Offset(offset) + } + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&users).Error return users, total, err } diff --git a/internal/engine/elasticsearch/client.go b/internal/engine/elasticsearch/client.go index bfd10d056b6..bd10fa16736 100644 --- a/internal/engine/elasticsearch/client.go +++ b/internal/engine/elasticsearch/client.go @@ -18,6 +18,7 @@ package elasticsearch import ( "context" + "encoding/json" "fmt" "net/http" "ragflow/internal/server" @@ -101,3 +102,144 @@ func (e *elasticsearchEngine) Close() error { // Go-elasticsearch client doesn't have a Close method, connection is managed by the transport return nil } + +// GetClusterStats gets Elasticsearch cluster statistics +// Reference: curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" +func (e *elasticsearchEngine) GetClusterStats() (map[string]interface{}, error) { + req := esapi.ClusterStatsRequest{} + res, err := req.Do(context.Background(), e.client) + if err != nil { + return nil, fmt.Errorf("failed to get cluster stats: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + return nil, fmt.Errorf("elasticsearch cluster stats returned error: %s", res.Status()) + } + + var rawStats map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&rawStats); err != nil { + return nil, fmt.Errorf("failed to decode cluster stats: %w", err) + } + + result := make(map[string]interface{}) + + // Basic cluster info + if clusterName, ok := rawStats["cluster_name"].(string); ok { + result["cluster_name"] = clusterName + } + if status, ok := rawStats["status"].(string); ok { + result["status"] = status + } + + // Indices info + if indices, ok := rawStats["indices"].(map[string]interface{}); ok { + if count, ok := indices["count"].(float64); ok { + result["indices"] = int(count) + } + if shards, ok := indices["shards"].(map[string]interface{}); ok { + if total, ok := shards["total"].(float64); ok { + result["indices_shards"] = int(total) + } + } + if docs, ok := indices["docs"].(map[string]interface{}); ok { + if docCount, ok := docs["count"].(float64); ok { + result["docs"] = int64(docCount) + } + if deleted, ok := docs["deleted"].(float64); ok { + result["docs_deleted"] = int64(deleted) + } + } + if store, ok := indices["store"].(map[string]interface{}); ok { + if sizeInBytes, ok := store["size_in_bytes"].(float64); ok { + result["store_size"] = convertBytes(int64(sizeInBytes)) + } + if totalDataSetSize, ok := store["total_data_set_size_in_bytes"].(float64); ok { + result["total_dataset_size"] = convertBytes(int64(totalDataSetSize)) + } + } + if mappings, ok := indices["mappings"].(map[string]interface{}); ok { + if fieldCount, ok := mappings["total_field_count"].(float64); ok { + result["mappings_fields"] = int(fieldCount) + } + if dedupFieldCount, ok := mappings["total_deduplicated_field_count"].(float64); ok { + result["mappings_deduplicated_fields"] = int(dedupFieldCount) + } + if dedupSize, ok := mappings["total_deduplicated_mapping_size_in_bytes"].(float64); ok { + result["mappings_deduplicated_size"] = convertBytes(int64(dedupSize)) + } + } + } + + // Nodes info + if nodes, ok := rawStats["nodes"].(map[string]interface{}); ok { + if count, ok := nodes["count"].(map[string]interface{}); ok { + if total, ok := count["total"].(float64); ok { + result["nodes"] = int(total) + } + } + if versions, ok := nodes["versions"].([]interface{}); ok { + result["nodes_version"] = versions + } + if os, ok := nodes["os"].(map[string]interface{}); ok { + if mem, ok := os["mem"].(map[string]interface{}); ok { + if totalInBytes, ok := mem["total_in_bytes"].(float64); ok { + result["os_mem"] = convertBytes(int64(totalInBytes)) + } + if usedInBytes, ok := mem["used_in_bytes"].(float64); ok { + result["os_mem_used"] = convertBytes(int64(usedInBytes)) + } + if usedPercent, ok := mem["used_percent"].(float64); ok { + result["os_mem_used_percent"] = usedPercent + } + } + } + if jvm, ok := nodes["jvm"].(map[string]interface{}); ok { + if versions, ok := jvm["versions"].([]interface{}); ok && len(versions) > 0 { + if version0, ok := versions[0].(map[string]interface{}); ok { + if vmVersion, ok := version0["vm_version"].(string); ok { + result["jvm_versions"] = vmVersion + } + } + } + if mem, ok := jvm["mem"].(map[string]interface{}); ok { + if heapUsed, ok := mem["heap_used_in_bytes"].(float64); ok { + result["jvm_heap_used"] = convertBytes(int64(heapUsed)) + } + if heapMax, ok := mem["heap_max_in_bytes"].(float64); ok { + result["jvm_heap_max"] = convertBytes(int64(heapMax)) + } + } + } + } + + return result, nil +} + +// convertBytes converts bytes to human readable format +func convertBytes(bytes int64) string { + const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + TB = 1024 * GB + PB = 1024 * TB + ) + + if bytes >= PB { + return fmt.Sprintf("%.2f pb", float64(bytes)/float64(PB)) + } + if bytes >= TB { + return fmt.Sprintf("%.2f tb", float64(bytes)/float64(TB)) + } + if bytes >= GB { + return fmt.Sprintf("%.2f gb", float64(bytes)/float64(GB)) + } + if bytes >= MB { + return fmt.Sprintf("%.2f mb", float64(bytes)/float64(MB)) + } + if bytes >= KB { + return fmt.Sprintf("%.2f kb", float64(bytes)/float64(KB)) + } + return fmt.Sprintf("%d b", bytes) +} diff --git a/internal/server/config.go b/internal/server/config.go index b29cef02996..5a8fbf1e1d5 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -18,6 +18,7 @@ package server import ( "fmt" + "net/url" "os" "strconv" "strings" @@ -111,6 +112,7 @@ var ( globalConfig *Config globalViper *viper.Viper zapLogger *zap.Logger + allConfigs []map[string]interface{} ) // Init initialize configuration @@ -147,6 +149,153 @@ func Init(configPath string) error { // Save viper instance globalViper = v + docEngine := os.Getenv("DOC_ENGINE") + if docEngine == "" { + docEngine = "elasticsearch" + } + id := 0 + for k, v := range globalViper.AllSettings() { + configDict, ok := v.(map[string]interface{}) + if !ok { + continue + } + + switch k { + case "ragflow": + configDict["id"] = id + configDict["name"] = fmt.Sprintf("ragflow_%d", id) + configDict["service_type"] = "ragflow_server" + configDict["extra"] = map[string]interface{}{} + configDict["port"] = configDict["http_port"] + delete(configDict, "http_port") + case "es": + // Skip if retrieval_type doesn't match doc_engine + if docEngine != "elasticsearch" { + continue + } + hosts := getString(configDict, "hosts") + host, port := parseHostPort(hosts) + username := getString(configDict, "username") + password := getString(configDict, "password") + configDict["id"] = id + configDict["name"] = "elasticsearch" + configDict["host"] = host + configDict["port"] = port + configDict["service_type"] = "retrieval" + configDict["extra"] = map[string]interface{}{ + "retrieval_type": "elasticsearch", + "username": username, + "password": password, + } + delete(configDict, "hosts") + delete(configDict, "username") + delete(configDict, "password") + case "infinity": + // Skip if retrieval_type doesn't match doc_engine + if docEngine != "infinity" { + continue + } + uri := getString(configDict, "uri") + host, port := parseHostPort(uri) + dbName := getString(configDict, "db_name") + if dbName == "" { + dbName = "default_db" + } + configDict["id"] = id + configDict["name"] = "infinity" + configDict["host"] = host + configDict["port"] = port + configDict["service_type"] = "retrieval" + configDict["extra"] = map[string]interface{}{ + "retrieval_type": "infinity", + "db_name": dbName, + } + case "minio": + hostPort := getString(configDict, "host") + host, port := parseHostPort(hostPort) + user := getString(configDict, "user") + password := getString(configDict, "password") + configDict["id"] = id + configDict["name"] = "minio" + configDict["host"] = host + configDict["port"] = port + configDict["service_type"] = "file_store" + configDict["extra"] = map[string]interface{}{ + "store_type": "minio", + "user": user, + "password": password, + } + delete(configDict, "bucket") + delete(configDict, "user") + delete(configDict, "password") + case "redis": + hostPort := getString(configDict, "host") + host, port := parseHostPort(hostPort) + password := getString(configDict, "password") + db := getInt(configDict, "db") + configDict["id"] = id + configDict["name"] = "redis" + configDict["host"] = host + configDict["port"] = port + configDict["service_type"] = "message_queue" + configDict["extra"] = map[string]interface{}{ + "mq_type": "redis", + "database": db, + "password": password, + } + delete(configDict, "password") + delete(configDict, "db") + case "mysql": + host := getString(configDict, "host") + port := getInt(configDict, "port") + user := getString(configDict, "user") + password := getString(configDict, "password") + configDict["id"] = id + configDict["name"] = "mysql" + configDict["host"] = host + configDict["port"] = port + configDict["service_type"] = "meta_data" + configDict["extra"] = map[string]interface{}{ + "meta_type": "mysql", + "username": user, + "password": password, + } + delete(configDict, "stale_timeout") + delete(configDict, "max_connections") + delete(configDict, "max_allowed_packet") + delete(configDict, "user") + delete(configDict, "password") + case "task_executor": + mqType := getString(configDict, "message_queue_type") + configDict["id"] = id + configDict["name"] = "task_executor" + configDict["service_type"] = "task_executor" + configDict["extra"] = map[string]interface{}{ + "message_queue_type": mqType, + } + delete(configDict, "message_queue_type") + case "admin": + // Skip admin section + continue + default: + // Skip unknown sections + continue + } + + // Set default values for empty host/port + if configDict["host"] == "" { + configDict["host"] = "-" + } + if configDict["port"] == 0 { + configDict["port"] = "-" + } + + delete(configDict, "prefix_path") + delete(configDict, "username") + allConfigs = append(allConfigs, configDict) + id++ + } + // Unmarshal configuration to globalConfig // Note: This will only unmarshal fields that match the Config struct if err := v.Unmarshal(&globalConfig); err != nil { @@ -278,6 +427,14 @@ func SetLogger(l *zap.Logger) { zapLogger = l } +func GetGlobalViperConfig() *viper.Viper { + return globalViper +} + +func GetAllConfigs() []map[string]interface{} { + return allConfigs +} + // PrintAll prints all configuration settings func PrintAll() { if globalViper == nil { @@ -292,3 +449,46 @@ func PrintAll() { } zapLogger.Info("=== End Configuration ===") } + +// parseHostPort parses host:port string and returns host and port +func parseHostPort(hostPort string) (string, int) { + if hostPort == "" { + return "", 0 + } + + // Handle URL format like http://host:port + if strings.Contains(hostPort, "://") { + u, err := url.Parse(hostPort) + if err == nil { + hostPort = u.Host + } + } + + // Split host:port + parts := strings.Split(hostPort, ":") + host := parts[0] + port := 0 + if len(parts) > 1 { + port, _ = strconv.Atoi(parts[1]) + } + return host, port +} + +// getString gets string value from map +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +// getInt gets int value from map +func getInt(m map[string]interface{}, key string) int { + if v, ok := m[key].(int); ok { + return v + } + if v, ok := m[key].(float64); ok { + return int(v) + } + return 0 +} diff --git a/internal/service/user.go b/internal/service/user.go index 9db2a264398..ccf737e3e3b 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -32,7 +32,6 @@ import ( "strings" "time" - "github.com/google/uuid" "golang.org/x/crypto/scrypt" "ragflow/internal/dao" @@ -123,8 +122,8 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err) } - userID := s.GenerateToken() - accessToken := s.GenerateToken() + userID := utility.GenerateToken() + accessToken := utility.GenerateToken() status := "1" loginChannel := "password" isSuperuser := false @@ -167,7 +166,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC tenant.CreateDate = &now_date tenant.UpdateDate = &now_date - userTenantID := s.GenerateToken() + userTenantID := utility.GenerateToken() userTenant := &model.UserTenant{ ID: userTenantID, UserID: userID, @@ -181,7 +180,7 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC userTenant.CreateDate = &now_date userTenant.UpdateDate = &now_date - fileID := s.GenerateToken() + fileID := utility.GenerateToken() rootFile := &model.File{ ID: fileID, ParentID: fileID, @@ -267,7 +266,7 @@ func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, e } // Generate new access token - token := s.GenerateToken() + token := utility.GenerateToken() if err := s.UpdateUserAccessToken(user, token); err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err) } @@ -310,7 +309,8 @@ func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common. return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!") } - token := s.GenerateToken() + // Generate new access token + token := utility.GenerateToken() user.AccessToken = &token now := time.Now().Unix() @@ -515,11 +515,6 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error) return string(plaintext), nil } -// GenerateToken generates a new access token -func (s *UserService) GenerateToken() string { - return strings.ReplaceAll(uuid.New().String(), "-", "") -} - // GetUserByToken gets user by authorization header // The token parameter is the authorization header value, which needs to be decrypted // using itsdangerous URLSafeTimedSerializer to get the actual access_token @@ -558,7 +553,7 @@ func (s *UserService) UpdateUserAccessToken(user *model.User, token string) erro func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) { // Invalidate token by setting it to an invalid value // Similar to Python implementation: "INVALID_" + secrets.token_hex(16) - invalidToken := "INVALID_" + s.GenerateToken() + invalidToken := "INVALID_" + utility.GenerateToken() err := s.UpdateUserAccessToken(user, invalidToken) if err != nil { return common.CodeServerError, err diff --git a/internal/utility/token.go b/internal/utility/token.go index 789036b4478..3c7b97fc7bc 100644 --- a/internal/utility/token.go +++ b/internal/utility/token.go @@ -25,6 +25,7 @@ import ( "fmt" "strings" + "github.com/google/uuid" "github.com/iromli/go-itsdangerous" ) @@ -133,3 +134,7 @@ func GenerateSecretKey() (string, error) { } return hex.EncodeToString(bytes), nil } + +func GenerateToken() string { + return strings.ReplaceAll(uuid.New().String(), "-", "") +} diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 4257d5a0e18..26c253051bc 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -518,7 +518,7 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim manualSetup: 'Pipeline', builtIn: 'Built-in', titleDescription: - 'Update your memory configuration here, particularly the LLM and prompts.', + 'Update your dataset configuration here, particularly the LLM and prompts.', name: 'Dataset name', photo: 'Dataset photo', photoTip: 'You can upload an image up to 4 MB.', diff --git a/web/src/utils/next-request.ts b/web/src/utils/next-request.ts index d2ead134a1a..c804400a343 100644 --- a/web/src/utils/next-request.ts +++ b/web/src/utils/next-request.ts @@ -149,7 +149,7 @@ request.interceptors.response.use( console.log('🚀 ~ error:', error); // Handle HTTP 401 (token expired / invalid) - const status = error?.response?.status; + const status = error?.response?.status; if (status === 401) { if (!isRedirecting) { isRedirecting = true; @@ -164,8 +164,8 @@ request.interceptors.response.use( redirectToLogin(); } - return Promise.reject(error); - } + return Promise.reject(error); + } errorHandler(error); return Promise.reject(error); diff --git a/web/src/utils/request.ts b/web/src/utils/request.ts index f957cb2a086..3c122cf1bef 100644 --- a/web/src/utils/request.ts +++ b/web/src/utils/request.ts @@ -117,10 +117,12 @@ request.interceptors.response.use(async (response: Response, options) => { if (!isRedirecting) { isRedirecting = true; - const data = await response.clone().json().catch(() => ({})); + const data = await response + .clone() + .json() + .catch(() => ({})); - const messageText = - data?.message || RetcodeMessage[401]; + const messageText = data?.message || RetcodeMessage[401]; notification.error({ message: messageText, description: messageText, From 585fa005a3cfd5bf7080c5ebe632bda2e2769a8c Mon Sep 17 00:00:00 2001 From: JiangNan <1394485448@qq.com> Date: Mon, 9 Mar 2026 11:09:47 +0800 Subject: [PATCH 176/565] Fix: undefined variable and wrong method name in agent components (#13462) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR fixes two runtime bugs in agent components: **Bug 1: `agent/component/invoke.py` — `NameError` in POST + `clean_html` path** The POST method's `clean_html` branch uses the variable `sections` without ever defining it. Both the GET and PUT branches correctly call `sections = HtmlParser()(None, response.content)` before referencing `sections`, but this line was missing from the POST branch (copy-paste omission). This causes a `NameError` whenever a user configures an Invoke component with `method="post"` and `clean_html=True`. **Bug 2: `agent/component/data_operations.py` — `AttributeError` in `_recursive_eval`** The `_recursive_eval` method recursively calls `self.recursive_eval()` (without the leading underscore) instead of `self._recursive_eval()`. Since the method is defined as `_recursive_eval`, this causes an `AttributeError` at runtime when the `literal_eval` operation processes nested dicts or lists. ## Test plan - [ ] Configure an Invoke node with `method=post` and `clean_html=True`, verify HTML is parsed correctly without `NameError` - [ ] Configure a DataOperations node with `operations=literal_eval` on nested data, verify no `AttributeError` --------- Signed-off-by: JiangNan <1394485448@qq.com> --- agent/component/data_operations.py | 4 ++-- agent/component/invoke.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/agent/component/data_operations.py b/agent/component/data_operations.py index cddd20996cd..60e65f88121 100644 --- a/agent/component/data_operations.py +++ b/agent/component/data_operations.py @@ -94,9 +94,9 @@ def _select_keys(self): def _recursive_eval(self, data): if isinstance(data, dict): - return {k: self.recursive_eval(v) for k, v in data.items()} + return {k: self._recursive_eval(v) for k, v in data.items()} if isinstance(data, list): - return [self.recursive_eval(item) for item in data] + return [self._recursive_eval(item) for item in data] if isinstance(data, str): try: if ( diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 61ebe2b396d..c24c91b16d6 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -121,6 +121,7 @@ def replace_variable(match): else: response = requests.post(url=url, data=args, headers=headers, proxies=proxies, timeout=self._param.timeout) if self._param.clean_html: + sections = HtmlParser()(None, response.content) self.set_output("result", "\n".join(sections)) else: self.set_output("result", response.text) From d03cf8b2324f7b5524af33b5b25fea4fcb2d1752 Mon Sep 17 00:00:00 2001 From: guptas6est Date: Mon, 9 Mar 2026 04:06:00 +0000 Subject: [PATCH 177/565] Fix: upgrade pypdf to 6.7.5 and migrate from deprecated pypdf2 to fix CVE-2026-28804 and CVE-2023-36464 (#13454) ### What problem does this PR solve? This PR addresses security vulnerabilities in PDF processing dependencies identified by Trivy security scan: 1. CVE-2026-28804 (MEDIUM): pypdf 6.7.4 vulnerable to inefficient decoding of ASCIIHexDecode streams 2. CVE-2023-36464 (MEDIUM): pypdf2 3.0.1 susceptible to infinite loop when parsing malformed comments Since pypdf2 is deprecated with no available fixes, this PR migrates all pypdf2 usage to the actively maintained pypdf library (version 6.7.5), which resolves both vulnerabilities. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- pyproject.toml | 3 +-- rag/app/presentation.py | 2 +- rag/utils/file_utils.py | 2 +- uv.lock | 19 ++++--------------- 4 files changed, 7 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0665a1c5365..73006ac28f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,8 +81,7 @@ dependencies = [ "pyobvector==0.2.22", "pyodbc>=5.2.0,<6.0.0", "pypandoc>=1.16", - "pypdf>=6.6.2", - "pypdf2>=3.0.1,<4.0.0", + "pypdf>=6.7.5", "python-calamine>=0.4.0", "python-docx>=1.1.2,<2.0.0", "python-pptx>=1.0.2,<2.0.0", diff --git a/rag/app/presentation.py b/rag/app/presentation.py index 909fd61a30c..390955041a4 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -20,7 +20,7 @@ from collections import defaultdict from io import BytesIO -from PyPDF2 import PdfReader as pdf2_read +from pypdf import PdfReader as pdf2_read from deepdoc.parser import PdfParser, PlainParser from deepdoc.parser.ppt_parser import RAGFlowPptParser diff --git a/rag/utils/file_utils.py b/rag/utils/file_utils.py index 8d19079b76a..c9ec50a36a4 100644 --- a/rag/utils/file_utils.py +++ b/rag/utils/file_utils.py @@ -21,7 +21,7 @@ from requests.exceptions import Timeout, RequestException from io import BytesIO from typing import List, Union, Tuple, Optional, Dict -import PyPDF2 +import pypdf as PyPDF2 from docx import Document import olefile diff --git a/uv.lock b/uv.lock index 0b1423a014c..3432723677a 100644 --- a/uv.lock +++ b/uv.lock @@ -5760,20 +5760,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.7.4" +version = "6.7.5" source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } -sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" } wheels = [ - { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" }, -] - -[[package]] -name = "pypdf2" -version = "3.0.1" -source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } -sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9f/bb/18dc3062d37db6c491392007dfd1a7f524bb95886eb956569ac38a23a784/PyPDF2-3.0.1.tar.gz", hash = "sha256:a74408f69ba6271f71b9352ef4ed03dc53a31aa404d29b5d31f53bfecfee1440", size = 227419, upload-time = "2022-12-31T10:36:13.13Z" } -wheels = [ - { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8e/5e/c86a5643653825d3c913719e788e41386bee415c2b87b4f955432f2de6b2/pypdf2-3.0.1-py3-none-any.whl", hash = "sha256:d16e4205cfee272fbdc0568b68d82be796540b1537508cef59388f839c191928", size = 232572, upload-time = "2022-12-31T10:36:10.327Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" }, ] [[package]] @@ -6323,7 +6314,6 @@ dependencies = [ { name = "pyodbc" }, { name = "pypandoc" }, { name = "pypdf" }, - { name = "pypdf2" }, { name = "python-calamine" }, { name = "python-docx" }, { name = "python-gitlab" }, @@ -6462,8 +6452,7 @@ requires-dist = [ { name = "pyobvector", specifier = "==0.2.22" }, { name = "pyodbc", specifier = ">=5.2.0,<6.0.0" }, { name = "pypandoc", specifier = ">=1.16" }, - { name = "pypdf", specifier = ">=6.6.2" }, - { name = "pypdf2", specifier = ">=3.0.1,<4.0.0" }, + { name = "pypdf", specifier = ">=6.7.5" }, { name = "python-calamine", specifier = ">=0.4.0" }, { name = "python-docx", specifier = ">=1.1.2,<2.0.0" }, { name = "python-gitlab", specifier = ">=7.0.0" }, From aee7134201e531e041e58e6f6a26934bf60183a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E6=B5=B7=E8=92=BC=E7=81=86?= Date: Mon, 9 Mar 2026 12:36:45 +0800 Subject: [PATCH 178/565] Feat: add switch_chunks endpoint to manage chunk availability (#13435) ### What problem does this commit solve? This commit introduces a new API endpoint `/datasets//documents//chunks/switch` that allows users to switch the availability status of specified chunks in a document as same as chunk_app.py ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/sdk/doc.py | 81 +++++++++++++++++++++++ docs/references/http_api_reference.md | 95 +++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 80d0a2e1eaf..7ed5d0cca4c 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -45,6 +45,7 @@ from rag.nlp import rag_tokenizer, search from rag.prompts.generator import cross_languages, keyword_extraction from common.string_utils import remove_redundant_spaces +from common.misc_utils import thread_pool_exec from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource from common import settings @@ -1477,6 +1478,86 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): return get_result() +@manager.route( # noqa: F821 + "/datasets//documents//chunks/switch", methods=["POST"] +) +@token_required +async def switch_chunks(tenant_id, dataset_id, document_id): + """ + Switch availability of specified chunks (same as chunk_app switch). + --- + tags: + - Chunks + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: body + name: body + required: true + schema: + type: object + properties: + chunk_ids: + type: array + items: + type: string + description: List of chunk IDs to switch. + available_int: + type: integer + description: 1 for available, 0 for unavailable. + available: + type: boolean + description: Availability status (alternative to available_int). + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Chunks availability switched successfully. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + req = await get_request_json() + if not req.get("chunk_ids"): + return get_error_data_result(message="`chunk_ids` is required.") + if "available_int" not in req and "available" not in req: + return get_error_data_result(message="`available_int` or `available` is required.") + available_int = int(req["available_int"]) if "available_int" in req else (1 if req.get("available") else 0) + try: + + def _switch_sync(): + e, doc = DocumentService.get_by_id(document_id) + if not e: + return get_error_data_result(message="Document not found!") + if not doc or str(doc.kb_id) != str(dataset_id): + return get_error_data_result(message="Document not found!") + for cid in req["chunk_ids"]: + if not settings.docStoreConn.update( + {"id": cid}, + {"available_int": available_int}, + search.index_name(tenant_id), + doc.kb_id, + ): + return get_error_data_result(message="Index updating failure") + return get_result(data=True) + + return await thread_pool_exec(_switch_sync) + except Exception as e: + return server_error_response(e) + + @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required async def retrieval_test(tenant_id): diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index a6ccf63fa6d..907e2202308 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -2220,6 +2220,101 @@ Failure: --- +### Switch chunks availability + +**POST** `/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/switch` + +Switches the availability of specified chunks (enable or disable chunks for retrieval). + +#### Request + +- Method: POST +- URL: `/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/switch` +- Headers: + - `'Content-Type: application/json'` + - `'Authorization: Bearer '` +- Body: + - `"chunk_ids"`: `list[string]` (*Required*) List of chunk IDs to switch. + - `"available_int"`: `integer` (*Optional*) `1` for available, `0` for unavailable. Mutually exclusive with `"available"`. + - `"available"`: `boolean` (*Optional*) Availability status. Mutually exclusive with `"available_int"`. Must provide either `available_int` or `available`. + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/switch \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer ' \ + --data ' + { + "chunk_ids": ["chunk_id_1", "chunk_id_2"], + "available_int": 1 + }' +``` + +##### Request parameters + +- `dataset_id`: (*Path parameter*) + The ID of the dataset. +- `document_id`: (*Path parameter*) + The ID of the document. +- `"chunk_ids"`: (*Body parameter*), `list[string]`, *Required* + List of chunk IDs whose availability is to be switched. +- `"available_int"`: (*Body parameter*), `integer` + `1` for available (chunk participates in retrieval), `0` for unavailable. Either this or `"available"` must be provided. +- `"available"`: (*Body parameter*), `boolean` + Availability status. `true` for available, `false` for unavailable. Alternative to `"available_int"`. + +#### Response + +Success: + +```json +{ + "code": 0, + "data": true +} +``` + +Failure: + +```json +{ + "code": 101, + "message": "You don't own the dataset {dataset_id}." +} +``` + +```json +{ + "code": 101, + "message": "`chunk_ids` is required." +} +``` + +```json +{ + "code": 101, + "message": "`available_int` or `available` is required." +} +``` + +```json +{ + "code": 101, + "message": "Document not found!" +} +``` + +```json +{ + "code": 101, + "message": "Index updating failure" +} +``` + +--- + ### Retrieve a metadata summary from a dataset **GET** `/api/v1/datasets/{dataset_id}/metadata/summary` From f0e1bbfb9a53140be6476f92a3981fb1e5e47f0c Mon Sep 17 00:00:00 2001 From: Stephen Hu <812791840@qq.com> Date: Mon, 9 Mar 2026 14:16:57 +0800 Subject: [PATCH 179/565] refactor: improve paddle ocr logic (#13467) ### What problem does this PR solve? improve paddle ocr logic ### Type of change - [x] Refactoring --- deepdoc/parser/paddleocr_parser.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/deepdoc/parser/paddleocr_parser.py b/deepdoc/parser/paddleocr_parser.py index 28546e1c0fc..cb27a725492 100644 --- a/deepdoc/parser/paddleocr_parser.py +++ b/deepdoc/parser/paddleocr_parser.py @@ -63,7 +63,6 @@ def _remove_images_from_markdown(markdown: str) -> str: class PaddleOCRVLConfig: """Configuration for PaddleOCR-VL algorithm.""" - use_doc_orientation_classify: Optional[bool] = False use_doc_orientation_classify: Optional[bool] = False use_doc_unwarping: Optional[bool] = False use_layout_detection: Optional[bool] = None @@ -533,21 +532,25 @@ def crop(self, text: str, need_position: bool = False): return None, None return - height = 0 + total_height = 0 + max_width = 0 + img_sizes = [] for img in imgs: - height += img.size[1] + GAP - height = int(height) - width = int(np.max([i.size[0] for i in imgs])) - pic = Image.new("RGB", (width, height), (245, 245, 245)) - height = 0 - for ii, img in enumerate(imgs): - if ii == 0 or ii + 1 == len(imgs): + w, h = img.size + img_sizes.append((w, h)) + max_width = max(max_width, w) + total_height += h + GAP + + pic = Image.new("RGB", (max_width, int(total_height)), (245, 245, 245)) + current_height = 0 + imgs_count = len(imgs) + for ii, (img, (w, h)) in enumerate(zip(imgs, img_sizes)): + if ii == 0 or ii + 1 == imgs_count: img = img.convert("RGBA") - overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) - overlay.putalpha(128) + overlay = Image.new("RGBA", img.size, (0, 0, 0, 128)) img = Image.alpha_composite(img, overlay).convert("RGB") - pic.paste(img, (0, int(height))) - height += img.size[1] + GAP + pic.paste(img, (0, int(current_height))) + current_height += h + GAP if need_position: return pic, positions From f6615e5b59b3a4ae7af86253b9607ede3869d01d Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Mon, 9 Mar 2026 15:52:14 +0800 Subject: [PATCH 180/565] feat: Added LLM factory initialization functionality and knowledge base related API interfaces (#13472) ### What problem does this PR solve? feat: Added LLM factory initialization functionality and knowledge base related API interfaces refactor(dao): Refactored the TenantLLMDAO query method feat(handler): Implemented knowledge base related API endpoints feat(service): Added LLM API key setting functionality feat(model): Extended the knowledge base model definition feat(config): Added default user LLM configuration ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .gitignore | 10 +- cmd/server_main.go | 8 + internal/dao/chat.go | 9 +- internal/dao/kb.go | 459 ++++++++++++++++++++++---- internal/dao/llm.go | 28 ++ internal/dao/tenant.go | 7 +- internal/dao/tenant_llm.go | 17 +- internal/handler/kb.go | 572 ++++++++++++++++++++++++++++++--- internal/handler/llm.go | 130 +++++--- internal/handler/tenant.go | 57 ++-- internal/handler/user.go | 58 ++++ internal/init_data/llm_init.go | 157 +++++++++ internal/model/kb.go | 194 ++++++++++- internal/model/llm.go | 15 +- internal/router/router.go | 21 ++ internal/server/config.go | 62 ++++ internal/service/chat.go | 4 +- internal/service/kb.go | 511 +++++++++++++++++++++++++++-- internal/service/llm.go | 309 +++++++++++++----- internal/service/user.go | 82 ++++- 20 files changed, 2408 insertions(+), 302 deletions(-) create mode 100644 internal/init_data/llm_init.go diff --git a/.gitignore b/.gitignore index 6316ddb9607..58fbda13bbc 100644 --- a/.gitignore +++ b/.gitignore @@ -220,7 +220,13 @@ uv-aarch64*.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz docker/launch_backend_service_windows.sh +# C++ build directories +internal/cpp/build/ +internal/cpp/cmake-build-release/ +internal/cpp/cmake-build-debug/ + +# Trae IDE config +.trae/ + # Go server build output bin/ -internal/cpp/cmake-build-release/ -internal/cpp/cmake-build-debug/ \ No newline at end of file diff --git a/cmd/server_main.go b/cmd/server_main.go index 011869145bc..7d6ac21e8da 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "ragflow/internal/init_data" "ragflow/internal/server" "ragflow/internal/utility" "strings" @@ -71,6 +72,13 @@ func main() { logger.Fatal("Failed to initialize database", zap.Error(err)) } + // Initialize LLM factory data models from configuration file + if err := init_data.InitLLMFactory(); err != nil { + logger.Error("Failed to initialize LLM factory", err) + } else { + logger.Info("LLM factory initialized successfully") + } + // Initialize doc engine if err := engine.Init(&cfg.DocEngine); err != nil { logger.Fatal("Failed to initialize doc engine", zap.Error(err)) diff --git a/internal/dao/chat.go b/internal/dao/chat.go index 1500ea540a4..91f1b7d1dc5 100644 --- a/internal/dao/chat.go +++ b/internal/dao/chat.go @@ -62,8 +62,13 @@ func (dao *ChatDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pag user.nickname, user.avatar as tenant_avatar `). - Joins("LEFT JOIN user ON dialog.tenant_id = user.id"). - Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1") + Joins("LEFT JOIN user ON dialog.tenant_id = user.id") + + if len(tenantIDs) > 0 { + query = query.Where("(dialog.tenant_id IN ? OR dialog.tenant_id = ?) AND dialog.status = ?", tenantIDs, userID, "1") + } else { + query = query.Where("dialog.tenant_id = ? AND dialog.status = ?", userID, "1") + } // Apply keyword filter if keywords != "" { diff --git a/internal/dao/kb.go b/internal/dao/kb.go index cf36e1a7e61..6fb2a6a2443 100644 --- a/internal/dao/kb.go +++ b/internal/dao/kb.go @@ -19,6 +19,7 @@ package dao import ( "ragflow/internal/model" "strings" + "time" ) // KnowledgebaseDAO knowledge base data access object @@ -29,15 +30,133 @@ func NewKnowledgebaseDAO() *KnowledgebaseDAO { return &KnowledgebaseDAO{} } -// ListByTenantIDs list knowledge bases by tenant IDs -func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { +// Create creates a new knowledge base record +func (dao *KnowledgebaseDAO) Create(kb *model.Knowledgebase) error { + return DB.Create(kb).Error +} + +// Update updates a knowledge base record +func (dao *KnowledgebaseDAO) Update(kb *model.Knowledgebase) error { + return DB.Save(kb).Error +} + +// UpdateByID updates a knowledge base by ID with the given fields +func (dao *KnowledgebaseDAO) UpdateByID(id string, updates map[string]interface{}) error { + return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Updates(updates).Error +} + +// Delete soft deletes a knowledge base by setting status to invalid +func (dao *KnowledgebaseDAO) Delete(id string) error { + return DB.Model(&model.Knowledgebase{}).Where("id = ?", id).Update("status", string(model.StatusInvalid)).Error +} + +// GetByID retrieves a knowledge base by ID +func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDAndTenantID retrieves a knowledge base by ID and tenant ID +func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByIDs retrieves multiple knowledge bases by IDs +func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Where("id IN ? AND status = ?", ids, string(model.StatusValid)).Find(&kbs).Error + return kbs, err +} + +// GetByName retrieves a knowledge base by name and tenant ID +func (dao *KnowledgebaseDAO) GetByName(name, tenantID string) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + err := DB.Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, string(model.StatusValid)).First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// GetByCreatedBy retrieves knowledge bases created by a specific user +func (dao *KnowledgebaseDAO) GetByCreatedBy(createdBy string) ([]*model.Knowledgebase, error) { var kbs []*model.Knowledgebase + err := DB.Where("created_by = ? AND status = ?", createdBy, string(model.StatusValid)).Find(&kbs).Error + return kbs, err +} + +// Query retrieves knowledge bases with filters +func (dao *KnowledgebaseDAO) Query(filters map[string]interface{}) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + query := DB.Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.Find(&kbs).Error + return kbs, err +} + +// QueryOne retrieves a single knowledge base with filters +func (dao *KnowledgebaseDAO) QueryOne(filters map[string]interface{}) (*model.Knowledgebase, error) { + var kb model.Knowledgebase + query := DB.Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.First(&kb).Error + if err != nil { + return nil, err + } + return &kb, nil +} + +// Count returns the count of knowledge bases matching the filters +func (dao *KnowledgebaseDAO) Count(filters map[string]interface{}) (int64, error) { + var count int64 + query := DB.Model(&model.Knowledgebase{}).Where("status = ?", string(model.StatusValid)) + + for key, value := range filters { + if value != nil && value != "" { + query = query.Where(key+" = ?", value) + } + } + + err := query.Count(&count).Error + return count, err +} + +// GetByTenantIDs retrieves knowledge bases by tenant IDs with pagination +// This matches the Python get_by_tenant_ids method +func (dao *KnowledgebaseDAO) GetByTenantIDs(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, keywords, parserID string) ([]*model.KnowledgebaseListItem, int64, error) { + var kbs []*model.KnowledgebaseListItem var total int64 query := DB.Model(&model.Knowledgebase{}). + Select(`knowledgebase.id, knowledgebase.avatar, knowledgebase.name, + knowledgebase.language, knowledgebase.description, knowledgebase.tenant_id, + knowledgebase.permission, knowledgebase.doc_num, knowledgebase.token_num, + knowledgebase.chunk_num, knowledgebase.parser_id, knowledgebase.embd_id, + user.nickname, user.avatar as tenant_avatar, knowledgebase.update_time`). Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). - Where("(knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?", tenantIDs, "team", userID). - Where("knowledgebase.status = ?", "1") + Where("((knowledgebase.tenant_id IN ? AND knowledgebase.permission = ?) OR knowledgebase.tenant_id = ?) AND knowledgebase.status = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid)) if keywords != "" { query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") @@ -47,26 +166,23 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, query = query.Where("knowledgebase.parser_id = ?", parserID) } - // Order if desc { - query = query.Order(orderby + " DESC") + query = query.Order("knowledgebase." + orderby + " DESC") } else { - query = query.Order(orderby + " ASC") + query = query.Order("knowledgebase." + orderby + " ASC") } - // Count if err := query.Count(&total).Error; err != nil { return nil, 0, err } - // Pagination - if page > 0 && pageSize > 0 { - offset := (page - 1) * pageSize - if err := query.Offset(offset).Limit(pageSize).Find(&kbs).Error; err != nil { + if pageNumber > 0 && itemsPerPage > 0 { + offset := (pageNumber - 1) * itemsPerPage + if err := query.Offset(offset).Limit(itemsPerPage).Scan(&kbs).Error; err != nil { return nil, 0, err } } else { - if err := query.Find(&kbs).Error; err != nil { + if err := query.Scan(&kbs).Error; err != nil { return nil, 0, err } } @@ -74,76 +190,307 @@ func (dao *KnowledgebaseDAO) ListByTenantIDs(tenantIDs []string, userID string, return kbs, total, nil } -// ListByOwnerIDs list knowledge bases by owner IDs -func (dao *KnowledgebaseDAO) ListByOwnerIDs(ownerIDs []string, page, pageSize int, orderby string, desc bool, keywords, parserID string) ([]*model.Knowledgebase, int64, error) { +// GetAllByTenantIDs retrieves all permitted knowledge bases by tenant IDs +// This matches the Python get_all_kb_by_tenant_ids method +func (dao *KnowledgebaseDAO) GetAllByTenantIDs(tenantIDs []string, userID string) ([]*model.Knowledgebase, error) { var kbs []*model.Knowledgebase - query := DB.Model(&model.Knowledgebase{}). - Joins("LEFT JOIN user ON knowledgebase.tenant_id = user.id"). - Where("knowledgebase.tenant_id IN ?", ownerIDs). - Where("knowledgebase.status = ?", "1") + err := DB.Where( + "(tenant_id IN ? AND permission = ?) OR tenant_id = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, + ).Order("create_time ASC").Find(&kbs).Error - if keywords != "" { - query = query.Where("LOWER(knowledgebase.name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + return kbs, err +} + +// GetDetail retrieves detailed knowledge base information with joined pipeline data +// This matches the Python get_detail method +func (dao *KnowledgebaseDAO) GetDetail(kbID string) (*model.KnowledgebaseDetail, error) { + var detail model.KnowledgebaseDetail + + err := DB.Table("knowledgebase"). + Select(`knowledgebase.id, knowledgebase.embd_id, knowledgebase.avatar, knowledgebase.name, + knowledgebase.language, knowledgebase.description, knowledgebase.permission, + knowledgebase.doc_num, knowledgebase.token_num, knowledgebase.chunk_num, + knowledgebase.parser_id, knowledgebase.pipeline_id, + user_canvas.title as pipeline_name, user_canvas.avatar as pipeline_avatar, + knowledgebase.parser_config, knowledgebase.pagerank, + knowledgebase.graphrag_task_id, knowledgebase.graphrag_task_finish_at, + knowledgebase.raptor_task_id, knowledgebase.raptor_task_finish_at, + knowledgebase.mindmap_task_id, knowledgebase.mindmap_task_finish_at, + knowledgebase.create_time, knowledgebase.update_time`). + Joins("LEFT JOIN user_canvas ON knowledgebase.pipeline_id = user_canvas.id"). + Where("knowledgebase.id = ? AND knowledgebase.status = ?", kbID, string(model.StatusValid)). + Scan(&detail).Error + + if err != nil { + return nil, err } - if parserID != "" { - query = query.Where("knowledgebase.parser_id = ?", parserID) + return &detail, nil +} + +// Accessible checks if a knowledge base is accessible by a user +// This matches the Python accessible method +func (dao *KnowledgebaseDAO) Accessible(kbID, userID string) bool { + var count int64 + err := DB.Table("knowledgebase"). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.id = ? AND user_tenant.user_id = ? AND knowledgebase.status = ?", + kbID, userID, string(model.StatusValid)). + Count(&count).Error + + if err != nil { + return false } + return count > 0 +} - // Order - if desc { - query = query.Order(orderby + " DESC") - } else { - query = query.Order(orderby + " ASC") +// Accessible4Deletion checks if a knowledge base can be deleted by a user +// This matches the Python accessible4deletion method +func (dao *KnowledgebaseDAO) Accessible4Deletion(kbID, userID string) bool { + var count int64 + err := DB.Model(&model.Knowledgebase{}). + Where("id = ? AND created_by = ? AND status = ?", kbID, userID, string(model.StatusValid)). + Count(&count).Error + + if err != nil { + return false } + return count > 0 +} - if err := query.Find(&kbs).Error; err != nil { - return nil, 0, err +// DuplicateName generates a unique name by appending parentheses if name already exists +// This matches the Python duplicate_name function behavior +func (dao *KnowledgebaseDAO) DuplicateName(name, tenantID string) string { + var existingNames []string + DB.Model(&model.Knowledgebase{}). + Where("name LIKE ? AND tenant_id = ? AND status = ?", name+"%", tenantID, string(model.StatusValid)). + Pluck("name", &existingNames) + + if len(existingNames) == 0 { + return name } - total := int64(len(kbs)) + nameSet := make(map[string]bool) + for _, n := range existingNames { + nameSet[strings.ToLower(n)] = true + } - // Manual pagination - if page > 0 && pageSize > 0 { - start := (page - 1) * pageSize - end := start + pageSize - if end > int(total) { - end = int(total) - } - if start < end { - kbs = kbs[start:end] - } else { - kbs = []*model.Knowledgebase{} + if !nameSet[strings.ToLower(name)] { + return name + } + + for i := 1; ; i++ { + newName := name + " " + strings.Repeat("(", i) + strings.Repeat(")", i) + if !nameSet[strings.ToLower(newName)] { + return newName } } +} - return kbs, total, nil +// AtomicIncreaseDocNumByID atomically increments the document count +// This matches the Python atomic_increase_doc_num_by_id method +func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error { + now := time.Now().Unix() + nowDate := time.Now() + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", kbID). + Updates(map[string]interface{}{ + "doc_num": DB.Raw("doc_num + 1"), + "update_time": now, + "update_date": nowDate, + }).Error } -// GetByID gets knowledge base by ID -func (dao *KnowledgebaseDAO) GetByID(id string) (*model.Knowledgebase, error) { +// DecreaseDocumentNum decreases document, chunk, and token counts +// This matches the Python decrease_document_num_in_delete method +func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error { + now := time.Now().Unix() + nowDate := time.Now() + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", kbID). + Updates(map[string]interface{}{ + "doc_num": DB.Raw("doc_num - ?", docNum), + "chunk_num": DB.Raw("chunk_num - ?", chunkNum), + "token_num": DB.Raw("token_num - ?", tokenNum), + "update_time": now, + "update_date": nowDate, + }).Error +} + +// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant +// This matches the Python get_kb_ids method +func (dao *KnowledgebaseDAO) GetKBIDsByTenantID(tenantID string) ([]string, error) { + var kbIDs []string + err := DB.Model(&model.Knowledgebase{}). + Where("tenant_id = ? AND status = ?", tenantID, string(model.StatusValid)). + Pluck("id", &kbIDs).Error + return kbIDs, err +} + +// GetAllIDs retrieves all knowledge base IDs +// This matches the Python get_all_ids method +func (dao *KnowledgebaseDAO) GetAllIDs() ([]string, error) { + var kbIDs []string + err := DB.Model(&model.Knowledgebase{}). + Where("status = ?", string(model.StatusValid)). + Pluck("id", &kbIDs).Error + return kbIDs, err +} + +// UpdateParserConfig updates the parser configuration with deep merge +// This matches the Python update_parser_config method +func (dao *KnowledgebaseDAO) UpdateParserConfig(id string, config map[string]interface{}) error { var kb model.Knowledgebase - err := DB.Where("id = ? AND status = ?", id, "1").First(&kb).Error - if err != nil { - return nil, err + if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil { + return err } - return &kb, nil + + mergedConfig := mergeConfig(kb.ParserConfig, config) + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", id). + Update("parser_config", mergedConfig).Error } -// GetByIDAndTenantID gets knowledge base by ID and tenant ID -func (dao *KnowledgebaseDAO) GetByIDAndTenantID(id, tenantID string) (*model.Knowledgebase, error) { +// DeleteFieldMap removes the field_map from parser_config +// This matches the Python delete_field_map method +func (dao *KnowledgebaseDAO) DeleteFieldMap(id string) error { var kb model.Knowledgebase - err := DB.Where("id = ? AND tenant_id = ? AND status = ?", id, tenantID, "1").First(&kb).Error + if err := DB.Where("id = ? AND status = ?", id, string(model.StatusValid)).First(&kb).Error; err != nil { + return err + } + + if kb.ParserConfig != nil { + delete(kb.ParserConfig, "field_map") + return DB.Model(&model.Knowledgebase{}). + Where("id = ?", id). + Update("parser_config", kb.ParserConfig).Error + } + return nil +} + +// GetFieldMap retrieves field mappings from multiple knowledge bases +// This matches the Python get_field_map method +func (dao *KnowledgebaseDAO) GetFieldMap(ids []string) (map[string]interface{}, error) { + conf := make(map[string]interface{}) + kbs, err := dao.GetByIDs(ids) if err != nil { return nil, err } - return &kb, nil + + for _, kb := range kbs { + if kb.ParserConfig != nil { + if fieldMap, ok := kb.ParserConfig["field_map"]; ok { + if fm, ok := fieldMap.(map[string]interface{}); ok { + for k, v := range fm { + conf[k] = v + } + } + } + } + } + return conf, nil } -// GetByIDs gets knowledge bases by IDs -func (dao *KnowledgebaseDAO) GetByIDs(ids []string) ([]*model.Knowledgebase, error) { +// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID with tenant join +// This matches the Python get_kb_by_id method +func (dao *KnowledgebaseDAO) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) { var kbs []*model.Knowledgebase - err := DB.Where("id IN ? AND status = ?", ids, "1").Find(&kbs).Error + err := DB.Model(&model.Knowledgebase{}). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.id = ? AND user_tenant.user_id = ?", kbID, userID). + Limit(1). + Find(&kbs).Error return kbs, err } + +// GetKBByNameAndUserID retrieves a knowledge base by name and user ID with tenant join +// This matches the Python get_kb_by_name method +func (dao *KnowledgebaseDAO) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) { + var kbs []*model.Knowledgebase + err := DB.Model(&model.Knowledgebase{}). + Joins("JOIN user_tenant ON user_tenant.tenant_id = knowledgebase.tenant_id"). + Where("knowledgebase.name = ? AND user_tenant.user_id = ?", kbName, userID). + Limit(1). + Find(&kbs).Error + return kbs, err +} + +// GetList retrieves knowledge bases with filtering by ID and name +// This matches the Python get_list method +func (dao *KnowledgebaseDAO) GetList(tenantIDs []string, userID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, error) { + var kbs []*model.Knowledgebase + var total int64 + + query := DB.Model(&model.Knowledgebase{}). + Where("((tenant_id IN ? AND permission = ?) OR tenant_id = ?) AND status = ?", + tenantIDs, string(model.TenantPermissionTeam), userID, string(model.StatusValid)) + + if id != "" { + query = query.Where("id = ?", id) + } + if name != "" { + query = query.Where("name = ?", name) + } + + if desc { + query = query.Order(orderby + " DESC") + } else { + query = query.Order(orderby + " ASC") + } + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + if pageNumber > 0 && itemsPerPage > 0 { + offset := (pageNumber - 1) * itemsPerPage + if err := query.Offset(offset).Limit(itemsPerPage).Find(&kbs).Error; err != nil { + return nil, 0, err + } + } else { + if err := query.Find(&kbs).Error; err != nil { + return nil, 0, err + } + } + + return kbs, total, nil +} + +// mergeConfig performs a deep merge of configuration maps +func mergeConfig(old, new map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range old { + result[k] = v + } + + for k, v := range new { + if existing, ok := result[k]; ok { + if existingMap, ok := existing.(map[string]interface{}); ok { + if newMap, ok := v.(map[string]interface{}); ok { + result[k] = mergeConfig(existingMap, newMap) + continue + } + } + if existingSlice, ok := existing.([]interface{}); ok { + if newSlice, ok := v.([]interface{}); ok { + merged := append(existingSlice, newSlice...) + seen := make(map[interface{}]bool) + unique := make([]interface{}, 0) + for _, item := range merged { + if !seen[item] { + seen[item] = true + unique = append(unique, item) + } + } + result[k] = unique + continue + } + } + } + result[k] = v + } + + return result +} diff --git a/internal/dao/llm.go b/internal/dao/llm.go index 44590ca9dcf..a522cfb091d 100644 --- a/internal/dao/llm.go +++ b/internal/dao/llm.go @@ -67,3 +67,31 @@ func (dao *LLMDAO) GetByFactoryAndName(factory, name string) (*model.LLM, error) } return &llm, nil } + +// LLMFactoryDAO LLM factory data access object +type LLMFactoryDAO struct{} + +// NewLLMFactoryDAO create LLM factory DAO +func NewLLMFactoryDAO() *LLMFactoryDAO { + return &LLMFactoryDAO{} +} + +// GetAllValid gets all valid LLM factories +func (dao *LLMFactoryDAO) GetAllValid() ([]*model.LLMFactories, error) { + var factories []*model.LLMFactories + err := DB.Where("status = ?", "1").Find(&factories).Error + if err != nil { + return nil, err + } + return factories, nil +} + +// GetByName gets LLM factory by name +func (dao *LLMFactoryDAO) GetByName(name string) (*model.LLMFactories, error) { + var factory model.LLMFactories + err := DB.Where("name = ?", name).First(&factory).Error + if err != nil { + return nil, err + } + return &factory, nil +} diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go index 781c6c20587..0585c481af0 100644 --- a/internal/dao/tenant.go +++ b/internal/dao/tenant.go @@ -75,7 +75,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) { Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id"). Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1"). Scan(&results).Error - + return results, err } @@ -98,3 +98,8 @@ func (dao *TenantDAO) Create(tenant *model.Tenant) error { func (dao *TenantDAO) Delete(id string) error { return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error } + +// Update updates a tenant by ID +func (dao *TenantDAO) Update(id string, updates map[string]interface{}) error { + return DB.Model(&model.Tenant{}).Where("id = ?", id).Updates(updates).Error +} diff --git a/internal/dao/tenant_llm.go b/internal/dao/tenant_llm.go index 8752e041fa9..e563ad0e870 100644 --- a/internal/dao/tenant_llm.go +++ b/internal/dao/tenant_llm.go @@ -94,21 +94,14 @@ func (dao *TenantLLMDAO) Delete(tenantID, factory, modelName string) error { } // GetMyLLMs get tenant LLMs with factory details -func (dao *TenantLLMDAO) GetMyLLMs(tenantID string, includeDetails bool) ([]model.MyLLM, error) { +func (dao *TenantLLMDAO) GetMyLLMs(tenantID string) ([]model.MyLLM, error) { var myLLMs []model.MyLLM - // Base query - query := DB.Table("tenant_llm tl"). - Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status"). + err := DB.Table("tenant_llm tl"). + Select("tl.id, tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status"). Joins("JOIN llm_factories lf ON tl.llm_factory = lf.name"). - Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID) - - // Add detailed fields if requested - if includeDetails { - query = query.Select("tl.llm_factory, lf.logo, lf.tags, tl.model_type, tl.llm_name, tl.used_tokens, tl.status, tl.api_base, tl.max_tokens") - } - - err := query.Find(&myLLMs).Error + Where("tl.tenant_id = ? AND tl.api_key IS NOT NULL", tenantID). + Find(&myLLMs).Error if err != nil { return nil, err } diff --git a/internal/handler/kb.go b/internal/handler/kb.go index e4e2a025b48..a7b5f7ac25b 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -18,20 +18,21 @@ package handler import ( "net/http" + "ragflow/internal/common" + "ragflow/internal/service" "strconv" + "strings" "github.com/gin-gonic/gin" - - "ragflow/internal/service" ) -// KnowledgebaseHandler knowledge base handler +// KnowledgebaseHandler handles knowledge base HTTP requests type KnowledgebaseHandler struct { kbService *service.KnowledgebaseService userService *service.UserService } -// NewKnowledgebaseHandler create knowledge base handler +// NewKnowledgebaseHandler creates a new knowledge base handler func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userService *service.UserService) *KnowledgebaseHandler { return &KnowledgebaseHandler{ kbService: kbService, @@ -39,35 +40,227 @@ func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userServic } } -// ListKbs list knowledge bases +// getUserID extracts user ID from authorization header +// It validates the authorization token and returns the user ID +// Parameters: +// - c: gin.Context - the HTTP request context +// +// Returns: +// - string: the user ID +// - common.ErrorCode: the error code +// - error: any error that occurred +func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCode, error) { + token := c.GetHeader("Authorization") + if token == "" { + return "", common.CodeUnauthorized, ErrMissingAuth + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + return "", code, err + } + + return user.ID, common.CodeSuccess, nil +} + +// jsonResponse sends a JSON response with code and message +func jsonResponse(c *gin.Context, code common.ErrorCode, data interface{}, message string) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": data, + "message": message, + }) +} + +// jsonError sends a JSON error response +func jsonError(c *gin.Context, code common.ErrorCode, message string) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": nil, + "message": message, + }) +} + +// HTTPError represents an HTTP error +type HTTPError struct { + Code common.ErrorCode + Message string +} + +// Error implements the error interface +func (e *HTTPError) Error() string { + return e.Message +} + +var ( + // ErrMissingAuth indicates missing authorization header + ErrMissingAuth = &HTTPError{Code: common.CodeUnauthorized, Message: "Missing Authorization header"} + // ErrInvalidToken indicates invalid access token + ErrInvalidToken = &HTTPError{Code: common.CodeUnauthorized, Message: "Invalid access token"} +) + +// CreateKB handles the create knowledge base request +// @Summary Create Knowledge Base +// @Description Create a new knowledge base (dataset) +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.CreateKBRequest true "knowledge base info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/create [post] +func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.CreateKBRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.CreateKB(&req, userID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// UpdateKB handles the update knowledge base request +// @Summary Update Knowledge Base +// @Description Update an existing knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.UpdateKBRequest true "knowledge base update info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/update [post] +func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.UpdateKBRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.UpdateKB(&req, userID) + if err != nil { + if strings.Contains(err.Error(), "authorization") { + jsonError(c, common.CodeAuthenticationError, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// UpdateMetadataSetting handles the update metadata setting request +// @Summary Update Metadata Setting +// @Description Update metadata settings for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.UpdateMetadataSettingRequest true "metadata setting info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/update_metadata_setting [post] +func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) { + _, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var req service.UpdateMetadataSettingRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.kbService.UpdateMetadataSetting(&req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// GetDetail handles the get knowledge base detail request +// @Summary Get Knowledge Base Detail +// @Description Get detailed information about a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id query string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/detail [get] +func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Query("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + result, code, err := h.kbService.GetDetail(kbID, userID) + if err != nil { + if strings.Contains(err.Error(), "authorized") { + jsonError(c, common.CodeOperatingError, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// ListKbs handles the list knowledge bases request // @Summary List Knowledge Bases -// @Description Get list of knowledge bases with filtering and pagination +// @Description List knowledge bases with pagination and filtering // @Tags knowledgebase // @Accept json // @Produce json -// @Param keywords query string false "search keywords" -// @Param page query int false "page number" -// @Param page_size query int false "items per page" -// @Param parser_id query string false "parser ID filter" -// @Param orderby query string false "order by field" -// @Param desc query bool false "descending order" -// @Param request body service.ListKbsRequest true "filter options" -// @Success 200 {object} service.ListKbsResponse +// @Security ApiKeyAuth +// @Param request body service.ListKbsRequest true "list options" +// @Success 200 {object} map[string]interface{} // @Router /v1/kb/list [post] func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { - // Parse request body - allow empty body + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + var req service.ListKbsRequest if c.Request.ContentLength > 0 { if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) + jsonError(c, common.CodeDataError, err.Error()) return } } - // Extract parameters from query or request body with defaults + // Get parameters from request or query string keywords := "" if req.Keywords != nil { keywords = *req.Keywords @@ -111,7 +304,7 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { if req.Desc != nil { desc = *req.Desc } else if descStr := c.Query("desc"); descStr != "" { - desc = descStr == "true" + desc = strings.ToLower(descStr) == "true" } var ownerIDs []string @@ -119,40 +312,327 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { ownerIDs = *req.OwnerIDs } - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) + result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) + if err != nil { + jsonError(c, code, err.Error()) return } - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteKB handles the delete knowledge base request +// @Summary Delete Knowledge Base +// @Description Soft delete a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body object{kb_id string} true "knowledge base id" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/rm [post] +func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) { + userID, code, err := h.getUserID(c) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + jsonError(c, code, err.Error()) return } - userID := user.ID - // List knowledge bases - result, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) + var req struct { + KBID string `json:"kb_id" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + code, err = h.kbService.DeleteKB(req.KBID, userID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) + if strings.Contains(err.Error(), "authorization") { + jsonError(c, common.CodeAuthenticationError, err.Error()) + return + } + jsonError(c, code, err.Error()) return } - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": result, - "message": "success", - }) + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// ListTags handles the list tags request for a knowledge base +// @Summary List Tags +// @Description List tags for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/tags [get] +func (h *KnowledgebaseHandler) ListTags(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, []string{}, "success") +} + +// ListTagsFromKbs handles the list tags from multiple knowledge bases request +// @Summary List Tags from Knowledge Bases +// @Description List tags from multiple knowledge bases +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_ids query string true "Comma-separated Knowledge Base IDs" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/tags [get] +func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbIDsStr := c.Query("kb_ids") + if kbIDsStr == "" { + jsonError(c, common.CodeDataError, "kb_ids is required") + return + } + + kbIDs := strings.Split(kbIDsStr, ",") + for _, kbID := range kbIDs { + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + } + + jsonResponse(c, common.CodeSuccess, []string{}, "success") +} + +// RemoveTags handles the remove tags request +// @Summary Remove Tags +// @Description Remove tags from a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Param request body object{tags []string} true "tags to remove" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/rm_tags [post] +func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + var req struct { + Tags []string `json:"tags" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// RenameTag handles the rename tag request +// @Summary Rename Tag +// @Description Rename a tag in a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Param request body object{from_tag string, to_tag string} true "tag rename info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/rename_tag [post] +func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + var req struct { + FromTag string `json:"from_tag" binding:"required"` + ToTag string `json:"to_tag" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// KnowledgeGraph handles the get knowledge graph request +// @Summary Get Knowledge Graph +// @Description Get knowledge graph for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/knowledge_graph [get] +func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + result := map[string]interface{}{ + "graph": map[string]interface{}{}, + "mind_map": map[string]interface{}{}, + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteKnowledgeGraph handles the delete knowledge graph request +// @Summary Delete Knowledge Graph +// @Description Delete knowledge graph for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id path string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/{kb_id}/knowledge_graph [delete] +func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Param("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +// GetMeta handles the get metadata request +// @Summary Get Metadata +// @Description Get metadata for knowledge bases +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_ids query string true "Comma-separated Knowledge Base IDs" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/get_meta [get] +func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbIDsStr := c.Query("kb_ids") + if kbIDsStr == "" { + jsonError(c, common.CodeDataError, "kb_ids is required") + return + } + + kbIDs := strings.Split(kbIDsStr, ",") + for _, kbID := range kbIDs { + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + } + + jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success") +} + +// GetBasicInfo handles the get basic info request +// @Summary Get Basic Info +// @Description Get basic information for a knowledge base +// @Tags knowledgebase +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param kb_id query string true "Knowledge Base ID" +// @Success 200 {object} map[string]interface{} +// @Router /v1/kb/basic_info [get] +func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) { + userID, code, err := h.getUserID(c) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + kbID := c.Query("kb_id") + if kbID == "" { + jsonError(c, common.CodeDataError, "kb_id is required") + return + } + + if !h.kbService.Accessible(kbID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success") } diff --git a/internal/handler/llm.go b/internal/handler/llm.go index 6926dfc97d2..9582eb37adb 100644 --- a/internal/handler/llm.go +++ b/internal/handler/llm.go @@ -21,6 +21,7 @@ import ( "github.com/gin-gonic/gin" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/service" ) @@ -60,50 +61,112 @@ func NewLLMHandler(llmService *service.LLMService, userService *service.UserServ // @Success 200 {object} map[string]interface{} // @Router /v1/llm/my_llms [get] func (h *LLMHandler) GetMyLLMs(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant ID from user tenantID := user.ID - if tenantID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User has no tenant ID", + includeDetailsStr := c.DefaultQuery("include_details", "false") + includeDetails := includeDetailsStr == "true" + + llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } - // Parse include_details query parameter - includeDetailsStr := c.DefaultQuery("include_details", "false") - includeDetails := includeDetailsStr == "true" + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": llms, + }) +} - // Get LLMs for tenant - llms, err := h.llmService.GetMyLLMs(tenantID, includeDetails) +// SetAPIKey set API key for a LLM factory +// @Summary Set API Key +// @Description Set API key for a LLM factory and test connectivity +// @Tags llm +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.SetAPIKeyRequest true "API Key configuration" +// @Success 200 {object} map[string]interface{} +// @Router /v1/llm/set_api_key [post] +func (h *LLMHandler) SetAPIKey(c *gin.Context) { + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, + }) + return + } + + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get LLMs", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, + }) + return + } + + var req service.SetAPIKeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "Invalid request: " + err.Error(), + "data": false, + }) + return + } + + tenantID := user.ID + result, err := h.llmService.SetAPIKey(tenantID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": err.Error(), + "data": false, + }) + return + } + + if req.Verify { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": result, }) return } c.JSON(http.StatusOK, gin.H{ - "data": llms, + "code": common.CodeSuccess, + "message": "success", + "data": true, }) } @@ -198,52 +261,43 @@ func (h *LLMHandler) Factories(c *gin.Context) { // @Success 200 {object} map[string][]service.LLMListItem // @Router /v1/llm/list [get] func (h *LLMHandler) ListApp(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant ID from user tenantID := user.ID - if tenantID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "User has no tenant ID", - }) - return - } - // Parse model_type query parameter modelType := c.Query("model_type") - // Get LLM list llms, err := h.llmService.ListLLMs(tenantID, modelType) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": llms, + "code": common.CodeSuccess, "message": "success", + "data": llms, }) } diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 02b87a41643..860acc3bbbc 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -21,6 +21,7 @@ import ( "github.com/gin-gonic/gin" + "ragflow/internal/common" "ragflow/internal/service" ) @@ -48,44 +49,49 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service // @Success 200 {object} map[string]interface{} // @Router /v1/user/tenant_info [get] func (h *TenantHandler) TenantInfo(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant info tenantInfo, err := h.tenantService.GetTenantInfo(user.ID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get tenant information", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } if tenantInfo == nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": "Tenant not found", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": "Tenant not found!", + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": tenantInfo, + "code": common.CodeSuccess, + "message": "success", + "data": tenantInfo, }) } @@ -99,38 +105,39 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/tenant/list [get] func (h *TenantHandler) TenantList(c *gin.Context) { - // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, }) return } - // Get user by token user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), + "data": false, }) return } - // Get tenant list tenantList, err := h.tenantService.GetTenantList(user.ID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": "Failed to get tenant list", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": tenantList, + "code": common.CodeSuccess, + "message": "success", + "data": tenantList, }) } diff --git a/internal/handler/user.go b/internal/handler/user.go index 7fb39a5df96..3651c29c148 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -522,3 +522,61 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) { "data": channels, }) } + +// SetTenantInfo update tenant information +// @Summary Set Tenant Info +// @Description Update tenant model configuration +// @Tags users +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param request body service.SetTenantInfoRequest true "tenant info" +// @Success 200 {object} map[string]interface{} +// @Router /v1/user/set_tenant_info [post] +func (h *UserHandler) SetTenantInfo(c *gin.Context) { + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, + "message": "Unauthorized!", + "data": false, + }) + return + } + + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, + }) + return + } + + var req service.SetTenantInfoRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + "data": false, + }) + return + } + + err = h.userService.SetTenantInfo(user.ID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": err.Error(), + "data": false, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": true, + }) +} diff --git a/internal/init_data/llm_init.go b/internal/init_data/llm_init.go new file mode 100644 index 00000000000..ef67dd6ba3a --- /dev/null +++ b/internal/init_data/llm_init.go @@ -0,0 +1,157 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package init_data + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + + "ragflow/internal/dao" + "ragflow/internal/model" +) + +// LLMFactoryConfig represents a single LLM factory configuration +type LLMFactoryConfig struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + LLM []LLMConfig `json:"llm"` +} + +// LLMConfig represents a single LLM model configuration +type LLMConfig struct { + LLMName string `json:"llm_name"` + Tags string `json:"tags"` + MaxTokens int64 `json:"max_tokens"` + ModelType string `json:"model_type"` + IsTools bool `json:"is_tools"` +} + +// LLMFactoriesFile represents the structure of llm_factories.json +type LLMFactoriesFile struct { + FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` +} + +// InitLLMFactory initializes LLM factories and models from JSON file +func InitLLMFactory() error { + configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json") + + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read llm_factories.json: %w", err) + } + + var fileData LLMFactoriesFile + if err := json.Unmarshal(data, &fileData); err != nil { + return fmt.Errorf("failed to parse llm_factories.json: %w", err) + } + + db := dao.DB + + for _, factory := range fileData.FactoryLLMInfos { + status := factory.Status + if status == "" { + status = "1" + } + + llmFactory := &model.LLMFactories{ + Name: factory.Name, + Logo: stringPtr(factory.Logo), + Tags: factory.Tags, + Rank: parseInt64(factory.Rank), + Status: &status, + } + + var existingFactory model.LLMFactories + result := db.Where("name = ?", factory.Name).First(&existingFactory) + if result.Error != nil { + if err := db.Create(llmFactory).Error; err != nil { + log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) + continue + } + } else { + if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ + "logo": llmFactory.Logo, + "tags": llmFactory.Tags, + "rank": llmFactory.Rank, + "status": llmFactory.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) + } + } + + for _, llm := range factory.LLM { + llmStatus := "1" + llmModel := &model.LLM{ + LLMName: llm.LLMName, + ModelType: llm.ModelType, + FID: factory.Name, + MaxTokens: llm.MaxTokens, + Tags: llm.Tags, + IsTools: llm.IsTools, + Status: &llmStatus, + } + + var existingLLM model.LLM + result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) + if result.Error != nil { + if err := db.Create(llmModel).Error; err != nil { + log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } else { + if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ + "model_type": llmModel.ModelType, + "max_tokens": llmModel.MaxTokens, + "tags": llmModel.Tags, + "is_tools": llmModel.IsTools, + "status": llmModel.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } + } + } + + log.Println("LLM factories initialized successfully") + return nil +} + +func getProjectBaseDirectory() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +func stringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +func parseInt64(s string) int64 { + var result int64 + fmt.Sscanf(s, "%d", &result) + return result +} diff --git a/internal/model/kb.go b/internal/model/kb.go index 78cc643721d..da5f817e292 100644 --- a/internal/model/kb.go +++ b/internal/model/kb.go @@ -18,7 +18,84 @@ package model import "time" -// Knowledgebase knowledge base model +// DatasetNameLimit is the maximum length for dataset name +const DatasetNameLimit = 128 + +// Status represents the status enum values +type Status string + +const ( + // StatusValid indicates a valid/active record + StatusValid Status = "1" + // StatusInvalid indicates a deleted/inactive record + StatusInvalid Status = "0" +) + +// TenantPermission represents the permission level for tenant access +type TenantPermission string + +const ( + // TenantPermissionMe indicates only the creator can access + TenantPermissionMe TenantPermission = "me" + // TenantPermissionTeam indicates all team members can access + TenantPermissionTeam TenantPermission = "team" +) + +// ParserType represents the document parser type +type ParserType string + +const ( + ParserTypePresentation ParserType = "presentation" + ParserTypeLaws ParserType = "laws" + ParserTypeManual ParserType = "manual" + ParserTypePaper ParserType = "paper" + ParserTypeResume ParserType = "resume" + ParserTypeBook ParserType = "book" + ParserTypeQA ParserType = "qa" + ParserTypeTable ParserType = "table" + ParserTypeNaive ParserType = "naive" + ParserTypePicture ParserType = "picture" + ParserTypeOne ParserType = "one" + ParserTypeAudio ParserType = "audio" + ParserTypeEmail ParserType = "email" + ParserTypeKG ParserType = "knowledge_graph" + ParserTypeTag ParserType = "tag" +) + +// TaskStatus represents the status of a processing task +type TaskStatus string + +const ( + TaskStatusUnstart TaskStatus = "0" + TaskStatusRunning TaskStatus = "1" + TaskStatusCancel TaskStatus = "2" + TaskStatusDone TaskStatus = "3" + TaskStatusFail TaskStatus = "4" + TaskStatusSchedule TaskStatus = "5" +) + +// PipelineTaskType represents the type of pipeline task +type PipelineTaskType string + +const ( + PipelineTaskTypeParse PipelineTaskType = "Parse" + PipelineTaskTypeDownload PipelineTaskType = "Download" + PipelineTaskTypeRAPTOR PipelineTaskType = "RAPTOR" + PipelineTaskTypeGraphRAG PipelineTaskType = "GraphRAG" + PipelineTaskTypeMindmap PipelineTaskType = "Mindmap" + PipelineTaskTypeMemory PipelineTaskType = "Memory" +) + +// FileSource represents the source of a file +type FileSource string + +const ( + FileSourceLocal FileSource = "" + FileSourceKnowledgebase FileSource = "knowledgebase" + FileSourceS3 FileSource = "s3" +) + +// Knowledgebase represents the knowledge base model type Knowledgebase struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` Avatar *string `gorm:"column:avatar;type:longtext" json:"avatar,omitempty"` @@ -27,7 +104,6 @@ type Knowledgebase struct { Language *string `gorm:"column:language;size:32;index" json:"language,omitempty"` Description *string `gorm:"column:description;type:longtext" json:"description,omitempty"` EmbdID string `gorm:"column:embd_id;size:128;not null;index" json:"embd_id"` - TenantEmbdID *int64 `gorm:"column:tenant_embd_id;index" json:"tenant_embd_id,omitempty"` Permission string `gorm:"column:permission;size:16;not null;default:me;index" json:"permission"` CreatedBy string `gorm:"column:created_by;size:32;not null;index" json:"created_by"` DocNum int64 `gorm:"column:doc_num;default:0;index" json:"doc_num"` @@ -37,7 +113,7 @@ type Knowledgebase struct { VectorSimilarityWeight float64 `gorm:"column:vector_similarity_weight;default:0.3;index" json:"vector_similarity_weight"` ParserID string `gorm:"column:parser_id;size:32;not null;default:naive;index" json:"parser_id"` PipelineID *string `gorm:"column:pipeline_id;size:32;index" json:"pipeline_id,omitempty"` - ParserConfig JSONMap `gorm:"column:parser_config;type:longtext;not null" json:"parser_config"` + ParserConfig JSONMap `gorm:"column:parser_config;type:json" json:"parser_config"` Pagerank int64 `gorm:"column:pagerank;default:0" json:"pagerank"` GraphragTaskID *string `gorm:"column:graphrag_task_id;size:32;index" json:"graphrag_task_id,omitempty"` GraphragTaskFinishAt *time.Time `gorm:"column:graphrag_task_finish_at" json:"graphrag_task_finish_at,omitempty"` @@ -49,12 +125,118 @@ type Knowledgebase struct { BaseModel } -// TableName specify table name +// TableName returns the table name for Knowledgebase model func (Knowledgebase) TableName() string { return "knowledgebase" } -// InvitationCode invitation code model +// ToMap converts Knowledgebase to a map for JSON response +func (kb *Knowledgebase) ToMap() map[string]interface{} { + result := map[string]interface{}{ + "id": kb.ID, + "tenant_id": kb.TenantID, + "name": kb.Name, + "embd_id": kb.EmbdID, + "permission": kb.Permission, + "created_by": kb.CreatedBy, + "doc_num": kb.DocNum, + "token_num": kb.TokenNum, + "chunk_num": kb.ChunkNum, + "similarity_threshold": kb.SimilarityThreshold, + "vector_similarity_weight": kb.VectorSimilarityWeight, + "parser_id": kb.ParserID, + "parser_config": kb.ParserConfig, + "pagerank": kb.Pagerank, + "create_time": kb.CreateTime, + } + + if kb.Avatar != nil { + result["avatar"] = *kb.Avatar + } + if kb.Language != nil { + result["language"] = *kb.Language + } + if kb.Description != nil { + result["description"] = *kb.Description + } + if kb.PipelineID != nil { + result["pipeline_id"] = *kb.PipelineID + } + if kb.GraphragTaskID != nil { + result["graphrag_task_id"] = *kb.GraphragTaskID + } + if kb.GraphragTaskFinishAt != nil { + result["graphrag_task_finish_at"] = kb.GraphragTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.RaptorTaskID != nil { + result["raptor_task_id"] = *kb.RaptorTaskID + } + if kb.RaptorTaskFinishAt != nil { + result["raptor_task_finish_at"] = kb.RaptorTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.MindmapTaskID != nil { + result["mindmap_task_id"] = *kb.MindmapTaskID + } + if kb.MindmapTaskFinishAt != nil { + result["mindmap_task_finish_at"] = kb.MindmapTaskFinishAt.Format("2006-01-02 15:04:05") + } + if kb.UpdateTime != nil { + result["update_time"] = *kb.UpdateTime + } + + return result +} + +// KnowledgebaseDetail represents detailed knowledge base information with joined data +type KnowledgebaseDetail struct { + ID string `json:"id"` + EmbdID string `json:"embd_id"` + Avatar *string `json:"avatar,omitempty"` + Name string `json:"name"` + Language *string `json:"language,omitempty"` + Description *string `json:"description,omitempty"` + Permission string `json:"permission"` + DocNum int64 `json:"doc_num"` + TokenNum int64 `json:"token_num"` + ChunkNum int64 `json:"chunk_num"` + ParserID string `json:"parser_id"` + PipelineID *string `json:"pipeline_id,omitempty"` + PipelineName *string `json:"pipeline_name,omitempty"` + PipelineAvatar *string `json:"pipeline_avatar,omitempty"` + ParserConfig JSONMap `json:"parser_config"` + Pagerank int64 `json:"pagerank"` + GraphragTaskID *string `json:"graphrag_task_id,omitempty"` + GraphragTaskFinishAt *string `json:"graphrag_task_finish_at,omitempty"` + RaptorTaskID *string `json:"raptor_task_id,omitempty"` + RaptorTaskFinishAt *string `json:"raptor_task_finish_at,omitempty"` + MindmapTaskID *string `json:"mindmap_task_id,omitempty"` + MindmapTaskFinishAt *string `json:"mindmap_task_finish_at,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` + Size int64 `json:"size"` + Connectors []string `json:"connectors"` +} + +// KnowledgebaseListItem represents a knowledge base item in list responses +type KnowledgebaseListItem struct { + ID string `json:"id"` + Avatar *string `json:"avatar,omitempty"` + Name string `json:"name"` + Language *string `json:"language,omitempty"` + Description *string `json:"description,omitempty"` + TenantID string `json:"tenant_id"` + Permission string `json:"permission"` + DocNum int64 `json:"doc_num"` + TokenNum int64 `json:"token_num"` + ChunkNum int64 `json:"chunk_num"` + ParserID string `json:"parser_id"` + EmbdID string `json:"embd_id"` + Nickname string `json:"nickname"` + TenantAvatar *string `json:"tenant_avatar,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` +} + +// InvitationCode represents the invitation code model type InvitationCode struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` Code string `gorm:"column:code;size:32;not null;index" json:"code"` @@ -65,7 +247,7 @@ type InvitationCode struct { BaseModel } -// TableName specify table name +// TableName returns the table name for InvitationCode model func (InvitationCode) TableName() string { return "invitation_code" } diff --git a/internal/model/llm.go b/internal/model/llm.go index 9b9054e7e68..665dee78679 100644 --- a/internal/model/llm.go +++ b/internal/model/llm.go @@ -64,13 +64,14 @@ func (TenantLangfuse) TableName() string { // MyLLM represents LLM information for a tenant with factory details type MyLLM struct { + ID string `gorm:"column:id" json:"id"` LLMFactory string `gorm:"column:llm_factory" json:"llm_factory"` Logo *string `gorm:"column:logo" json:"logo,omitempty"` - Tags string `gorm:"column:tags" json:"tags"` - ModelType string `gorm:"column:model_type" json:"model_type"` - LLMName string `gorm:"column:llm_name" json:"llm_name"` - UsedTokens int64 `gorm:"column:used_tokens" json:"used_tokens"` - Status string `gorm:"column:status" json:"status"` - APIBase string `gorm:"column:api_base" json:"api_base,omitempty"` - MaxTokens int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"` + Tags *string `gorm:"column:tags" json:"tags"` + ModelType *string `gorm:"column:model_type" json:"model_type"` + LLMName *string `gorm:"column:llm_name" json:"llm_name"` + UsedTokens *int64 `gorm:"column:used_tokens" json:"used_tokens"` + Status *string `gorm:"column:status" json:"status"` + APIBase *string `gorm:"column:api_base" json:"api_base,omitempty"` + MaxTokens *int64 `gorm:"column:max_tokens" json:"max_tokens,omitempty"` } diff --git a/internal/router/router.go b/internal/router/router.go index 5f41765d60f..cebd6b97ac7 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -101,6 +101,8 @@ func (r *Router) Setup(engine *gin.Engine) { engine.POST("/v1/user/setting", r.userHandler.Setting) // User change password endpoint engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword) + // User set tenant info endpoint + engine.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo) // API v1 route group v1 := engine.Group("/api/v1") @@ -134,7 +136,25 @@ func (r *Router) Setup(engine *gin.Engine) { // Knowledge base routes kb := engine.Group("/v1/kb") { + kb.POST("/create", r.knowledgebaseHandler.CreateKB) + kb.POST("/update", r.knowledgebaseHandler.UpdateKB) + kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting) + kb.GET("/detail", r.knowledgebaseHandler.GetDetail) kb.POST("/list", r.knowledgebaseHandler.ListKbs) + kb.POST("/rm", r.knowledgebaseHandler.DeleteKB) + kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs) + kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta) + kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo) + + // KB ID specific routes + kbByID := kb.Group("/:kb_id") + { + kbByID.GET("/tags", r.knowledgebaseHandler.ListTags) + kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags) + kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag) + kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph) + kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph) + } } // Chunk routes @@ -149,6 +169,7 @@ func (r *Router) Setup(engine *gin.Engine) { llm.GET("/my_llms", r.llmHandler.GetMyLLMs) llm.GET("/factories", r.llmHandler.Factories) llm.GET("/list", r.llmHandler.ListApp) + llm.POST("/set_api_key", r.llmHandler.SetAPIKey) } // Chat routes diff --git a/internal/server/config.go b/internal/server/config.go index 5a8fbf1e1d5..fe9cdea48ed 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -40,6 +40,29 @@ type Config struct { DocEngine DocEngineConfig `mapstructure:"doc_engine"` RegisterEnabled int `mapstructure:"register_enabled"` OAuth map[string]OAuthConfig `mapstructure:"oauth"` + UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"` +} + +// UserDefaultLLMConfig user default LLM configuration +type UserDefaultLLMConfig struct { + DefaultModels DefaultModelsConfig `mapstructure:"default_models"` +} + +// DefaultModelsConfig default models configuration +type DefaultModelsConfig struct { + ChatModel ModelConfig `mapstructure:"chat_model"` + EmbeddingModel ModelConfig `mapstructure:"embedding_model"` + RerankModel ModelConfig `mapstructure:"rerank_model"` + ASRModel ModelConfig `mapstructure:"asr_model"` + Image2TextModel ModelConfig `mapstructure:"image2text_model"` +} + +// ModelConfig model configuration +type ModelConfig struct { + Name string `mapstructure:"name"` + APIKey string `mapstructure:"api_key"` + BaseURL string `mapstructure:"base_url"` + Factory string `mapstructure:"factory"` } // OAuthConfig OAuth configuration for a channel @@ -414,6 +437,45 @@ func Init(configPath string) error { } } + // Map user_default_llm section to UserDefaultLLMConfig + if v.IsSet("user_default_llm") { + userDefaultLLMConfig := v.Sub("user_default_llm") + if userDefaultLLMConfig != nil { + if defaultModels := userDefaultLLMConfig.Sub("default_models"); defaultModels != nil { + globalConfig.UserDefaultLLM.DefaultModels.ChatModel = ModelConfig{ + Name: defaultModels.GetString("chat_model.name"), + APIKey: defaultModels.GetString("chat_model.api_key"), + BaseURL: defaultModels.GetString("chat_model.base_url"), + Factory: defaultModels.GetString("chat_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.EmbeddingModel = ModelConfig{ + Name: defaultModels.GetString("embedding_model.name"), + APIKey: defaultModels.GetString("embedding_model.api_key"), + BaseURL: defaultModels.GetString("embedding_model.base_url"), + Factory: defaultModels.GetString("embedding_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.RerankModel = ModelConfig{ + Name: defaultModels.GetString("rerank_model.name"), + APIKey: defaultModels.GetString("rerank_model.api_key"), + BaseURL: defaultModels.GetString("rerank_model.base_url"), + Factory: defaultModels.GetString("rerank_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.ASRModel = ModelConfig{ + Name: defaultModels.GetString("asr_model.name"), + APIKey: defaultModels.GetString("asr_model.api_key"), + BaseURL: defaultModels.GetString("asr_model.base_url"), + Factory: defaultModels.GetString("asr_model.factory"), + } + globalConfig.UserDefaultLLM.DefaultModels.Image2TextModel = ModelConfig{ + Name: defaultModels.GetString("image2text_model.name"), + APIKey: defaultModels.GetString("image2text_model.api_key"), + BaseURL: defaultModels.GetString("image2text_model.base_url"), + Factory: defaultModels.GetString("image2text_model.factory"), + } + } + } + } + return nil } diff --git a/internal/service/chat.go b/internal/service/chat.go index b53706d180e..c1a7915ac77 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -82,7 +82,7 @@ func (s *ChatService) ListChats(userID string, status string) (*ListChatsRespons } // Enrich with knowledge base names - var chatsWithKBNames []*ChatWithKBNames + chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames := s.getKBNames(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ @@ -148,7 +148,7 @@ func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSi } // Enrich with knowledge base names - var chatsWithKBNames []*ChatWithKBNames + chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames := s.getKBNames(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ diff --git a/internal/service/kb.go b/internal/service/kb.go index 8b982ebe6f6..143c0e9a40e 100644 --- a/internal/service/kb.go +++ b/internal/service/kb.go @@ -17,25 +17,76 @@ package service import ( + "errors" + "fmt" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/model" + "ragflow/internal/utility" + "strings" + "time" + + "github.com/google/uuid" ) -// KnowledgebaseService knowledge base service +// KnowledgebaseService service class for managing dataset operations type KnowledgebaseService struct { kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO + userDAO *dao.UserDAO + tenantDAO *dao.TenantDAO + connectorDAO *dao.ConnectorDAO } -// NewKnowledgebaseService create knowledge base service +// NewKnowledgebaseService creates a new knowledge base service func NewKnowledgebaseService() *KnowledgebaseService { return &KnowledgebaseService{ kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), + userDAO: dao.NewUserDAO(), + tenantDAO: dao.NewTenantDAO(), + connectorDAO: dao.NewConnectorDAO(), } } -// ListKbsRequest list knowledge bases request +// CreateKBRequest represents the request for creating a knowledge base +type CreateKBRequest struct { + Name string `json:"name" binding:"required"` + ParserID *string `json:"parser_id,omitempty"` + Description *string `json:"description,omitempty"` + Language *string `json:"language,omitempty"` + Permission *string `json:"permission,omitempty"` + Avatar *string `json:"avatar,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` +} + +// CreateKBResponse represents the response for creating a knowledge base +type CreateKBResponse struct { + KBID string `json:"kb_id"` +} + +// UpdateKBRequest represents the request for updating a knowledge base +type UpdateKBRequest struct { + KBID string `json:"kb_id" binding:"required"` + Name string `json:"name" binding:"required"` + Description *string `json:"description"` + ParserID string `json:"parser_id" binding:"required"` + Permission *string `json:"permission,omitempty"` + Language *string `json:"language,omitempty"` + Avatar *string `json:"avatar,omitempty"` + Pagerank *int64 `json:"pagerank,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` + Connectors []string `json:"connectors,omitempty"` +} + +// UpdateMetadataSettingRequest represents the request for updating metadata settings +type UpdateMetadataSettingRequest struct { + KBID string `json:"kb_id" binding:"required"` + Metadata map[string]interface{} `json:"metadata" binding:"required"` + EnableMetadata *bool `json:"enable_metadata,omitempty"` +} + +// ListKbsRequest represents the request for listing knowledge bases type ListKbsRequest struct { Keywords *string `json:"keywords,omitempty"` Page *int `json:"page,omitempty"` @@ -46,37 +97,461 @@ type ListKbsRequest struct { OwnerIDs *[]string `json:"owner_ids,omitempty"` } -// ListKbsResponse list knowledge bases response +// ListKbsResponse represents the response for listing knowledge bases type ListKbsResponse struct { - KBs []*model.Knowledgebase `json:"kbs"` - Total int64 `json:"total"` + KBs []map[string]interface{} `json:"kbs"` + Total int64 `json:"total"` +} + +// CreateKB creates a new knowledge base +// This matches the Python create endpoint in kb_app.py +func (s *KnowledgebaseService) CreateKB(req *CreateKBRequest, tenantID string) (*CreateKBResponse, common.ErrorCode, error) { + // Validate name is a string + if !isValidString(req.Name) { + return nil, common.CodeDataError, errors.New("Dataset name must be string.") + } + + // Trim and validate name + name := strings.TrimSpace(req.Name) + if name == "" { + return nil, common.CodeDataError, errors.New("Dataset name can't be empty.") + } + + // Check name length (using UTF-8 byte length like Python) + if len(name) > model.DatasetNameLimit { + return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit) + } + + // Verify tenant exists + tenant, err := s.tenantDAO.GetByID(tenantID) + if err != nil { + return nil, common.CodeDataError, errors.New("Tenant not found.") + } + + // Deduplicate name within tenant + duplicateName := s.kbDAO.DuplicateName(name, tenantID) + + // Get parser ID (default to "naive") + parserID := "naive" + if req.ParserID != nil && *req.ParserID != "" { + parserID = *req.ParserID + } + + // Get parser config with defaults + parserConfig := getParserConfig(parserID, req.ParserConfig) + parserConfig["llm_id"] = tenant.LLMID + + // Generate KB ID + kbID := strings.ReplaceAll(uuid.New().String(), "-", "") + + // Create knowledge base model + now := time.Now().Unix() + nowDate := time.Now() + kb := &model.Knowledgebase{ + ID: kbID, + Name: duplicateName, + TenantID: tenantID, + CreatedBy: tenantID, + ParserID: parserID, + ParserConfig: parserConfig, + Permission: "me", + EmbdID: "", + } + kb.CreateTime = &now + kb.UpdateTime = &now + kb.CreateDate = &nowDate + kb.UpdateDate = &nowDate + status := string(model.StatusValid) + kb.Status = &status + + // Set optional fields + if req.Description != nil { + kb.Description = req.Description + } + if req.Language != nil { + kb.Language = req.Language + } + if req.Permission != nil { + kb.Permission = *req.Permission + } + if req.Avatar != nil { + kb.Avatar = req.Avatar + } + + // Create in database + if err := s.kbDAO.Create(kb); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to create knowledge base: %w", err) + } + + return &CreateKBResponse{KBID: kbID}, common.CodeSuccess, nil +} + +// UpdateKB updates an existing knowledge base +// This matches the Python update endpoint in kb_app.py +func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (map[string]interface{}, common.ErrorCode, error) { + // Validate name is a string + if !isValidString(req.Name) { + return nil, common.CodeDataError, errors.New("Dataset name must be string.") + } + + // Trim and validate name + name := strings.TrimSpace(req.Name) + if name == "" { + return nil, common.CodeDataError, errors.New("Dataset name can't be empty.") + } + + // Check name length + if len(name) > model.DatasetNameLimit { + return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), model.DatasetNameLimit) + } + + // Check authorization + if !s.kbDAO.Accessible4Deletion(req.KBID, userID) { + return nil, common.CodeAuthenticationError, errors.New("No authorization.") + } + + // Verify ownership + kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": req.KBID}) + if err != nil || len(kbs) == 0 { + return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Get existing KB + kb, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("can't find this dataset") + } + + // Check for duplicate name + if strings.ToLower(name) != strings.ToLower(kb.Name) { + existingKB, _ := s.kbDAO.GetByName(name, userID) + if existingKB != nil { + return nil, common.CodeDataError, errors.New("duplicated dataset name") + } + } + + // Build updates + updates := map[string]interface{}{ + "name": name, + "parser_id": req.ParserID, + } + + if req.Description != nil { + updates["description"] = *req.Description + } + if req.Permission != nil { + updates["permission"] = *req.Permission + } + if req.Language != nil { + updates["language"] = *req.Language + } + if req.Avatar != nil { + updates["avatar"] = *req.Avatar + } + if req.Pagerank != nil { + updates["pagerank"] = *req.Pagerank + } + if req.ParserConfig != nil { + updates["parser_config"] = req.ParserConfig + } + + now := time.Now().Unix() + nowDate := time.Now() + updates["update_time"] = now + updates["update_date"] = nowDate + + // Update in database + if err := s.kbDAO.UpdateByID(req.KBID, updates); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update knowledge base: %w", err) + } + + // Get updated KB + updatedKB, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("database error (knowledgebase rename)") + } + + result := updatedKB.ToMap() + result["connectors"] = req.Connectors + + return result, common.CodeSuccess, nil } -// ListKbs list knowledge bases -func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, error) { - var kbs []*model.Knowledgebase +// UpdateMetadataSetting updates the metadata settings for a knowledge base +func (s *KnowledgebaseService) UpdateMetadataSetting(req *UpdateMetadataSettingRequest) (map[string]interface{}, common.ErrorCode, error) { + kb, err := s.kbDAO.GetByID(req.KBID) + if err != nil { + return nil, common.CodeDataError, errors.New("database error (knowledgebase not found)") + } + + parserConfig := kb.ParserConfig + if parserConfig == nil { + parserConfig = make(map[string]interface{}) + } + + parserConfig["metadata"] = req.Metadata + enableMetadata := true + if req.EnableMetadata != nil { + enableMetadata = *req.EnableMetadata + } + parserConfig["enable_metadata"] = enableMetadata + + if err := s.kbDAO.UpdateParserConfig(req.KBID, parserConfig); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update metadata setting: %w", err) + } + + result := kb.ToMap() + result["parser_config"] = parserConfig + + return result, common.CodeSuccess, nil +} + +// GetDetail retrieves detailed information about a knowledge base +// This matches the Python kb_detail endpoint in kb_app.py +func (s *KnowledgebaseService) GetDetail(kbID, userID string) (*model.KnowledgebaseDetail, common.ErrorCode, error) { + // Check authorization + if !s.kbDAO.Accessible(kbID, userID) { + return nil, common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Get detail + detail, err := s.kbDAO.GetDetail(kbID) + if err != nil { + return nil, common.CodeDataError, errors.New("can't find this dataset") + } + + // Set connectors (empty for now) + detail.Connectors = []string{} + + return detail, common.CodeSuccess, nil +} + +// ListKbs lists knowledge bases with pagination and filtering +// This matches the Python list endpoint in kb_app.py +func (s *KnowledgebaseService) ListKbs(keywords string, page int, pageSize int, parserID string, orderby string, desc bool, ownerIDs []string, userID string) (*ListKbsResponse, common.ErrorCode, error) { + var kbs []*model.KnowledgebaseListItem var total int64 var err error - // If owner IDs are provided, filter by them - if ownerIDs != nil && len(ownerIDs) > 0 { - kbs, total, err = s.kbDAO.ListByOwnerIDs(ownerIDs, page, pageSize, orderby, desc, keywords, parserID) + if len(ownerIDs) > 0 { + // List by owner IDs + kbs, total, err = s.kbDAO.GetByTenantIDs(ownerIDs, userID, page, pageSize, orderby, desc, keywords, parserID) } else { - // Get tenant IDs by user ID + // Get tenant IDs for user tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { - return nil, err + return nil, common.CodeServerError, err } - kbs, total, err = s.kbDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID) + kbs, total, err = s.kbDAO.GetByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID) } if err != nil { - return nil, err + return nil, common.CodeServerError, err + } + + // Convert to map slice + kbMaps := make([]map[string]interface{}, len(kbs)) + for i, kb := range kbs { + kbMaps[i] = map[string]interface{}{ + "id": kb.ID, + "avatar": kb.Avatar, + "name": kb.Name, + "language": kb.Language, + "description": kb.Description, + "tenant_id": kb.TenantID, + "permission": kb.Permission, + "doc_num": kb.DocNum, + "token_num": kb.TokenNum, + "chunk_num": kb.ChunkNum, + "parser_id": kb.ParserID, + "embd_id": kb.EmbdID, + "nickname": kb.Nickname, + "tenant_avatar": kb.TenantAvatar, + "update_time": kb.UpdateTime, + } } return &ListKbsResponse{ - KBs: kbs, + KBs: kbMaps, Total: total, - }, nil + }, common.CodeSuccess, nil +} + +// DeleteKB soft deletes a knowledge base +// This matches the Python rm endpoint in kb_app.py +func (s *KnowledgebaseService) DeleteKB(kbID, userID string) (common.ErrorCode, error) { + // Check authorization + if !s.kbDAO.Accessible4Deletion(kbID, userID) { + return common.CodeAuthenticationError, errors.New("No authorization.") + } + + // Verify ownership + kbs, err := s.kbDAO.Query(map[string]interface{}{"created_by": userID, "id": kbID}) + if err != nil || len(kbs) == 0 { + return common.CodeOperatingError, errors.New("only owner of dataset authorized for this operation") + } + + // Soft delete + if err := s.kbDAO.Delete(kbID); err != nil { + return common.CodeServerError, fmt.Errorf("database error (knowledgebase removal): %w", err) + } + + return common.CodeSuccess, nil +} + +// Accessible checks if a knowledge base is accessible by a user +func (s *KnowledgebaseService) Accessible(kbID, userID string) bool { + return s.kbDAO.Accessible(kbID, userID) +} + +// GetByID retrieves a knowledge base by ID +func (s *KnowledgebaseService) GetByID(kbID string) (*model.Knowledgebase, error) { + return s.kbDAO.GetByID(kbID) +} + +// GetKBIDsByTenantID retrieves all knowledge base IDs for a tenant +func (s *KnowledgebaseService) GetKBIDsByTenantID(tenantID string) ([]string, error) { + return s.kbDAO.GetKBIDsByTenantID(tenantID) +} + +// isValidString checks if a value is a non-empty string +func isValidString(v interface{}) bool { + str, ok := v.(string) + return ok && str != "" +} + +// getParserConfig returns the parser configuration with defaults +// This matches the Python get_parser_config function +func getParserConfig(parserID string, customConfig map[string]interface{}) map[string]interface{} { + config := map[string]interface{}{ + "pages": [][]int{{1, 1000000}}, + "table_context_size": 0, + "image_context_size": 0, + } + + switch parserID { + case "table": + config["layout_recognize"] = false + config["chunk_token_num"] = 128 + config["delimiter"] = "\n!?;。;!?" + config["html4excel"] = false + case "naive": + config["chunk_token_num"] = 128 + config["delimiter"] = "\n!?;。;!?" + config["html4excel"] = false + default: + config["raptor"] = map[string]interface{}{ + "use_raptor": false, + } + } + + // Merge custom config if provided + if customConfig != nil { + config = mergeParserConfig(config, customConfig) + } + + return config +} + +// mergeParserConfig merges two parser configurations +func mergeParserConfig(base, override map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range base { + result[k] = v + } + + for k, v := range override { + if existing, ok := result[k]; ok { + if existingMap, ok := existing.(map[string]interface{}); ok { + if newMap, ok := v.(map[string]interface{}); ok { + result[k] = mergeParserConfig(existingMap, newMap) + continue + } + } + } + result[k] = v + } + + return result +} + +// GenerateUUID generates a UUID string without dashes +func GenerateUUID() string { + return strings.ReplaceAll(uuid.New().String(), "-", "") +} + +// GetUserByToken gets user by authorization token +func (s *KnowledgebaseService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) { + userService := NewUserService() + return userService.GetUserByToken(authorization) +} + +// GetUserByID gets user by ID +func (s *KnowledgebaseService) GetUserByID(id string) (*model.User, error) { + return s.userDAO.GetByAccessToken(id) +} + +// GetTenantIDsByUserID gets tenant IDs for a user +func (s *KnowledgebaseService) GetTenantIDsByUserID(userID string) ([]string, error) { + return s.userTenantDAO.GetTenantIDsByUserID(userID) +} + +// GetConnectorsByTenantID gets connectors for a tenant +func (s *KnowledgebaseService) GetConnectorsByTenantID(tenantID string) ([]*dao.ConnectorListItem, error) { + return s.connectorDAO.ListByTenantID(tenantID) +} + +// GetKBList retrieves knowledge bases with ID and name filtering +func (s *KnowledgebaseService) GetKBList(tenantIDs []string, userID string, page, pageSize int, orderby string, desc bool, id, name string) ([]*model.Knowledgebase, int64, common.ErrorCode, error) { + kbs, total, err := s.kbDAO.GetList(tenantIDs, userID, page, pageSize, orderby, desc, id, name) + if err != nil { + return nil, 0, common.CodeServerError, err + } + return kbs, total, common.CodeSuccess, nil +} + +// GetKBByIDAndUserID retrieves a knowledge base by ID and user ID +func (s *KnowledgebaseService) GetKBByIDAndUserID(kbID, userID string) ([]*model.Knowledgebase, error) { + return s.kbDAO.GetKBByIDAndUserID(kbID, userID) +} + +// GetKBByNameAndUserID retrieves a knowledge base by name and user ID +func (s *KnowledgebaseService) GetKBByNameAndUserID(kbName, userID string) ([]*model.Knowledgebase, error) { + return s.kbDAO.GetKBByNameAndUserID(kbName, userID) +} + +// AtomicIncreaseDocNumByID atomically increments the document count +func (s *KnowledgebaseService) AtomicIncreaseDocNumByID(kbID string) error { + return s.kbDAO.AtomicIncreaseDocNumByID(kbID) +} + +// DecreaseDocumentNum decreases document, chunk, and token counts +func (s *KnowledgebaseService) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error { + return s.kbDAO.DecreaseDocumentNum(kbID, docNum, chunkNum, tokenNum) +} + +// UpdateParserConfig updates the parser configuration +func (s *KnowledgebaseService) UpdateParserConfig(id string, config map[string]interface{}) error { + return s.kbDAO.UpdateParserConfig(id, config) +} + +// DeleteFieldMap removes the field_map from parser_config +func (s *KnowledgebaseService) DeleteFieldMap(id string) error { + return s.kbDAO.DeleteFieldMap(id) +} + +// GetFieldMap retrieves field mappings from multiple knowledge bases +func (s *KnowledgebaseService) GetFieldMap(ids []string) (map[string]interface{}, error) { + return s.kbDAO.GetFieldMap(ids) +} + +// GetAllIDs retrieves all knowledge base IDs +func (s *KnowledgebaseService) GetAllIDs() ([]string, error) { + return s.kbDAO.GetAllIDs() +} + +// ExtractAccessToken extracts access token from authorization header +func ExtractAccessToken(authorization, secretKey string) (string, error) { + return utility.ExtractAccessToken(authorization, secretKey) } diff --git a/internal/service/llm.go b/internal/service/llm.go index 85b1cd99f8f..a284f4d2c67 100644 --- a/internal/service/llm.go +++ b/internal/service/llm.go @@ -17,11 +17,16 @@ package service import ( + "fmt" + "strconv" "strings" "ragflow/internal/dao" + "ragflow/internal/model" ) +var DB = dao.DB + // LLMService LLM service type LLMService struct { tenantLLMDAO *dao.TenantLLMDAO @@ -38,6 +43,7 @@ func NewLLMService() *LLMService { // MyLLMItem represents a single LLM item in the response type MyLLMItem struct { + ID string `json:"id"` Type string `json:"type"` Name string `json:"name"` UsedToken int64 `json:"used_token"` @@ -46,67 +52,89 @@ type MyLLMItem struct { MaxTokens int64 `json:"max_tokens,omitempty"` } -// MyLLMResponse represents the response structure for my LLMs -type MyLLMResponse struct { +// MyLLMFactory represents the response structure for a factory in my LLMs +type MyLLMFactory struct { Tags string `json:"tags"` LLM []MyLLMItem `json:"llm"` } // GetMyLLMs get my LLMs for a tenant -func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMResponse, error) { - // Get LLM list from database - myLLMs, err := s.tenantLLMDAO.GetMyLLMs(tenantID, includeDetails) - if err != nil { - return nil, err - } +func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string]MyLLMFactory, error) { + result := make(map[string]MyLLMFactory) - // Group by factory - result := make(map[string]MyLLMResponse) - providerDAO := dao.NewModelProviderDAO() - for _, llm := range myLLMs { - // Get or create factory entry - resp, exists := result[llm.LLMFactory] - if !exists { - resp = MyLLMResponse{ - Tags: llm.Tags, - LLM: []MyLLMItem{}, - } + if includeDetails { + objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) + if err != nil { + return nil, err } - // Create LLM item - item := MyLLMItem{ - Type: llm.ModelType, - Name: llm.LLMName, - UsedToken: llm.UsedTokens, - Status: llm.Status, + factoryDAO := dao.NewLLMFactoryDAO() + factories, err := factoryDAO.GetAllValid() + if err != nil { + return nil, err } - // Add detailed fields if requested - if includeDetails { - item.APIBase = llm.APIBase - item.MaxTokens = llm.MaxTokens - - // If APIBase is empty, try to get from model provider configuration - if item.APIBase == "" { - provider := providerDAO.GetProviderByName(llm.LLMFactory) - if provider != nil { - // Determine appropriate API base URL based on model type - switch llm.ModelType { - case "embedding": - if provider.DefaultEmbeddingURL != "" { - item.APIBase = provider.DefaultEmbeddingURL - } - // Add other model types here if needed - // case "chat": - // case "rerank": - // etc. - } + factoryTagsMap := make(map[string]string) + for _, f := range factories { + if f.Tags != "" { + factoryTagsMap[f.Name] = f.Tags + } + } + + for _, o := range objs { + llmFactory := o.LLMFactory + if _, exists := result[llmFactory]; !exists { + tags := factoryTagsMap[llmFactory] + result[llmFactory] = MyLLMFactory{ + Tags: tags, + LLM: []MyLLMItem{}, } } + + item := MyLLMItem{ + ID: int64ToString(o.ID), + Type: getStringValue(o.ModelType), + Name: getStringValue(o.LLMName), + UsedToken: o.UsedTokens, + Status: getValidStatus(o.Status), + } + + if includeDetails { + item.APIBase = getStringValueDefault(o.APIBase, "") + item.MaxTokens = o.MaxTokens + } + + factory := result[llmFactory] + factory.LLM = append(factory.LLM, item) + result[llmFactory] = factory } + } else { + objs, err := s.tenantLLMDAO.GetMyLLMs(tenantID) + if err != nil { + return nil, err + } + + for _, o := range objs { + llmFactory := o.LLMFactory + if _, exists := result[llmFactory]; !exists { + result[llmFactory] = MyLLMFactory{ + Tags: getStringValue(o.Tags), + LLM: []MyLLMItem{}, + } + } - resp.LLM = append(resp.LLM, item) - result[llm.LLMFactory] = resp + item := MyLLMItem{ + ID: o.ID, + Type: getStringValue(o.ModelType), + Name: getStringValue(o.LLMName), + UsedToken: getInt64Value(o.UsedTokens), + Status: getStringValueDefault(o.Status, "1"), + } + + factory := result[llmFactory] + factory.LLM = append(factory.LLM, item) + result[llmFactory] = factory + } } return result, nil @@ -114,6 +142,7 @@ func (s *LLMService) GetMyLLMs(tenantID string, includeDetails bool) (map[string // LLMListItem represents a single LLM item in the list response type LLMListItem struct { + ID string `json:"id"` LLMName string `json:"llm_name"` ModelType string `json:"model_type"` FID string `json:"fid"` @@ -142,37 +171,32 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon "GPUStack": true, } - // Get tenant LLMs - tenantLLMs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) + objs, err := s.tenantLLMDAO.ListAllByTenant(tenantID) if err != nil { return nil, err } - // Build set of factories with valid API keys facts := make(map[string]bool) - // Build set of valid LLM names@factories status := make(map[string]bool) - for _, tl := range tenantLLMs { - if tl.APIKey != nil && *tl.APIKey != "" && tl.Status == "1" { - facts[tl.LLMFactory] = true - } - llmName := "" - if tl.LLMName != nil { - llmName = *tl.LLMName + tenantLLMMapping := make(map[string]string) + + for _, o := range objs { + if o.APIKey != nil && *o.APIKey != "" && getValidStatus(o.Status) == "1" { + facts[o.LLMFactory] = true } - key := llmName + "@" + tl.LLMFactory - if tl.Status == "1" { + llmName := getStringValue(o.LLMName) + key := llmName + "@" + o.LLMFactory + if getValidStatus(o.Status) == "1" { status[key] = true } + tenantLLMMapping[key] = int64ToString(o.ID) } - // Get all valid LLMs allLLMs, err := s.llmDAO.GetAllValid() if err != nil { return nil, err } - // Filter and build result llmSet := make(map[string]bool) result := make(ListLLMsResponse) @@ -183,20 +207,18 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon key := llm.LLMName + "@" + llm.FID - // Check if valid (Builtin factory or in status set) if llm.FID != "Builtin" && !status[key] { continue } - // Filter by model type if specified if modelType != "" && !strings.Contains(llm.ModelType, modelType) { continue } - // Determine availability - available := facts[llm.FID] || selfDeployed[llm.FID] || llm.LLMName == "flag-embedding" + available := facts[llm.FID] || selfDeployed[llm.FID] || strings.ToLower(llm.LLMName) == "flag-embedding" item := LLMListItem{ + ID: tenantLLMMapping[key], LLMName: llm.LLMName, ModelType: llm.ModelType, FID: llm.FID, @@ -207,7 +229,6 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon Tags: llm.Tags, } - // Add BaseModel fields if llm.CreateDate != nil { createDateStr := llm.CreateDate.Format("2006-01-02T15:04:05") item.CreateDate = &createDateStr @@ -225,36 +246,160 @@ func (s *LLMService) ListLLMs(tenantID string, modelType string) (ListLLMsRespon llmSet[key] = true } - // Add tenant LLMs that are not in the global list - for _, tl := range tenantLLMs { - llmName := "" - if tl.LLMName != nil { - llmName = *tl.LLMName - } - key := llmName + "@" + tl.LLMFactory + for _, o := range objs { + llmName := getStringValue(o.LLMName) + key := llmName + "@" + o.LLMFactory if llmSet[key] { continue } - // Filter by model type if specified - modelTypeValue := "" - if tl.ModelType != nil { - modelTypeValue = *tl.ModelType - } + modelTypeValue := getStringValue(o.ModelType) if modelType != "" && !strings.Contains(modelTypeValue, modelType) { continue } item := LLMListItem{ + ID: int64ToString(o.ID), LLMName: llmName, ModelType: modelTypeValue, - FID: tl.LLMFactory, + FID: o.LLMFactory, Available: true, - Status: tl.Status, + Status: getValidStatus(o.Status), } - result[tl.LLMFactory] = append(result[tl.LLMFactory], item) + result[o.LLMFactory] = append(result[o.LLMFactory], item) } return result, nil } + +func getStringValue(s *string) string { + if s == nil { + return "" + } + return *s +} + +func getStringValueDefault(s *string, defaultVal string) string { + if s == nil || *s == "" { + return defaultVal + } + return *s +} + +func getValidStatus(status string) string { + if status == "" { + return "1" + } + return status +} + +func getInt64Value(i *int64) int64 { + if i == nil { + return 0 + } + return *i +} + +func getInt64ValueDefault(i *int64, defaultVal int64) int64 { + if i == nil || *i == 0 { + return defaultVal + } + return *i +} + +func getBoolValue(b *bool) bool { + if b == nil { + return false + } + return *b +} + +func int64ToString(n int64) string { + return strconv.FormatInt(n, 10) +} + +// SetAPIKeyRequest represents the request for setting API key +type SetAPIKeyRequest struct { + LLMFactory string `json:"llm_factory"` + APIKey string `json:"api_key"` + BaseURL string `json:"base_url"` + SourceFID string `json:"source_fid"` + ModelType string `json:"model_type"` + LLMName string `json:"llm_name"` + Verify bool `json:"verify"` + MaxTokens int64 `json:"max_tokens"` +} + +// SetAPIKeyResult represents the result of setting API key +type SetAPIKeyResult struct { + Message string `json:"message"` + Success bool `json:"success"` +} + +// SetAPIKey sets API key for a LLM factory +func (s *LLMService) SetAPIKey(tenantID string, req *SetAPIKeyRequest) (*SetAPIKeyResult, error) { + factory := req.LLMFactory + baseURL := req.BaseURL + sourceFactory := req.SourceFID + if sourceFactory == "" { + sourceFactory = factory + } + + sourceLLMs, err := s.llmDAO.GetByFactory(sourceFactory) + if err != nil || len(sourceLLMs) == 0 { + msg := "No models configured for " + factory + " (source: " + sourceFactory + ")." + if req.Verify { + return &SetAPIKeyResult{Message: msg, Success: false}, nil + } + return nil, fmt.Errorf(msg) + } + + llmConfig := map[string]interface{}{ + "api_key": req.APIKey, + "api_base": baseURL, + } + + if req.ModelType != "" { + llmConfig["model_type"] = req.ModelType + } + if req.LLMName != "" { + llmConfig["llm_name"] = req.LLMName + } + + for _, llm := range sourceLLMs { + maxTokens := llm.MaxTokens + if maxTokens == 0 { + maxTokens = 8192 + } + llmConfig["max_tokens"] = maxTokens + + existingLLM, _ := s.tenantLLMDAO.GetByTenantFactoryAndModelName(tenantID, factory, llm.LLMName) + if existingLLM != nil { + updates := map[string]interface{}{ + "api_key": req.APIKey, + "api_base": baseURL, + "max_tokens": maxTokens, + } + DB.Model(&model.TenantLLM{}). + Where("tenant_id = ? AND llm_factory = ? AND llm_name = ?", tenantID, factory, llm.LLMName). + Updates(updates) + } else { + modelType := llm.ModelType + llmName := llm.LLMName + tenantLLM := &model.TenantLLM{ + TenantID: tenantID, + LLMFactory: factory, + ModelType: &modelType, + LLMName: &llmName, + APIKey: &req.APIKey, + APIBase: &baseURL, + MaxTokens: maxTokens, + Status: "1", + } + s.tenantLLMDAO.Create(tenantLLM) + } + } + + return &SetAPIKeyResult{Message: "", Success: true}, nil +} diff --git a/internal/service/user.go b/internal/service/user.go index ccf737e3e3b..eb8b2e6f1eb 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -151,15 +151,38 @@ func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorC user.LastLoginTime = &now_date tenantName := req.Nickname + "'s Kingdom" + + llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name + if llmID == "" { + llmID = "" + } + embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name + if embdID == "" { + embdID = "" + } + asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name + if asrID == "" { + asrID = "" + } + img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name + if img2txtID == "" { + img2txtID = "" + } + rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name + if rerankID == "" { + rerankID = "" + } + tenant := &model.Tenant{ ID: userID, Name: &tenantName, - LLMID: cfg.Server.Mode, - EmbdID: cfg.Server.Mode, - ASRID: cfg.Server.Mode, - Img2TxtID: cfg.Server.Mode, - RerankID: cfg.Server.Mode, + LLMID: llmID, + EmbdID: embdID, + ASRID: asrID, + Img2TxtID: img2txtID, + RerankID: rerankID, ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", + Status: &status, } tenant.CreateTime = &now tenant.UpdateTime = &now @@ -753,3 +776,52 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, err return channels, common.CodeSuccess, nil } + +// SetTenantInfoRequest represents the request for setting tenant info +type SetTenantInfoRequest struct { + TenantID string `json:"tenant_id"` + ASRID string `json:"asr_id"` + EmbdID string `json:"embd_id"` + Img2TxtID string `json:"img2txt_id"` + LLMID string `json:"llm_id"` + RerankID string `json:"rerank_id"` + TTSID string `json:"tts_id"` +} + +// SetTenantInfo updates tenant model configuration +func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) error { + tenantDAO := dao.NewTenantDAO() + + _, err := tenantDAO.GetByID(req.TenantID) + if err != nil { + return fmt.Errorf("tenant not found: %w", err) + } + + updates := make(map[string]interface{}) + if req.LLMID != "" { + updates["llm_id"] = req.LLMID + } + if req.EmbdID != "" { + updates["embd_id"] = req.EmbdID + } + if req.ASRID != "" { + updates["asr_id"] = req.ASRID + } + if req.Img2TxtID != "" { + updates["img2txt_id"] = req.Img2TxtID + } + if req.RerankID != "" { + updates["rerank_id"] = req.RerankID + } + if req.TTSID != "" { + updates["tts_id"] = req.TTSID + } + + if len(updates) > 0 { + if err := tenantDAO.Update(req.TenantID, updates); err != nil { + return fmt.Errorf("failed to update tenant: %w", err) + } + } + + return nil +} From 551efa2ebe205d6ff859b4c74a4348407053a734 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 9 Mar 2026 15:52:31 +0800 Subject: [PATCH 181/565] Refactor the go_binding to binding (#13469) ### What problem does this PR solve? As title. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- internal/{go_binding => binding}/rag_analyzer.go | 0 internal/cpp/Makefile | 12 ++++++------ internal/tokenizer/tokenizer.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) rename internal/{go_binding => binding}/rag_analyzer.go (100%) diff --git a/internal/go_binding/rag_analyzer.go b/internal/binding/rag_analyzer.go similarity index 100% rename from internal/go_binding/rag_analyzer.go rename to internal/binding/rag_analyzer.go diff --git a/internal/cpp/Makefile b/internal/cpp/Makefile index cbf66ac70ff..9ddf0244059 100644 --- a/internal/cpp/Makefile +++ b/internal/cpp/Makefile @@ -51,20 +51,20 @@ c_api_debug: $(BUILD_DIR) # Test the Go bindings test_go: c_api - cd go_bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test." + cd bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test." # Run memory test test_memory: c_api - cd go_bindings/example && go run memory_leak_check.go + cd bindings/example && go run memory_leak_check.go # Run with valgrind valgrind: c_api - cd go_bindings/example && bash run_valgrind.sh + cd bindings/example && bash run_valgrind.sh # Run with AddressSanitizer asan: c_api_asan @echo "Running with AddressSanitizer..." - cd go_bindings/example && \ + cd bindings/example && \ ASAN_OPTIONS=detect_leaks=1:print_stats=1:verbosity=0 \ go run memory_leak_check.go @@ -77,5 +77,5 @@ install: c_api clean: rm -rf $(BUILD_DIR) rm -rf $(ASAN_BUILD_DIR) - rm -f go_bindings/example/valgrind.log - rm -f go_bindings/example/memory_test_bin + rm -f bindings/example/valgrind.log + rm -f bindings/example/memory_test_bin diff --git a/internal/tokenizer/tokenizer.go b/internal/tokenizer/tokenizer.go index 9fe895e7118..d3dd867abd4 100644 --- a/internal/tokenizer/tokenizer.go +++ b/internal/tokenizer/tokenizer.go @@ -26,7 +26,7 @@ import ( "go.uber.org/zap" - rag "ragflow/internal/go_binding" + rag "ragflow/internal/binding" "ragflow/internal/logger" ) From f8040ae6214cad625c431b37c9b99591f0ecbebc Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 9 Mar 2026 17:48:29 +0800 Subject: [PATCH 182/565] Add scheduled tasks (#13470) ### What problem does this PR solve? 1. RAGFlow server will send heartbeat periodically. 2. This PR will including: - Scheduled task - API server message sending - Admin server API to receive the message. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- cmd/server_main.go | 71 +++++-- internal/admin/handler.go | 37 +++- internal/admin/heartbeat.go | 76 ++++++++ internal/admin/router.go | 2 + internal/admin/service.go | 16 ++ internal/common/status_message.go | 33 ++++ internal/cpp/Makefile | 2 +- internal/handler/user.go | 2 +- internal/server/config.go | 49 ++++- internal/service/heartbeat_sender.go | 136 +++++++++++++ internal/service/user.go | 4 +- internal/utility/http_client.go | 274 +++++++++++++++++++++++++++ internal/utility/network.go | 49 +++++ internal/utility/scheduled_task.go | 156 +++++++++++++++ 14 files changed, 871 insertions(+), 36 deletions(-) create mode 100644 internal/admin/heartbeat.go create mode 100644 internal/common/status_message.go create mode 100644 internal/service/heartbeat_sender.go create mode 100644 internal/utility/http_client.go create mode 100644 internal/utility/network.go create mode 100644 internal/utility/scheduled_task.go diff --git a/cmd/server_main.go b/cmd/server_main.go index 7d6ac21e8da..dbae3efbeae 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "ragflow/internal/common" "ragflow/internal/init_data" "ragflow/internal/server" "ragflow/internal/utility" @@ -30,7 +31,7 @@ import ( func main() { // Initialize logger with default level // logger.Init("info"); // set debug log level - if err := logger.Init("debug"); err != nil { + if err := logger.Init("info"); err != nil { panic(fmt.Sprintf("Failed to initialize logger: %v", err)) } @@ -45,28 +46,21 @@ func main() { } logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders()))) - cfg := server.GetConfig() + config := server.GetConfig() // Reinitialize logger with configured level if different - if cfg.Log.Level != "" && cfg.Log.Level != "info" { - if err := logger.Init(cfg.Log.Level); err != nil { + if config.Log.Level != "" && config.Log.Level != "info" { + if err := logger.Init(config.Log.Level); err != nil { logger.Error("Failed to reinitialize logger with configured level", err) } } server.SetLogger(logger.Logger) - logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) + logger.Info("Server mode", zap.String("mode", config.Server.Mode)) // Print all configuration settings server.PrintAll() - // Set Gin mode - if cfg.Server.Mode == "release" { - gin.SetMode(gin.ReleaseMode) - } else { - gin.SetMode(gin.DebugMode) - } - // Initialize database if err := dao.InitDB(); err != nil { logger.Fatal("Failed to initialize database", zap.Error(err)) @@ -80,13 +74,13 @@ func main() { } // Initialize doc engine - if err := engine.Init(&cfg.DocEngine); err != nil { + if err := engine.Init(&config.DocEngine); err != nil { logger.Fatal("Failed to initialize doc engine", zap.Error(err)) } defer engine.Close() // Initialize Redis cache - if err := cache.Init(&cfg.Redis); err != nil { + if err := cache.Init(&config.Redis); err != nil { logger.Fatal("Failed to initialize Redis", zap.Error(err)) } defer cache.Close() @@ -112,6 +106,20 @@ func main() { logger.Fatal("Failed to initialize query builder", zap.Error(err)) } + startServer(config) + + logger.Info("Server exited") +} + +func startServer(config *server.Config) { + + // Set Gin mode + if config.Server.Mode == "release" { + gin.SetMode(gin.ReleaseMode) + } else { + gin.SetMode(gin.DebugMode) + } + // Initialize service layer userService := service.NewUserService() documentService := service.NewDocumentService() @@ -147,7 +155,7 @@ func main() { ginEngine := gin.New() // Middleware - if cfg.Server.Mode == "debug" { + if config.Server.Mode == "debug" { ginEngine.Use(gin.Logger()) } ginEngine.Use(gin.Recovery()) @@ -156,7 +164,7 @@ func main() { r.Setup(ginEngine) // Create HTTP server - addr := fmt.Sprintf(":%d", cfg.Server.Port) + addr := fmt.Sprintf(":%d", config.Server.Port) srv := &http.Server{ Addr: addr, Handler: ginEngine, @@ -172,12 +180,39 @@ func main() { " /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n", ) logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Server starting on port: %d", cfg.Server.Port)) + logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) } }() + // Get local IP address for heartbeat reporting + localIP := utility.GetLocalIP() + if localIP == "" { + localIP = "127.0.0.1" + } + + // Initialize and start heartbeat reporter to admin server + heartbeatService := service.NewHeartbeatSender( + logger.Logger, + common.ServerTypeAPI, + fmt.Sprintf("ragflow-server-%d", config.Server.Port), + localIP, + config.Server.Port, + ) + if err := heartbeatService.InitHTTPClient(); err != nil { + logger.Warn("Failed to initialize heartbeat service", zap.Error(err)) + } else { + // Start heartbeat reporter with 30 seconds interval + heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { + if err := heartbeatService.SendHeartbeat(); err != nil { + logger.Warn("Failed to send heartbeat", zap.Error(err)) + } + }) + heartbeatReporter.Start() + defer heartbeatReporter.Stop() + } + // Wait for interrupt signal to gracefully shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) @@ -194,6 +229,4 @@ func main() { if err := srv.Shutdown(ctx); err != nil { logger.Fatal("Server forced to shutdown", zap.Error(err)) } - - logger.Info("Server exited") } diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 526a22fa912..155ebe1685c 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -19,10 +19,12 @@ package admin import ( "errors" "net/http" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/service" "ragflow/internal/utility" "strconv" + "time" "github.com/gin-gonic/gin" ) @@ -111,7 +113,7 @@ func (h *Handler) Login(c *gin.Context) { return } - user, code, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req, true) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "code": code, @@ -135,8 +137,9 @@ func (h *Handler) Login(c *gin.Context) { c.Header("Access-Control-Expose-Headers", "Authorization") c.JSON(http.StatusOK, gin.H{ - "code": 0, - "message": "Login successful", + "code": common.CodeSuccess, + "message": "Welcome back!", + "data": user, }) } @@ -943,3 +946,31 @@ func (h *Handler) HandleNoRoute(c *gin.Context) { Message: "The requested resource was not found", }) } + +// Reports handle heartbeat reports from servers +func (h *Handler) Reports(c *gin.Context) { + var req common.BaseMessage + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Invalid request body: "+err.Error(), 400) + return + } + + // Set default timestamp if not provided + if req.Timestamp.IsZero() { + req.Timestamp = time.Now() + } + + // Only process heartbeat messages for now + if req.MessageType != common.MessageHeartbeat { + errorResponse(c, "Unsupported report type: "+string(req.MessageType), 400) + return + } + + // Handle the heartbeat + if err := h.service.HandleHeartbeat(&req); err != nil { + errorResponse(c, "Failed to process heartbeat: "+err.Error(), 500) + return + } + + successNoData(c, "Heartbeat received successfully") +} diff --git a/internal/admin/heartbeat.go b/internal/admin/heartbeat.go new file mode 100644 index 00000000000..b7e41e61147 --- /dev/null +++ b/internal/admin/heartbeat.go @@ -0,0 +1,76 @@ +package admin + +import ( + "ragflow/internal/common" + "sync" + "time" +) + +// ServerStatusStore is a thread-safe global server status storage +type ServerStatusStore struct { + mu sync.RWMutex + servers map[string]*common.BaseMessage // key: server_id +} + +// GlobalServerStatusStore is the global instance +var GlobalServerStatusStore = &ServerStatusStore{ + servers: make(map[string]*common.BaseMessage), +} + +// UpdateStatus updates or adds a server status +func (s *ServerStatusStore) UpdateStatus(serverID string, status *common.BaseMessage) { + s.mu.Lock() + defer s.mu.Unlock() + s.servers[serverID] = status +} + +// GetStatus gets a single server status +func (s *ServerStatusStore) GetStatus(serverID string) (*common.BaseMessage, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + status, ok := s.servers[serverID] + return status, ok +} + +// GetAllStatuses gets all server statuses +func (s *ServerStatusStore) GetAllStatuses() []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0, len(s.servers)) + for _, status := range s.servers { + result = append(result, status) + } + return result +} + +// GetStatusesByType gets server statuses by type +func (s *ServerStatusStore) GetStatusesByType(serverType common.ServerType) []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0) + for _, status := range s.servers { + if status.ServerType == serverType { + result = append(result, status) + } + } + return result +} + +// RemoveStatus removes a server status +func (s *ServerStatusStore) RemoveStatus(serverID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.servers, serverID) +} + +// CleanupStaleStatuses cleans up servers that haven't reported for a specified duration +func (s *ServerStatusStore) CleanupStaleStatuses(maxAge time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + now := time.Now() + for id, status := range s.servers { + if now.Sub(status.Timestamp) > maxAge { + delete(s.servers, id) + } + } +} diff --git a/internal/admin/router.go b/internal/admin/router.go index 3dc03c2c143..e2d9cedf1c1 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -47,6 +47,8 @@ func (r *Router) Setup(engine *gin.Engine) { admin.GET("/ping", r.handler.Ping) admin.POST("/login", r.handler.Login) + admin.POST("/reports", r.handler.Reports) + // Protected routes protected := admin.Group("") protected.Use(r.handler.AuthMiddleware()) diff --git a/internal/admin/service.go b/internal/admin/service.go index 2438ef6c85d..b45714396a1 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "ragflow/internal/cache" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine/elasticsearch" "ragflow/internal/model" @@ -732,3 +733,18 @@ func (s *Service) TestSandboxConnection(providerType string, config map[string]i "connected": true, }, nil } + +// HandleHeartbeat handle heartbeat +func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error { + status := &common.BaseMessage{ + ServerName: msg.ServerName, + ServerType: msg.ServerType, + Host: msg.Host, + Port: msg.Port, + Version: msg.Version, + Timestamp: msg.Timestamp, + Ext: msg.Ext, + } + GlobalServerStatusStore.UpdateStatus(msg.ServerName, status) + return nil +} diff --git a/internal/common/status_message.go b/internal/common/status_message.go new file mode 100644 index 00000000000..76d29ac3eb9 --- /dev/null +++ b/internal/common/status_message.go @@ -0,0 +1,33 @@ +package common + +import ( + "time" +) + +type MessageType string + +const ( + MessageHeartbeat MessageType = "heartbeat" + MessageMetric MessageType = "metric" + MessageEvent MessageType = "event" +) + +type ServerType string + +const ( + ServerTypeAPI ServerType = "api_server" // API server + ServerTypeWorker ServerType = "ingestor" // Ingestion server + ServerTypeScheduler ServerType = "data_collector" // Data collection server +) + +type BaseMessage struct { + MessageID int64 `json:"report_id"` + MessageType MessageType `json:"report_type"` + ServerName string `json:"server_id"` + ServerType ServerType `json:"server_type"` + Host string `json:"host"` + Port int `json:"port"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + Ext map[string]interface{} `json:"ext,omitempty"` +} diff --git a/internal/cpp/Makefile b/internal/cpp/Makefile index 9ddf0244059..e45843e85dc 100644 --- a/internal/cpp/Makefile +++ b/internal/cpp/Makefile @@ -51,7 +51,7 @@ c_api_debug: $(BUILD_DIR) # Test the Go bindings test_go: c_api - cd bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test." + cd bindings/example && go run main.go ../../$(BUILD_DIR) "This is a test." # Run memory test test_memory: c_api diff --git a/internal/handler/user.go b/internal/handler/user.go index 3651c29c148..8ec2d314f38 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -164,7 +164,7 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { return } - user, code, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": code, diff --git a/internal/server/config.go b/internal/server/config.go index fe9cdea48ed..b028ae76ce2 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -40,9 +40,16 @@ type Config struct { DocEngine DocEngineConfig `mapstructure:"doc_engine"` RegisterEnabled int `mapstructure:"register_enabled"` OAuth map[string]OAuthConfig `mapstructure:"oauth"` + Admin AdminConfig `mapstructure:"admin"` UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"` } +// AdminConfig admin server configuration +type AdminConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"http_port"` +} + // UserDefaultLLMConfig user default LLM configuration type UserDefaultLLMConfig struct { DefaultModels DefaultModelsConfig `mapstructure:"default_models"` @@ -50,19 +57,19 @@ type UserDefaultLLMConfig struct { // DefaultModelsConfig default models configuration type DefaultModelsConfig struct { - ChatModel ModelConfig `mapstructure:"chat_model"` - EmbeddingModel ModelConfig `mapstructure:"embedding_model"` - RerankModel ModelConfig `mapstructure:"rerank_model"` - ASRModel ModelConfig `mapstructure:"asr_model"` + ChatModel ModelConfig `mapstructure:"chat_model"` + EmbeddingModel ModelConfig `mapstructure:"embedding_model"` + RerankModel ModelConfig `mapstructure:"rerank_model"` + ASRModel ModelConfig `mapstructure:"asr_model"` Image2TextModel ModelConfig `mapstructure:"image2text_model"` } // ModelConfig model configuration type ModelConfig struct { - Name string `mapstructure:"name"` - APIKey string `mapstructure:"api_key"` - BaseURL string `mapstructure:"base_url"` - Factory string `mapstructure:"factory"` + Name string `mapstructure:"name"` + APIKey string `mapstructure:"api_key"` + BaseURL string `mapstructure:"base_url"` + Factory string `mapstructure:"factory"` } // OAuthConfig OAuth configuration for a channel @@ -325,6 +332,20 @@ func Init(configPath string) error { return fmt.Errorf("unmarshal config error: %w", err) } + // Set default values for admin configuration if not configured + if globalConfig.Admin.Host == "" { + globalConfig.Admin.Host = v.GetString("admin.host") + } + if globalConfig.Admin.Host == "" { + globalConfig.Admin.Host = "127.0.0.1" + } + if globalConfig.Admin.Port == 0 { + globalConfig.Admin.Port = v.GetInt("admin.http_port") + } + if globalConfig.Admin.Port == 0 { + globalConfig.Admin.Port = 9381 + } + // Load REGISTER_ENABLED from environment variable (default: 1) registerEnabled := 1 if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" { @@ -357,8 +378,8 @@ func Init(configPath string) error { if v.IsSet("ragflow") { ragflowConfig := v.Sub("ragflow") if ragflowConfig != nil { - globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default - // globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct + //globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default + globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct // If mode is not set, default to debug if globalConfig.Server.Mode == "" { globalConfig.Server.Mode = "release" @@ -484,6 +505,14 @@ func GetConfig() *Config { return globalConfig } +// GetAdminConfig gets the admin server configuration +func GetAdminConfig() *AdminConfig { + if globalConfig == nil { + return nil + } + return &globalConfig.Admin +} + // SetLogger sets the logger instance func SetLogger(l *zap.Logger) { zapLogger = l diff --git a/internal/service/heartbeat_sender.go b/internal/service/heartbeat_sender.go new file mode 100644 index 00000000000..ec2b198320b --- /dev/null +++ b/internal/service/heartbeat_sender.go @@ -0,0 +1,136 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "encoding/json" + "fmt" + "ragflow/internal/common" + "ragflow/internal/server" + "ragflow/internal/utility" + "time" + + "go.uber.org/zap" +) + +// HeartbeatSender is responsible for sending heartbeat reports to the admin server +type HeartbeatSender struct { + client *utility.HTTPClient + logger *zap.Logger + serverType common.ServerType + serverName string + host string + port int + version string + lastSuccess bool + attemptCount int +} + +// NewHeartbeatSender creates a new heartbeat service instance +func NewHeartbeatSender(logger *zap.Logger, serverType common.ServerType, serverName, host string, port int) *HeartbeatSender { + return &HeartbeatSender{ + logger: logger, + serverType: serverType, + serverName: serverName, + host: host, + port: port, + version: utility.GetRAGFlowVersion(), + lastSuccess: false, + attemptCount: 0, + } +} + +// InitHTTPClient initializes the HTTP client with admin server configuration +func (h *HeartbeatSender) InitHTTPClient() error { + adminConfig := server.GetAdminConfig() + if adminConfig == nil { + return fmt.Errorf("admin configuration not found") + } + + h.client = utility.NewHTTPClientBuilder(). + WithHost(adminConfig.Host). + WithPort(adminConfig.Port). + WithTimeout(10 * time.Second). + Build() + + h.logger.Info("Heartbeat HTTP client initialized", + zap.String("admin_host", adminConfig.Host), + zap.Int("admin_port", adminConfig.Port+2), + ) + + return nil +} + +// SendHeartbeat sends a heartbeat message to the admin server +func (h *HeartbeatSender) SendHeartbeat() error { + + if h.attemptCount < 10 { + if h.lastSuccess { + h.attemptCount++ + return nil + } + } + h.attemptCount = 0 + h.lastSuccess = false + + if h.client == nil { + if err := h.InitHTTPClient(); err != nil { + h.logger.Error("Failed to initialize HTTP client", zap.Error(err)) + return err + } + } + + message := &common.BaseMessage{ + MessageID: time.Now().UnixNano(), + MessageType: common.MessageHeartbeat, + ServerName: h.serverName, + ServerType: h.serverType, + Host: h.host, + Port: h.port, + Version: h.version, + Timestamp: time.Now(), + Ext: make(map[string]interface{}), + } + + jsonData, err := json.Marshal(message) + if err != nil { + h.logger.Error("Failed to marshal heartbeat message", zap.Error(err)) + return err + } + + resp, err := h.client.PostJSON("/api/v1/admin/reports", jsonData) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + h.logger.Error("Heartbeat request failed", + zap.Int("status_code", resp.StatusCode), + ) + return fmt.Errorf("heartbeat request failed with status code: %d", resp.StatusCode) + } + + h.logger.Debug("Heartbeat sent successfully", + zap.String("server_id", h.serverName), + zap.String("server_type", string(h.serverType)), + ) + + h.lastSuccess = true + + return nil +} diff --git a/internal/service/user.go b/internal/service/user.go index eb8b2e6f1eb..a87260b6805 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -309,8 +309,8 @@ func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, e // - CodeAuthenticationError (109): Email not registered or password mismatch // - CodeServerError (500): Password decryption failure // - CodeForbidden (403): Account disabled -func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) { - if req.Email == "admin@ragflow.io" { +func (s *UserService) LoginByEmail(req *EmailLoginRequest, adminLogin bool) (*model.User, common.ErrorCode, error) { + if !adminLogin && req.Email == "admin@ragflow.io" { return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services") } diff --git a/internal/utility/http_client.go b/internal/utility/http_client.go new file mode 100644 index 00000000000..464b5530af0 --- /dev/null +++ b/internal/utility/http_client.go @@ -0,0 +1,274 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +// HTTPClient is a configurable HTTP client +type HTTPClient struct { + host string + port int + useSSL bool + timeout time.Duration + headers map[string]string + httpClient *http.Client +} + +// HTTPClientBuilder is a builder for HTTPClient +type HTTPClientBuilder struct { + client *HTTPClient +} + +// NewHTTPClientBuilder creates a new HTTPClientBuilder with default values +func NewHTTPClientBuilder() *HTTPClientBuilder { + return &HTTPClientBuilder{ + client: &HTTPClient{ + host: "localhost", + port: 80, + useSSL: false, + timeout: 30 * time.Second, + headers: make(map[string]string), + }, + } +} + +// WithHost sets the host +func (b *HTTPClientBuilder) WithHost(host string) *HTTPClientBuilder { + b.client.host = host + return b +} + +// WithPort sets the port +func (b *HTTPClientBuilder) WithPort(port int) *HTTPClientBuilder { + b.client.port = port + return b +} + +// WithSSL enables or disables SSL +func (b *HTTPClientBuilder) WithSSL(useSSL bool) *HTTPClientBuilder { + b.client.useSSL = useSSL + return b +} + +// WithTimeout sets the timeout duration +func (b *HTTPClientBuilder) WithTimeout(timeout time.Duration) *HTTPClientBuilder { + b.client.timeout = timeout + return b +} + +// WithHeader adds a single header +func (b *HTTPClientBuilder) WithHeader(key, value string) *HTTPClientBuilder { + b.client.headers[key] = value + return b +} + +// WithHeaders sets multiple headers +func (b *HTTPClientBuilder) WithHeaders(headers map[string]string) *HTTPClientBuilder { + for key, value := range headers { + b.client.headers[key] = value + } + return b +} + +// Build creates the HTTPClient +func (b *HTTPClientBuilder) Build() *HTTPClient { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + } + + // If SSL is disabled, allow insecure connections + if !b.client.useSSL { + transport.TLSClientConfig.InsecureSkipVerify = true + } + + b.client.httpClient = &http.Client{ + Timeout: b.client.timeout, + Transport: transport, + } + + return b.client +} + +// SetHost sets the host +func (c *HTTPClient) SetHost(host string) { + c.host = host +} + +// SetPort sets the port +func (c *HTTPClient) SetPort(port int) { + c.port = port +} + +// SetSSL enables or disables SSL +func (c *HTTPClient) SetSSL(useSSL bool) { + c.useSSL = useSSL +} + +// SetTimeout sets the timeout duration +func (c *HTTPClient) SetTimeout(timeout time.Duration) { + c.timeout = timeout + c.httpClient.Timeout = timeout +} + +// SetHeader sets a single header +func (c *HTTPClient) SetHeader(key, value string) { + c.headers[key] = value +} + +// SetHeaders sets multiple headers +func (c *HTTPClient) SetHeaders(headers map[string]string) { + c.headers = headers +} + +// AddHeader adds a header without removing existing ones +func (c *HTTPClient) AddHeader(key, value string) { + c.headers[key] = value +} + +// GetHeaders returns a copy of all headers +func (c *HTTPClient) GetHeaders() map[string]string { + headersCopy := make(map[string]string) + for k, v := range c.headers { + headersCopy[k] = v + } + return headersCopy +} + +// GetBaseURL returns the base URL +func (c *HTTPClient) GetBaseURL() string { + scheme := "http" + if c.useSSL { + scheme = "https" + } + return fmt.Sprintf("%s://%s:%d", scheme, c.host, c.port) +} + +// GetFullURL returns the full URL for a given path +func (c *HTTPClient) GetFullURL(path string) string { + baseURL := c.GetBaseURL() + // Ensure path starts with / + if path != "" && path[0] != '/' { + path = "/" + path + } + return baseURL + path +} + +// prepareRequest creates an HTTP request with configured headers +func (c *HTTPClient) prepareRequest(method, urlStr string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + return nil, err + } + + // Add configured headers + for key, value := range c.headers { + req.Header.Set(key, value) + } + + return req, nil +} + +// Get performs a GET request +func (c *HTTPClient) Get(path string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// GetWithParams performs a GET request with query parameters +func (c *HTTPClient) GetWithParams(path string, params map[string]string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + + query := u.Query() + for key, value := range params { + query.Set(key, value) + } + u.RawQuery = query.Encode() + + req, err := c.prepareRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Post performs a POST request +func (c *HTTPClient) Post(path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodPost, urlStr, bytes.NewReader(body)) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// PostJSON performs a POST request with JSON content type +func (c *HTTPClient) PostJSON(path string, body []byte) (*http.Response, error) { + c.SetHeader("Content-Type", "application/json") + return c.Post(path, body) +} + +// Put performs a PUT request +func (c *HTTPClient) Put(path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodPut, urlStr, bytes.NewReader(body)) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Delete performs a DELETE request +func (c *HTTPClient) Delete(path string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodDelete, urlStr, nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Do performs a request with the given method +func (c *HTTPClient) Do(method, path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + req, err := c.prepareRequest(method, urlStr, bodyReader) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} diff --git a/internal/utility/network.go b/internal/utility/network.go new file mode 100644 index 00000000000..bf8ad982010 --- /dev/null +++ b/internal/utility/network.go @@ -0,0 +1,49 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "net" +) + +// GetLocalIP returns the first non-loopback local IP address of the host +func GetLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + + for _, addr := range addrs { + // Check the address type and skip loopback addresses + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + + return "" +} + +// GetLocalIPWithFallback returns the local IP address with a fallback value +func GetLocalIPWithFallback(fallback string) string { + ip := GetLocalIP() + if ip == "" { + return fallback + } + return ip +} diff --git a/internal/utility/scheduled_task.go b/internal/utility/scheduled_task.go new file mode 100644 index 00000000000..88c9886d17a --- /dev/null +++ b/internal/utility/scheduled_task.go @@ -0,0 +1,156 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "encoding/json" + "fmt" + "ragflow/internal/logger" + "sync/atomic" + "time" + + "go.uber.org/zap" +) + +type StatusMessage struct { + ID int `json:"id"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + NodeName string `json:"node_name"` + ExtInfo string `json:"ext_info"` +} + +func NewStatusMessage(id int, version string, nodeName string, extInfo string) *StatusMessage { + return &StatusMessage{ + ID: id, + Version: version, + Timestamp: time.Now(), + NodeName: nodeName, + ExtInfo: extInfo, + } +} + +func StatusMessageSending() { + // Construct status message + statusMessage := NewStatusMessage(0, "v1", "ragflow", "") + + // Serialize to JSON + jsonData, err := json.Marshal(statusMessage) + if err != nil { + logger.Error("Failed to marshal status message", err) + return + } + + // Create HTTP client + client := NewHTTPClientBuilder(). + WithHost("127.0.0.1"). + WithPort(9381). + WithSSL(false). + WithTimeout(10 * time.Second). + Build() + + // Send POST request + resp, err := client.PostJSON("/v1/admin/status", jsonData) + if err != nil { + logger.Error("Error sending status message", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + logger.Error("Failed to send status message", fmt.Errorf("status: %d", resp.StatusCode)) + } +} + +// ScheduledTask represents a periodic task +type ScheduledTask struct { + Name string + Interval time.Duration + Job func() + stop chan struct{} + running bool + executing int32 // atomic flag: 0 - not executed, 1 running +} + +// NewScheduledTask creates a new simple task +func NewScheduledTask(name string, interval time.Duration, job func()) *ScheduledTask { + return &ScheduledTask{ + Name: name, + Interval: interval, + Job: job, + stop: make(chan struct{}), + } +} + +// Start begins the periodic task +func (t *ScheduledTask) Start() { + if t.running { + return + } + t.running = true + + go func() { + ticker := time.NewTicker(t.Interval) + defer ticker.Stop() + + logger.Info("Task started", zap.String("name", t.Name)) + + for { + select { + case <-ticker.C: + t.runSafely() + case <-t.stop: + logger.Info("Task stopped", zap.String("name", t.Name)) + return + } + } + }() +} + +// runSafely executes the job with panic recovery and prevents overlap +func (t *ScheduledTask) runSafely() { + // Attempt to set the flag + if !atomic.CompareAndSwapInt32(&t.executing, 0, 1) { + logger.Warn("Task skipped - previous execution still running", zap.String("name", t.Name)) + return + } + + // Clear atomic flag after execution + defer atomic.StoreInt32(&t.executing, 0) + + defer func() { + if r := recover(); r != nil { + logger.Fatal("Task panicked", zap.String("name", t.Name), zap.Any("recover", r)) + } + }() + + t.Job() +} + +// Stop stops the periodic task +func (t *ScheduledTask) Stop() { + if !t.running { + return + } + t.running = false + close(t.stop) +} + +// IsExecuting returns whether the task is currently executing +func (t *ScheduledTask) IsExecuting() bool { + return atomic.LoadInt32(&t.executing) == 1 +} From fb647dc614e465a8f1d93d80f2312ec38227bc24 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 9 Mar 2026 19:00:17 +0800 Subject: [PATCH 183/565] Refa: convert download_img to async (#13477) ### What problem does this PR solve? Convert download_img to async. ### Type of change - [x] Refactoring - [x] Performance Improvement --- api/apps/user_app.py | 6 +++--- common/misc_utils.py | 6 +++--- pyproject.toml | 1 + test/unit_test/common/test_misc_utils.py | 11 +++++++---- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index e08e434d490..702e1bd8557 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -222,7 +222,7 @@ async def oauth_callback(channel): if not users: try: try: - avatar = download_img(user_info.avatar_url) + avatar = await download_img(user_info.avatar_url) except Exception as e: logging.exception(e) avatar = "" @@ -317,7 +317,7 @@ async def github_callback(): # User isn't try to register try: try: - avatar = download_img(user_info["avatar_url"]) + avatar = await download_img(user_info["avatar_url"]) except Exception as e: logging.exception(e) avatar = "" @@ -421,7 +421,7 @@ async def feishu_callback(): # User isn't try to register try: try: - avatar = download_img(user_info["avatar_url"]) + avatar = await download_img(user_info["avatar_url"]) except Exception as e: logging.exception(e) avatar = "" diff --git a/common/misc_utils.py b/common/misc_utils.py index 19b608ca7fe..1826be77f30 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -27,16 +27,16 @@ from concurrent.futures import ThreadPoolExecutor -import requests def get_uuid(): return uuid.uuid1().hex -def download_img(url): +async def download_img(url): if not url: return "" - response = requests.get(url) + from common.http_client import async_request + response = await async_request("GET", url) return "data:" + \ response.headers.get('Content-Type', 'image/jpg') + ";" + \ "base64," + base64.b64encode(response.content).decode("utf-8") diff --git a/pyproject.toml b/pyproject.toml index 73006ac28f6..02efa55a335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,6 +218,7 @@ markers = [ "p3: low priority test cases", "smoke: smoke test cases", "auth: authentication UI tests", + "asyncio: mark test as async", ] # Test collection and runtime configuration diff --git a/test/unit_test/common/test_misc_utils.py b/test/unit_test/common/test_misc_utils.py index b407c49b7d4..82c8f976576 100644 --- a/test/unit_test/common/test_misc_utils.py +++ b/test/unit_test/common/test_misc_utils.py @@ -15,6 +15,7 @@ # import uuid import hashlib +import pytest from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes @@ -91,14 +92,16 @@ def test_hex_characters_only(self): class TestDownloadImg: """Test cases for download_img function""" - def test_empty_url_returns_empty_string(self): + @pytest.mark.asyncio + async def test_empty_url_returns_empty_string(self): """Test that empty URL returns empty string""" - result = download_img("") + result = await download_img("") assert result == "" - def test_none_url_returns_empty_string(self): + @pytest.mark.asyncio + async def test_none_url_returns_empty_string(self): """Test that None URL returns empty string""" - result = download_img(None) + result = await download_img(None) assert result == "" From 5370eb540a7a0aaeb23e496d0d6da800f30897f1 Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:14:45 +0800 Subject: [PATCH 184/565] Docs: Updated Switch chunk availability (#13482) ### What problem does this PR solve? A quick editorial pass. ### Type of change - [x] Documentation Update --- docs/references/http_api_reference.md | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 907e2202308..1f50d5753ee 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -2220,11 +2220,11 @@ Failure: --- -### Switch chunks availability +### Update chunk availability **POST** `/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/switch` -Switches the availability of specified chunks (enable or disable chunks for retrieval). +Updates or switches the availability status of specified chunks, controlling whether they are available for retrieval. #### Request @@ -2234,9 +2234,9 @@ Switches the availability of specified chunks (enable or disable chunks for retr - `'Content-Type: application/json'` - `'Authorization: Bearer '` - Body: - - `"chunk_ids"`: `list[string]` (*Required*) List of chunk IDs to switch. - - `"available_int"`: `integer` (*Optional*) `1` for available, `0` for unavailable. Mutually exclusive with `"available"`. - - `"available"`: `boolean` (*Optional*) Availability status. Mutually exclusive with `"available_int"`. Must provide either `available_int` or `available`. + - `"chunk_ids"`: `list[string]` (*Required*) + - `"available_int"`: `integer` (*Optional*) + - `"available"`: `boolean` (*Optional*) ##### Request example @@ -2258,12 +2258,16 @@ curl --request POST \ The ID of the dataset. - `document_id`: (*Path parameter*) The ID of the document. -- `"chunk_ids"`: (*Body parameter*), `list[string]`, *Required* - List of chunk IDs whose availability is to be switched. -- `"available_int"`: (*Body parameter*), `integer` - `1` for available (chunk participates in retrieval), `0` for unavailable. Either this or `"available"` must be provided. -- `"available"`: (*Body parameter*), `boolean` - Availability status. `true` for available, `false` for unavailable. Alternative to `"available_int"`. +- `"chunk_ids"`: (*Body parameter*), `list[string]` (*Required*) + IDs of the chunks whose availability status is to be updated. +- `"available_int"`: (*Body parameter*), `integer` (*Optional*) + Availability status for the specified chunks. Mutually exclusive with `"available"`. You must provide either `available_int` or `available`, *not* both. + - `1`: Available, + - `0`: Unavailable. +- `"available"`: (*Body parameter*), `boolean` (*Optional*) + Availability status of the specified chunks. Mutually exclusive with `"available_int"`. You must provide either `available` or `available_int`, *not* both. + - `true`: Available, + - `false`: Unavailable. #### Response From e16dbd8a309d7adf0a62dae9f6eac50d7e05cac8 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Tue, 10 Mar 2026 09:56:43 +0800 Subject: [PATCH 185/565] Service list and minio status (#13480) ### What problem does this PR solve? 1. Resolve standard user can access admin service 2. Get RAGFlow service status 3. Fix minio status fetching ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- cmd/admin_server.go | 3 --- internal/admin/handler.go | 13 +++++++++- internal/admin/heartbeat.go | 8 +++--- internal/admin/router.go | 2 +- internal/admin/service.go | 52 ++++++++++++++++++++++++++++++------- internal/handler/kb.go | 5 ++++ 6 files changed, 64 insertions(+), 19 deletions(-) diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 103ad6d227c..1165ce219a5 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -80,9 +80,6 @@ func main() { logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) - // Print all configuration settings - server.PrintAll() - // Set Gin mode if cfg.Server.Mode == "release" { gin.SetMode(gin.ReleaseMode) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 155ebe1685c..eb5b4d50554 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -644,7 +644,10 @@ func (h *Handler) GetUserPermission(c *gin.Context) { func (h *Handler) GetServices(c *gin.Context) { services, err := h.service.GetAllServices() if err != nil { - errorResponse(c, err.Error(), 500) + c.JSON(http.StatusInternalServerError, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + }) return } @@ -932,6 +935,14 @@ func (h *Handler) AuthMiddleware() gin.HandlerFunc { return } + if !*user.IsSuperuser { + c.JSON(http.StatusForbidden, gin.H{ + "code": common.CodeForbidden, + "message": "Permission denied", + }) + return + } + c.Set("user", user) c.Set("user_id", user.ID) c.Set("email", user.Email) diff --git a/internal/admin/heartbeat.go b/internal/admin/heartbeat.go index b7e41e61147..fc8901f4404 100644 --- a/internal/admin/heartbeat.go +++ b/internal/admin/heartbeat.go @@ -18,17 +18,17 @@ var GlobalServerStatusStore = &ServerStatusStore{ } // UpdateStatus updates or adds a server status -func (s *ServerStatusStore) UpdateStatus(serverID string, status *common.BaseMessage) { +func (s *ServerStatusStore) UpdateStatus(serverName string, status *common.BaseMessage) { s.mu.Lock() defer s.mu.Unlock() - s.servers[serverID] = status + s.servers[serverName] = status } // GetStatus gets a single server status -func (s *ServerStatusStore) GetStatus(serverID string) (*common.BaseMessage, bool) { +func (s *ServerStatusStore) GetStatus(serverName string) (*common.BaseMessage, bool) { s.mu.RLock() defer s.mu.RUnlock() - status, ok := s.servers[serverID] + status, ok := s.servers[serverName] return status, ok } diff --git a/internal/admin/router.go b/internal/admin/router.go index e2d9cedf1c1..4e9dd213465 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -46,6 +46,7 @@ func (r *Router) Setup(engine *gin.Engine) { // Public routes admin.GET("/ping", r.handler.Ping) admin.POST("/login", r.handler.Login) + admin.GET("/logout", r.handler.Logout) admin.POST("/reports", r.handler.Reports) @@ -55,7 +56,6 @@ func (r *Router) Setup(engine *gin.Engine) { { // Auth protected.GET("/auth", r.handler.AuthCheck) - protected.GET("/logout", r.handler.Logout) // User management protected.GET("/users", r.handler.ListUsers) diff --git a/internal/admin/service.go b/internal/admin/service.go index b45714396a1..7a236769b9b 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -31,6 +31,7 @@ import ( "ragflow/internal/model" "ragflow/internal/server" "ragflow/internal/utility" + "strconv" "time" ) @@ -282,20 +283,37 @@ func (s *Service) GetAllServices() ([]map[string]interface{}, error) { var result []map[string]interface{} for _, configDict := range allConfigs { - // Get service details to check status - serviceDetail, err := s.GetServiceDetails(configDict) - if err == nil { - if status, ok := serviceDetail["status"]; ok { - configDict["status"] = status + serviceType := configDict["service_type"] + if serviceType != "ragflow_server" { + // Get service details to check status + serviceDetail, err := s.GetServiceDetails(configDict) + if err == nil { + if status, ok := serviceDetail["status"]; ok { + configDict["status"] = status + } else { + configDict["status"] = "timeout" + } } else { configDict["status"] = "timeout" } - } else { - configDict["status"] = "timeout" + result = append(result, configDict) } - result = append(result, configDict) + } + id := len(result) + serverList := GlobalServerStatusStore.GetAllStatuses() + for _, serverStatus := range serverList { + serverItem := make(map[string]interface{}) + serverItem["name"] = serverStatus.ServerName + serverItem["service_type"] = serverStatus.ServerType + serverItem["id"] = id + id++ + serverItem["host"] = serverStatus.Host + serverItem["port"] = serverStatus.Port + serverItem["status"] = "alive" + result = append(result, serverItem) + } return result, nil } @@ -540,6 +558,7 @@ func (s *Service) checkMinioAlive(name string) (map[string]interface{}, error) { // Get minio config from allConfigs var host string + var port int var secure bool var verify bool = true @@ -550,6 +569,16 @@ func (s *Service) checkMinioAlive(name string) (map[string]interface{}, error) { if h, ok := config["host"].(string); ok { host = h } + + if p, ok := config["port"].(int); ok { + port = p + } else if p, ok := config["port"].(float64); ok { + port = int(p) + } else if p, ok := config["port"].(string); ok { + if parsedPort, err := strconv.Atoi(p); err == nil { + port = parsedPort + } + } // Get secure from extra config if extra, ok := config["extra"].(map[string]interface{}); ok { if s, ok := extra["secure"].(bool); ok { @@ -569,7 +598,10 @@ func (s *Service) checkMinioAlive(name string) (map[string]interface{}, error) { // Default host if host == "" { - host = "localhost:9000" + host = "localhost" + } + if port == 0 { + port = 9000 } // Determine scheme @@ -578,7 +610,7 @@ func (s *Service) checkMinioAlive(name string) (map[string]interface{}, error) { scheme = "https" } - url := fmt.Sprintf("%s://%s/minio/health/live", scheme, host) + url := fmt.Sprintf("%s://%s:%d/minio/health/live", scheme, host, port) // Create HTTP client with timeout client := &http.Client{ diff --git a/internal/handler/kb.go b/internal/handler/kb.go index a7b5f7ac25b..d4d4e848ef9 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -60,6 +60,10 @@ func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCo return "", code, err } + if *user.IsSuperuser { + return "", common.CodeForbidden, ErrForbidden + } + return user.ID, common.CodeSuccess, nil } @@ -97,6 +101,7 @@ var ( ErrMissingAuth = &HTTPError{Code: common.CodeUnauthorized, Message: "Missing Authorization header"} // ErrInvalidToken indicates invalid access token ErrInvalidToken = &HTTPError{Code: common.CodeUnauthorized, Message: "Invalid access token"} + ErrForbidden = &HTTPError{Code: common.CodeForbidden, Message: "Forbidden user"} ) // CreateKB handles the create knowledge base request From 8b2bbcba74755781654c67bdcdb9eac28a26d579 Mon Sep 17 00:00:00 2001 From: atian8179 Date: Tue, 10 Mar 2026 10:02:21 +0800 Subject: [PATCH 186/565] fix: include missing modules in ragflow-cli PyPI package (#13457) ## Problem The `ragflow-cli` PyPI package (v0.24.0) is missing `http_client.py`, `ragflow_client.py`, and `user.py`, causing import errors when installed from PyPI. ## Root Cause `pyproject.toml` only lists `ragflow_cli` and `parser` in `[tool.setuptools] py-modules`. ## Fix Add the three missing modules to `py-modules`. Fixes #13456 Co-authored-by: atian8179 --- admin/client/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml index 4b5e2cd31b8..1b79c85c31e 100644 --- a/admin/client/pyproject.toml +++ b/admin/client/pyproject.toml @@ -21,7 +21,7 @@ test = [ ] [tool.setuptools] -py-modules = ["ragflow_cli", "parser"] +py-modules = ["ragflow_cli", "parser", "http_client", "ragflow_client", "user"] [project.scripts] ragflow-cli = "ragflow_cli:main" From 404b5271bcbdc85dafa4b8c429738932f2d87a3d Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Tue, 10 Mar 2026 10:35:55 +0800 Subject: [PATCH 187/565] refactor: Moves the LLM factory initialization logic to the `dao` package. (#13476) ### What problem does this PR solve? refactor: Moves the LLM factory initialization logic to the `dao` package. Removes the `init_data` package and integrates the LLM factory initialization functionality into the `dao` package. Adds a `utility` package to provide general utility functions. Updates `server_main.go` to use the new initialization path. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: Jin Hai --- cmd/server_main.go | 5 +- internal/dao/database.go | 125 +++++++++++++++++++++++++- internal/init_data/llm_init.go | 157 --------------------------------- internal/utility/convert.go | 80 +++++++++++++++++ 4 files changed, 204 insertions(+), 163 deletions(-) delete mode 100644 internal/init_data/llm_init.go create mode 100644 internal/utility/convert.go diff --git a/cmd/server_main.go b/cmd/server_main.go index dbae3efbeae..c4919abeca8 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,8 +6,7 @@ import ( "net/http" "os" "os/signal" - "ragflow/internal/common" - "ragflow/internal/init_data" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/utility" "strings" @@ -67,7 +66,7 @@ func main() { } // Initialize LLM factory data models from configuration file - if err := init_data.InitLLMFactory(); err != nil { + if err := dao.InitLLMFactory(); err != nil { logger.Error("Failed to initialize LLM factory", err) } else { logger.Info("LLM factory initialized successfully") diff --git a/internal/dao/database.go b/internal/dao/database.go index 391759431ed..1529088fb4e 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -17,21 +17,50 @@ package dao import ( + "encoding/json" "fmt" + "log" + "os" + "path/filepath" + "time" + + "ragflow/internal/logger" "ragflow/internal/model" "ragflow/internal/server" - "time" + "ragflow/internal/utility" gormLogger "gorm.io/gorm/logger" "gorm.io/driver/mysql" "gorm.io/gorm" - - "ragflow/internal/logger" ) var DB *gorm.DB +// LLMFactoryConfig represents a single LLM factory configuration +type LLMFactoryConfig struct { + Name string `json:"name"` + Logo string `json:"logo"` + Tags string `json:"tags"` + Status string `json:"status"` + Rank string `json:"rank"` + LLM []LLMConfig `json:"llm"` +} + +// LLMConfig represents a single LLM model configuration +type LLMConfig struct { + LLMName string `json:"llm_name"` + Tags string `json:"tags"` + MaxTokens int64 `json:"max_tokens"` + ModelType string `json:"model_type"` + IsTools bool `json:"is_tools"` +} + +// LLMFactoriesFile represents the structure of llm_factories.json +type LLMFactoriesFile struct { + FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` +} + // InitDB initialize database connection func InitDB() error { cfg := server.GetConfig() @@ -132,3 +161,93 @@ func InitDB() error { func GetDB() *gorm.DB { return DB } + +// InitLLMFactory initializes LLM factories and models from JSON file. +// It reads the llm_factories.json configuration file and populates the database +// with LLM factory and model information. If a factory or model already exists, +// it will be updated with the new configuration. +// +// Returns: +// - error: An error if the initialization fails, nil otherwise. +func InitLLMFactory() error { + configPath := filepath.Join(utility.GetProjectBaseDirectory(), "conf", "llm_factories.json") + + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read llm_factories.json: %w", err) + } + + var fileData LLMFactoriesFile + if err := json.Unmarshal(data, &fileData); err != nil { + return fmt.Errorf("failed to parse llm_factories.json: %w", err) + } + + db := DB + + for _, factory := range fileData.FactoryLLMInfos { + status := factory.Status + if status == "" { + status = "1" + } + + llmFactory := &model.LLMFactories{ + Name: factory.Name, + Logo: utility.StringPtr(factory.Logo), + Tags: factory.Tags, + Rank: utility.ParseInt64(factory.Rank), + Status: &status, + } + + var existingFactory model.LLMFactories + result := db.Where("name = ?", factory.Name).First(&existingFactory) + if result.Error != nil { + if err := db.Create(llmFactory).Error; err != nil { + log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) + continue + } + } else { + if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ + "logo": llmFactory.Logo, + "tags": llmFactory.Tags, + "rank": llmFactory.Rank, + "status": llmFactory.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) + } + } + + for _, llm := range factory.LLM { + llmStatus := "1" + llmModel := &model.LLM{ + LLMName: llm.LLMName, + ModelType: llm.ModelType, + FID: factory.Name, + MaxTokens: llm.MaxTokens, + Tags: llm.Tags, + IsTools: llm.IsTools, + Status: &llmStatus, + } + + var existingLLM model.LLM + result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) + if result.Error != nil { + if err := db.Create(llmModel).Error; err != nil { + log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } else { + if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ + "model_type": llmModel.ModelType, + "max_tokens": llmModel.MaxTokens, + "tags": llmModel.Tags, + "is_tools": llmModel.IsTools, + "status": llmModel.Status, + }).Error; err != nil { + log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) + } + } + } + } + + log.Println("LLM factories initialized successfully") + return nil +} diff --git a/internal/init_data/llm_init.go b/internal/init_data/llm_init.go deleted file mode 100644 index ef67dd6ba3a..00000000000 --- a/internal/init_data/llm_init.go +++ /dev/null @@ -1,157 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package init_data - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path/filepath" - - "ragflow/internal/dao" - "ragflow/internal/model" -) - -// LLMFactoryConfig represents a single LLM factory configuration -type LLMFactoryConfig struct { - Name string `json:"name"` - Logo string `json:"logo"` - Tags string `json:"tags"` - Status string `json:"status"` - Rank string `json:"rank"` - LLM []LLMConfig `json:"llm"` -} - -// LLMConfig represents a single LLM model configuration -type LLMConfig struct { - LLMName string `json:"llm_name"` - Tags string `json:"tags"` - MaxTokens int64 `json:"max_tokens"` - ModelType string `json:"model_type"` - IsTools bool `json:"is_tools"` -} - -// LLMFactoriesFile represents the structure of llm_factories.json -type LLMFactoriesFile struct { - FactoryLLMInfos []LLMFactoryConfig `json:"factory_llm_infos"` -} - -// InitLLMFactory initializes LLM factories and models from JSON file -func InitLLMFactory() error { - configPath := filepath.Join(getProjectBaseDirectory(), "conf", "llm_factories.json") - - data, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read llm_factories.json: %w", err) - } - - var fileData LLMFactoriesFile - if err := json.Unmarshal(data, &fileData); err != nil { - return fmt.Errorf("failed to parse llm_factories.json: %w", err) - } - - db := dao.DB - - for _, factory := range fileData.FactoryLLMInfos { - status := factory.Status - if status == "" { - status = "1" - } - - llmFactory := &model.LLMFactories{ - Name: factory.Name, - Logo: stringPtr(factory.Logo), - Tags: factory.Tags, - Rank: parseInt64(factory.Rank), - Status: &status, - } - - var existingFactory model.LLMFactories - result := db.Where("name = ?", factory.Name).First(&existingFactory) - if result.Error != nil { - if err := db.Create(llmFactory).Error; err != nil { - log.Printf("Failed to create LLM factory %s: %v", factory.Name, err) - continue - } - } else { - if err := db.Model(&model.LLMFactories{}).Where("name = ?", factory.Name).Updates(map[string]interface{}{ - "logo": llmFactory.Logo, - "tags": llmFactory.Tags, - "rank": llmFactory.Rank, - "status": llmFactory.Status, - }).Error; err != nil { - log.Printf("Failed to update LLM factory %s: %v", factory.Name, err) - } - } - - for _, llm := range factory.LLM { - llmStatus := "1" - llmModel := &model.LLM{ - LLMName: llm.LLMName, - ModelType: llm.ModelType, - FID: factory.Name, - MaxTokens: llm.MaxTokens, - Tags: llm.Tags, - IsTools: llm.IsTools, - Status: &llmStatus, - } - - var existingLLM model.LLM - result := db.Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).First(&existingLLM) - if result.Error != nil { - if err := db.Create(llmModel).Error; err != nil { - log.Printf("Failed to create LLM %s/%s: %v", factory.Name, llm.LLMName, err) - } - } else { - if err := db.Model(&model.LLM{}).Where("llm_name = ? AND fid = ?", llm.LLMName, factory.Name).Updates(map[string]interface{}{ - "model_type": llmModel.ModelType, - "max_tokens": llmModel.MaxTokens, - "tags": llmModel.Tags, - "is_tools": llmModel.IsTools, - "status": llmModel.Status, - }).Error; err != nil { - log.Printf("Failed to update LLM %s/%s: %v", factory.Name, llm.LLMName, err) - } - } - } - } - - log.Println("LLM factories initialized successfully") - return nil -} - -func getProjectBaseDirectory() string { - cwd, err := os.Getwd() - if err != nil { - return "." - } - return cwd -} - -func stringPtr(s string) *string { - if s == "" { - return nil - } - return &s -} - -func parseInt64(s string) int64 { - var result int64 - fmt.Sscanf(s, "%d", &result) - return result -} diff --git a/internal/utility/convert.go b/internal/utility/convert.go new file mode 100644 index 00000000000..281b3614864 --- /dev/null +++ b/internal/utility/convert.go @@ -0,0 +1,80 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "fmt" + "os" +) + +// GetProjectBaseDirectory returns the current working directory. +// If an error occurs while getting the current directory, it returns ".". +// +// Returns: +// - string: The current working directory path, or "." if an error occurs. +// +// Example: +// +// baseDir := utility.GetProjectBaseDirectory() +// configPath := filepath.Join(baseDir, "conf", "config.json") +func GetProjectBaseDirectory() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +// StringPtr converts a string to a pointer of string. +// If the input string is empty, it returns nil. +// +// Parameters: +// - s: The string to convert to a pointer. +// +// Returns: +// - *string: A pointer to the input string, or nil if the input is empty. +// +// Example: +// +// name := utility.StringPtr("example") // returns &"example" +// empty := utility.StringPtr("") // returns nil +func StringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +// ParseInt64 parses a string to int64. +// If parsing fails, it returns 0. +// +// Parameters: +// - s: The string to parse. +// +// Returns: +// - int64: The parsed integer value, or 0 if parsing fails. +// +// Example: +// +// val := utility.ParseInt64("123") // returns 123 +// val := utility.ParseInt64("abc") // returns 0 +// val := utility.ParseInt64("") // returns 0 +func ParseInt64(s string) int64 { + var result int64 + fmt.Sscanf(s, "%d", &result) + return result +} From 2e46bb8db6b54b252f4636c56c622ed255cfe915 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Tue, 10 Mar 2026 10:49:39 +0800 Subject: [PATCH 188/565] Update ext field type of heartbeat message (#13490) ### What problem does this PR solve? As title ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Signed-off-by: Jin Hai --- internal/admin/handler.go | 17 +++++++++++++---- internal/admin/service.go | 4 ++-- internal/common/status_message.go | 18 +++++++++--------- internal/service/heartbeat_sender.go | 2 +- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index eb5b4d50554..4a4e36bfb6d 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -642,7 +642,7 @@ func (h *Handler) GetUserPermission(c *gin.Context) { // GetServices handle get all services func (h *Handler) GetServices(c *gin.Context) { - services, err := h.service.GetAllServices() + services, err := h.service.ListServices() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": common.CodeServerError, @@ -908,7 +908,10 @@ func (h *Handler) TestSandboxConnection(c *gin.Context) { result, err := h.service.TestSandboxConnection(req.ProviderType, req.Config) if err != nil { - errorResponse(c, err.Error(), 400) + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "Invalid access token", + }) return } @@ -962,7 +965,10 @@ func (h *Handler) HandleNoRoute(c *gin.Context) { func (h *Handler) Reports(c *gin.Context) { var req common.BaseMessage if err := c.ShouldBindJSON(&req); err != nil { - errorResponse(c, "Invalid request body: "+err.Error(), 400) + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "Invalid request body: " + err.Error(), + }) return } @@ -973,7 +979,10 @@ func (h *Handler) Reports(c *gin.Context) { // Only process heartbeat messages for now if req.MessageType != common.MessageHeartbeat { - errorResponse(c, "Unsupported report type: "+string(req.MessageType), 400) + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "Unsupported report type: " + string(req.MessageType), + }) return } diff --git a/internal/admin/service.go b/internal/admin/service.go index 7a236769b9b..6e2e1b346cc 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -277,8 +277,8 @@ func (s *Service) GetUserPermission(username string) ([]map[string]interface{}, return []map[string]interface{}{}, nil } -// GetAllServices get all services -func (s *Service) GetAllServices() ([]map[string]interface{}, error) { +// ListServices get all services +func (s *Service) ListServices() ([]map[string]interface{}, error) { allConfigs := server.GetAllConfigs() var result []map[string]interface{} diff --git a/internal/common/status_message.go b/internal/common/status_message.go index 76d29ac3eb9..d538848a9eb 100644 --- a/internal/common/status_message.go +++ b/internal/common/status_message.go @@ -21,13 +21,13 @@ const ( ) type BaseMessage struct { - MessageID int64 `json:"report_id"` - MessageType MessageType `json:"report_type"` - ServerName string `json:"server_id"` - ServerType ServerType `json:"server_type"` - Host string `json:"host"` - Port int `json:"port"` - Version string `json:"version"` - Timestamp time.Time `json:"timestamp"` - Ext map[string]interface{} `json:"ext,omitempty"` + MessageID int64 `json:"report_id"` + MessageType MessageType `json:"report_type"` + ServerName string `json:"server_id"` + ServerType ServerType `json:"server_type"` + Host string `json:"host"` + Port int `json:"port"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + Ext interface{} `json:"ext,omitempty"` } diff --git a/internal/service/heartbeat_sender.go b/internal/service/heartbeat_sender.go index ec2b198320b..8a80c471376 100644 --- a/internal/service/heartbeat_sender.go +++ b/internal/service/heartbeat_sender.go @@ -103,7 +103,7 @@ func (h *HeartbeatSender) SendHeartbeat() error { Port: h.port, Version: h.version, Timestamp: time.Now(), - Ext: make(map[string]interface{}), + Ext: nil, } jsonData, err := json.Marshal(message) From 453641fd0ed563c514cd2fbf7294b631ab5ee858 Mon Sep 17 00:00:00 2001 From: tunsuy <957126743@qq.com> Date: Tue, 10 Mar 2026 11:20:31 +0800 Subject: [PATCH 189/565] fix: detect and fallback garbled PDF text to OCR (#13366) (#13404) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem When PDF fonts lack ToUnicode/CMap mappings, pdfplumber (pdfminer) cannot map CIDs to correct Unicode characters, outputting PUA characters (U+E000~U+F8FF) or `(cid:xxx)` placeholders. The original code fully trusted pdfplumber text without any garbled detection, causing garbled output in the final parsed result. Relates to #13366 ## Solution ### 1. Garbled text detection functions - `_is_garbled_char(ch)`: Detects PUA characters (BMP/Plane 15/16), replacement character U+FFFD, control characters, and unassigned/surrogate codepoints - `_is_garbled_text(text, threshold)`: Calculates garbled ratio and detects `(cid:xxx)` patterns ### 2. Box-level fallback (in `__ocr()`) When a text box has ≥50% garbled characters, discard pdfplumber text and fallback to OCR recognition. ### 3. Page-level detection (in `__images__()`) Sample characters from each page; if garbled rate ≥30%, clear all pdfplumber characters for that page, forcing full OCR. ### 4. Layout recognizer CID filtering Filter out `(cid:xxx)` patterns in `layout_recognizer.py` text processing to prevent them from polluting layout analysis. ## Testing - 29 unit tests covering: normal CJK/English text, PUA characters, CID patterns, mixed text, boundary thresholds, edge cases - All 85 existing project unit tests pass without regression --- deepdoc/parser/pdf_parser.py | 180 ++++++- deepdoc/vision/layout_recognizer.py | 7 +- .../parser/test_pdf_garbled_detection.py | 438 ++++++++++++++++++ 3 files changed, 619 insertions(+), 6 deletions(-) create mode 100644 test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 49880c3c552..6020361c07b 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -22,6 +22,7 @@ import re import sys import threading +import unicodedata from collections import Counter, defaultdict from copy import deepcopy from io import BytesIO @@ -197,6 +198,127 @@ def _has_color(self, o): return False return True + # CID pattern regex for unmapped font characters from pdfminer + _CID_PATTERN = re.compile(r"\(cid\s*:\s*\d+\s*\)") + + @staticmethod + def _is_garbled_char(ch): + """Check if a single character is garbled (unmappable from PDF font encoding). + + A character is considered garbled if it falls into Unicode Private Use Areas + or certain replacement/control character ranges that typically indicate + pdfminer failed to map a CID to a valid Unicode codepoint. + """ + if not ch: + return False + cp = ord(ch) + if 0xE000 <= cp <= 0xF8FF: + return True + if 0xF0000 <= cp <= 0xFFFFF: + return True + if 0x100000 <= cp <= 0x10FFFF: + return True + if cp == 0xFFFD: + return True + if cp < 0x20 and ch not in ('\t', '\n', '\r'): + return True + if 0x80 <= cp <= 0x9F: + return True + cat = unicodedata.category(ch) + if cat in ("Cn", "Cs"): + return True + return False + + @staticmethod + def _is_garbled_text(text, threshold=0.5): + """Check if a text string contains too many garbled characters. + + Examines each character and determines if the overall proportion + of garbled characters exceeds the given threshold. Also detects + pdfminer's CID placeholder patterns like '(cid:123)'. + """ + if not text or not text.strip(): + return False + if RAGFlowPdfParser._CID_PATTERN.search(text): + return True + garbled_count = 0 + total = 0 + for ch in text: + if ch.isspace(): + continue + total += 1 + if RAGFlowPdfParser._is_garbled_char(ch): + garbled_count += 1 + if total == 0: + return False + return garbled_count / total >= threshold + + @staticmethod + def _has_subset_font_prefix(fontname): + """Check if a font name has a subset prefix (e.g. 'DY1+ZLQDm1-1'). + + PDF subset fonts use a 6-letter uppercase tag followed by '+' before + the actual font name. Some tools use shorter tags (e.g. 'DY1+'). + """ + if not fontname: + return False + return bool(re.match(r"^[A-Z0-9]{2,6}\+", fontname)) + + @staticmethod + def _is_garbled_by_font_encoding(page_chars, min_chars=20): + """Detect garbled text caused by broken font encoding mappings. + + Some PDFs (especially older Chinese standards) embed custom fonts that + map CJK glyphs to ASCII codepoints. The extracted text appears as + random ASCII punctuation/symbols instead of actual CJK characters. + + Detection strategy: if a significant proportion of characters come from + subset-embedded fonts and the page produces overwhelmingly ASCII + (punctuation, digits, symbols) with virtually no CJK/Hangul/Kana + characters, the page is likely garbled due to broken font encoding. + """ + if not page_chars or len(page_chars) < min_chars: + return False + + subset_font_count = 0 + total_non_space = 0 + ascii_punct_sym = 0 + cjk_like = 0 + + for c in page_chars: + text = c.get("text", "") + fontname = c.get("fontname", "") + if not text or text.isspace(): + continue + total_non_space += 1 + + if RAGFlowPdfParser._has_subset_font_prefix(fontname): + subset_font_count += 1 + + cp = ord(text[0]) + if (0x2E80 <= cp <= 0x9FFF or 0xF900 <= cp <= 0xFAFF + or 0x20000 <= cp <= 0x2FA1F + or 0xAC00 <= cp <= 0xD7AF + or 0x3040 <= cp <= 0x30FF): + cjk_like += 1 + elif (0x21 <= cp <= 0x2F or 0x3A <= cp <= 0x40 + or 0x5B <= cp <= 0x60 or 0x7B <= cp <= 0x7E): + ascii_punct_sym += 1 + + if total_non_space < min_chars: + return False + + subset_ratio = subset_font_count / total_non_space + if subset_ratio < 0.3: + return False + + cjk_ratio = cjk_like / total_non_space + punct_ratio = ascii_punct_sym / total_non_space + if cjk_ratio < 0.05 and punct_ratio > 0.4: + return True + + return False + def _evaluate_table_orientation(self, table_img, sample_ratio=0.3): """ Evaluate the best rotation orientation for a table image. @@ -618,14 +740,40 @@ def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None): if not b["chars"]: del b["chars"] continue - m_ht = np.mean([c["height"] for c in b["chars"]]) - for c in Recognizer.sort_Y_firstly(b["chars"], m_ht): + box_chars = b["chars"] + m_ht = np.mean([c["height"] for c in box_chars]) + garbled_count = 0 + total_count = 0 + for c in Recognizer.sort_Y_firstly(box_chars, m_ht): if c["text"] == " " and b["text"]: if re.match(r"[0-9a-zA-Zа-яА-Я,.?;:!%%]", b["text"][-1]): b["text"] += " " else: b["text"] += c["text"] + for ch in c["text"]: + if not ch.isspace(): + total_count += 1 + if self._is_garbled_char(ch): + garbled_count += 1 del b["chars"] + # If the majority of characters from pdfplumber are garbled, + # clear the text so OCR recognition will be used as fallback. + # Strategy 1: PUA / unmapped CID characters + if total_count > 0 and garbled_count / total_count >= 0.5: + logging.info( + "Page %d: detected garbled pdfplumber text (garbled=%d/%d), falling back to OCR for box at (%.1f, %.1f)", + pagenum, garbled_count, total_count, b["x0"], b["top"], + ) + b["text"] = "" + continue + # Strategy 2: font-encoding garbling — all chars are ASCII + # punctuation from subset fonts (no CJK output) + if total_count > 0 and self._is_garbled_by_font_encoding(box_chars, min_chars=5): + logging.info( + "Page %d: detected font-encoding garbled text (%d chars), falling back to OCR for box at (%.1f, %.1f)", + pagenum, total_count, b["x0"], b["top"], + ) + b["text"] = "" logging.info(f"__ocr sorting {len(chars)} chars cost {timer() - start}s") start = timer() @@ -1400,6 +1548,34 @@ def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. + # Detect garbled pages and clear their chars so the OCR + # path will be used instead. Two detection strategies: + # 1) PUA / unmapped CID characters (threshold=0.3) + # 2) Font-encoding garbling: subset fonts mapping CJK to ASCII + for pi, page_ch in enumerate(self.page_chars): + if not page_ch: + continue + # Strategy 1: PUA / CID garbling + sample = page_ch if len(page_ch) <= 200 else page_ch[:200] + sample_text = "".join(c.get("text", "") for c in sample) + if self._is_garbled_text(sample_text, threshold=0.3): + logging.warning( + "Page %d: pdfplumber extracted mostly garbled characters (%d chars), " + "clearing to use OCR fallback.", + page_from + pi + 1, len(page_ch), + ) + self.page_chars[pi] = [] + continue + # Strategy 2: font-encoding garbling (CJK mapped to ASCII) + if self._is_garbled_by_font_encoding(page_ch): + logging.warning( + "Page %d: detected font-encoding garbled text " + "(subset fonts with no CJK output, %d chars), " + "clearing to use OCR fallback.", + page_from + pi + 1, len(page_ch), + ) + self.page_chars[pi] = [] + self.total_page = len(self.pdf.pages) except Exception as e: diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 5b79e2bf5c6..be1f8667cec 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -17,7 +17,7 @@ import logging import math import os -# import re +import re from collections import Counter from copy import deepcopy @@ -62,9 +62,8 @@ def __init__(self, domain): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): def __is_garbage(b): - return False - # patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"] - # return any([re.search(p, b["text"]) for p in patt]) + patt = [r"\(cid\s*:\s*\d+\s*\)"] + return any([re.search(p, b.get("text", "")) for p in patt]) if self.client: layouts = self.client.predict(image_list) diff --git a/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py b/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py new file mode 100644 index 00000000000..fa7c4a8b76b --- /dev/null +++ b/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py @@ -0,0 +1,438 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for PDF garbled text detection and layout garbage filtering. + +Tests cover: +- RAGFlowPdfParser static methods: _is_garbled_char, _is_garbled_text, + _has_subset_font_prefix, _is_garbled_by_font_encoding +- layout_recognizer.__is_garbage: CID pattern filtering +""" + +import re +import sys +import os +import importlib.util +from unittest import mock + +# Import RAGFlowPdfParser directly by file path to avoid triggering +# deepdoc/parser/__init__.py which pulls in heavy dependencies +# (pdfplumber, xgboost, etc.) that may not be available in test environments. +# +# We mock the heavy third-party modules so that pdf_parser.py can be loaded +# purely for its static detection methods. +_MOCK_MODULES = [ + "numpy", "np", "pdfplumber", "xgboost", "xgb", + "huggingface_hub", "PIL", "PIL.Image", "pypdf", + "sklearn", "sklearn.cluster", "sklearn.metrics", + "common", "common.file_utils", "common.misc_utils", "common.settings", + "common.token_utils", + "deepdoc", "deepdoc.vision", "deepdoc.parser", + "rag", "rag.nlp", "rag.prompts", "rag.prompts.generator", +] +for _m in _MOCK_MODULES: + if _m not in sys.modules: + sys.modules[_m] = mock.MagicMock() + +def _find_project_root(marker="pyproject.toml"): + """Walk up from this file until a directory containing *marker* is found.""" + cur = os.path.dirname(os.path.abspath(__file__)) + while True: + if os.path.exists(os.path.join(cur, marker)): + return cur + parent = os.path.dirname(cur) + if parent == cur: + raise FileNotFoundError(f"Could not locate project root (missing {marker})") + cur = parent + + +_MODULE_PATH = os.path.join(_find_project_root(), "deepdoc", "parser", "pdf_parser.py") +_spec = importlib.util.spec_from_file_location("pdf_parser", _MODULE_PATH) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +_Parser = _mod.RAGFlowPdfParser +is_garbled_char = _Parser._is_garbled_char +is_garbled_text = _Parser._is_garbled_text +has_subset_font_prefix = _Parser._has_subset_font_prefix +is_garbled_by_font_encoding = _Parser._is_garbled_by_font_encoding + + +# --------------------------------------------------------------------------- +# Tests for is_garbled_char +# --------------------------------------------------------------------------- + + +class TestIsGarbledChar: + """Tests for the is_garbled_char function.""" + + def test_normal_ascii_chars(self): + for ch in "Hello World 123 !@#": + assert is_garbled_char(ch) is False + + def test_normal_chinese_chars(self): + for ch in "中文测试你好世界": + assert is_garbled_char(ch) is False + + def test_normal_japanese_chars(self): + for ch in "日本語テスト": + assert is_garbled_char(ch) is False + + def test_normal_korean_chars(self): + for ch in "한국어테스트": + assert is_garbled_char(ch) is False + + def test_common_whitespace_not_garbled(self): + assert is_garbled_char('\t') is False + assert is_garbled_char('\n') is False + assert is_garbled_char('\r') is False + assert is_garbled_char(' ') is False + + def test_pua_chars_are_garbled(self): + assert is_garbled_char('\uE000') is True + assert is_garbled_char('\uF000') is True + assert is_garbled_char('\uF8FF') is True + + def test_supplementary_pua_a(self): + assert is_garbled_char(chr(0xF0000)) is True + assert is_garbled_char(chr(0xFFFFF)) is True + + def test_supplementary_pua_b(self): + assert is_garbled_char(chr(0x100000)) is True + assert is_garbled_char(chr(0x10FFFF)) is True + + def test_replacement_char(self): + assert is_garbled_char('\uFFFD') is True + + def test_c0_control_chars(self): + assert is_garbled_char('\x00') is True + assert is_garbled_char('\x01') is True + assert is_garbled_char('\x1F') is True + + def test_c1_control_chars(self): + assert is_garbled_char('\x80') is True + assert is_garbled_char('\x8F') is True + assert is_garbled_char('\x9F') is True + + def test_empty_string(self): + assert is_garbled_char('') is False + + def test_common_punctuation(self): + for ch in ".,;:!?()[]{}\"'-/\\@#$%^&*+=<>~`|": + assert is_garbled_char(ch) is False + + def test_unicode_symbols(self): + for ch in "©®™°±²³µ¶·¹º»¼½¾": + assert is_garbled_char(ch) is False + + +# --------------------------------------------------------------------------- +# Tests for is_garbled_text +# --------------------------------------------------------------------------- + + +class TestIsGarbledText: + """Tests for the is_garbled_text function.""" + + def test_normal_chinese_text(self): + assert is_garbled_text("这是一段正常的中文文本") is False + + def test_normal_english_text(self): + assert is_garbled_text("This is normal English text.") is False + + def test_mixed_normal_text(self): + assert is_garbled_text("Hello 你好 World 世界 123") is False + + def test_empty_text(self): + assert is_garbled_text("") is False + assert is_garbled_text(" ") is False + + def test_none_text(self): + assert is_garbled_text(None) is False + + def test_all_pua_chars(self): + text = "\uE000\uE001\uE002\uE003\uE004" + assert is_garbled_text(text) is True + + def test_mostly_garbled(self): + text = "\uE000\uE001\uE002好" + assert is_garbled_text(text, threshold=0.5) is True + + def test_few_garbled_below_threshold(self): + text = "这是正常文本\uE000" + assert is_garbled_text(text, threshold=0.5) is False + + def test_cid_pattern_detected(self): + assert is_garbled_text("Hello (cid:123) World") is True + assert is_garbled_text("(cid : 45)") is True + assert is_garbled_text("(cid:0)") is True + + def test_cid_like_but_not_matching(self): + assert is_garbled_text("This is a valid cid reference") is False + + def test_whitespace_only_text(self): + assert is_garbled_text(" \t\n ") is False + + def test_custom_threshold(self): + text = "\uE000正常" + assert is_garbled_text(text, threshold=0.3) is True + assert is_garbled_text(text, threshold=0.5) is False + + def test_replacement_chars_in_text(self): + text = "文档\uFFFD\uFFFD解析" + assert is_garbled_text(text, threshold=0.5) is False + assert is_garbled_text(text, threshold=0.3) is True + + def test_real_world_garbled_pattern(self): + text = "\uE000\uE001\uE002\uE003\uE004\uE005\uE006\uE007" + assert is_garbled_text(text) is True + + def test_mixed_garbled_and_normal_at_boundary(self): + text = "AB\uE000\uE001" + assert is_garbled_text(text, threshold=0.5) is True + text2 = "ABC\uE000" + assert is_garbled_text(text2, threshold=0.5) is False + + +# --------------------------------------------------------------------------- +# Tests for has_subset_font_prefix +# --------------------------------------------------------------------------- + + +class TestHasSubsetFontPrefix: + """Tests for the has_subset_font_prefix function.""" + + def test_standard_subset_prefix(self): + assert has_subset_font_prefix("ABCDEF+Arial") is True + assert has_subset_font_prefix("XYZABC+TimesNewRoman") is True + + def test_short_subset_prefix(self): + assert has_subset_font_prefix("DY1+ZLQDm1-1") is True + assert has_subset_font_prefix("AB+Font") is True + + def test_alphanumeric_prefix(self): + assert has_subset_font_prefix("DY2+ZLQDnC-2") is True + assert has_subset_font_prefix("A1B2C3+MyFont") is True + + def test_no_prefix(self): + assert has_subset_font_prefix("Arial") is False + assert has_subset_font_prefix("TimesNewRoman") is False + + def test_empty_or_none(self): + assert has_subset_font_prefix("") is False + assert has_subset_font_prefix(None) is False + + def test_plus_in_middle_not_prefix(self): + assert has_subset_font_prefix("Font+Name") is False + + def test_lowercase_not_prefix(self): + assert has_subset_font_prefix("abc+Font") is False + + +# --------------------------------------------------------------------------- +# Tests for is_garbled_by_font_encoding +# --------------------------------------------------------------------------- + + +def _make_chars(texts, fontname="DY1+ZLQDm1-1"): + """Helper to create a list of pdfplumber-like char dicts.""" + return [{"text": t, "fontname": fontname} for t in texts] + + +class TestIsGarbledByFontEncoding: + """Tests for font-encoding garbled text detection. + + This covers the scenario where PDF fonts with broken ToUnicode + mappings cause CJK characters to be extracted as ASCII + punctuation/symbols (e.g. GB.18067-2000.pdf). + """ + + def test_ascii_punct_from_subset_font_is_garbled(self): + """Simulates GB.18067-2000.pdf: all chars are ASCII punct from subset fonts.""" + chars = _make_chars( + list('!"#$%&\'(\'&)\'"*$!"#$%&\'\'()*+,$-'), + fontname="DY1+ZLQDm1-1", + ) + assert is_garbled_by_font_encoding(chars) is True + + def test_normal_cjk_text_not_garbled(self): + """Normal Chinese text from subset fonts should not be flagged.""" + chars = _make_chars( + list("这是一段正常的中文文本用于测试的示例内容没有问题"), + fontname="ABCDEF+SimSun", + ) + assert is_garbled_by_font_encoding(chars) is False + + def test_mixed_cjk_and_ascii_not_garbled(self): + """Mixed CJK and ASCII content should not be flagged.""" + chars = _make_chars( + list("GB18067-2000居住区大气中酚卫生标准"), + fontname="DY1+ZLQDm1-1", + ) + assert is_garbled_by_font_encoding(chars) is False + + def test_non_subset_font_not_flagged(self): + """ASCII punct from non-subset fonts should not be flagged.""" + chars = _make_chars( + list('!"#$%&\'()*+,-./!"#$%&\'()*+,-./'), + fontname="Arial", + ) + assert is_garbled_by_font_encoding(chars) is False + + def test_too_few_chars_not_flagged(self): + """Pages with very few chars should not trigger detection.""" + chars = _make_chars(list('!"#$'), fontname="DY1+ZLQDm1-1") + assert is_garbled_by_font_encoding(chars) is False + + def test_mostly_digits_not_garbled(self): + """Pages with lots of digits (like data tables) should not be flagged.""" + chars = _make_chars( + list("1234567890" * 3), + fontname="DY1+ZLQDm1-1", + ) + assert is_garbled_by_font_encoding(chars) is False + + def test_english_letters_not_garbled(self): + """Pages with English letters should not be flagged.""" + chars = _make_chars( + list("The quick brown fox jumps over the lazy dog"), + fontname="ABCDEF+Arial", + ) + assert is_garbled_by_font_encoding(chars) is False + + def test_real_world_gb18067_page1(self): + """Simulate actual GB.18067-2000.pdf Page 1 character distribution.""" + page_text = '!"#$%&\'(\'&)\'"*$!"#$%&\'\'()*+,$-' + chars = _make_chars(list(page_text), fontname="DY1+ZLQDm1-1") + assert is_garbled_by_font_encoding(chars) is True + + def test_real_world_gb18067_page3(self): + """Simulate actual GB.18067-2000.pdf Page 3 character distribution.""" + page_text = '!"#$%&\'()*+,-.*+/0+123456789:;<' + chars = _make_chars(list(page_text), fontname="DY1+ZLQDnC-1") + assert is_garbled_by_font_encoding(chars) is True + + def test_empty_chars(self): + assert is_garbled_by_font_encoding([]) is False + assert is_garbled_by_font_encoding(None) is False + + def test_only_spaces(self): + chars = _make_chars([" "] * 30, fontname="DY1+ZLQDm1-1") + assert is_garbled_by_font_encoding(chars) is False + + def test_small_min_chars_threshold(self): + """With reduced min_chars, even small boxes can be detected.""" + chars = _make_chars(list('!"#$%&'), fontname="DY1+ZLQDm1-1") + assert is_garbled_by_font_encoding(chars, min_chars=5) is True + assert is_garbled_by_font_encoding(chars, min_chars=20) is False + + def test_boundary_cjk_ratio(self): + """Just below 5% CJK threshold should still be flagged.""" + # 1 CJK out of 25 chars = 4% CJK, rest are punct from subset font + chars = _make_chars(list('!"#$%&\'()*+,-./!@#$%^&*'), fontname="DY1+Font") + chars.append({"text": "中", "fontname": "DY1+Font"}) + assert is_garbled_by_font_encoding(chars, min_chars=5) is True + + def test_boundary_above_cjk_threshold(self): + """Above 5% CJK ratio should NOT be flagged.""" + # 3 CJK out of 23 chars = ~13% CJK + chars = _make_chars(list('!"#$%&\'()*+,-./!@#$'), fontname="DY1+Font") + for ch in "中文字": + chars.append({"text": ch, "fontname": "DY1+Font"}) + assert is_garbled_by_font_encoding(chars, min_chars=5) is False + + def test_low_subset_ratio_not_flagged(self): + """When only a few chars come from subset fonts, should not be flagged. + + Addresses reviewer feedback: a single subset font should not cause + the entire page to be flagged as garbled. + """ + # 5 chars from subset font, 20 from normal font -> 20% subset ratio < 30% + chars = _make_chars(list('!"#$%'), fontname="DY1+Font") + chars.extend(_make_chars(list('!"#$%&\'()*+,-./!@#$%'), fontname="Arial")) + assert is_garbled_by_font_encoding(chars, min_chars=5) is False + + def test_high_subset_ratio_flagged(self): + """When most chars come from subset fonts, detection should trigger.""" + # All 30 chars from subset font with punct -> garbled + chars = _make_chars( + list('!"#$%&\'()*+,-./!@#$%^&*()[]{}'), + fontname="BCDGEE+R0015", + ) + assert is_garbled_by_font_encoding(chars) is True + + +# --------------------------------------------------------------------------- +# Tests for layout_recognizer.__is_garbage +# --------------------------------------------------------------------------- + + +def _is_garbage(b): + """Reproduce LayoutRecognizer.__is_garbage for unit testing. + + The original is a closure nested inside LayoutRecognizer.__call__ + (deepdoc/vision/layout_recognizer.py). We replicate it here because + it cannot be directly imported. + """ + patt = [r"\(cid\s*:\s*\d+\s*\)"] + return any([re.search(p, b.get("text", "")) for p in patt]) + + +class TestLayoutRecognizerIsGarbage: + """Tests for the layout_recognizer __is_garbage function. + + This function filters out text boxes containing CID patterns like + (cid:123) which indicate unmapped characters in PDF fonts. + """ + + def test_cid_pattern_simple(self): + assert _is_garbage({"text": "(cid:123)"}) is True + + def test_cid_pattern_with_spaces(self): + assert _is_garbage({"text": "(cid : 45)"}) is True + assert _is_garbage({"text": "(cid : 0)"}) is True + + def test_cid_pattern_embedded_in_text(self): + assert _is_garbage({"text": "Hello (cid:99) World"}) is True + + def test_cid_pattern_multiple(self): + assert _is_garbage({"text": "(cid:1)(cid:2)(cid:3)"}) is True + + def test_normal_text_not_garbage(self): + assert _is_garbage({"text": "This is normal text."}) is False + + def test_chinese_text_not_garbage(self): + assert _is_garbage({"text": "这是正常的中文内容"}) is False + + def test_empty_text_not_garbage(self): + assert _is_garbage({"text": ""}) is False + + def test_missing_text_key_not_garbage(self): + assert _is_garbage({}) is False + + def test_parentheses_without_cid_not_garbage(self): + assert _is_garbage({"text": "(hello:123)"}) is False + assert _is_garbage({"text": "cid:123"}) is False + + def test_partial_cid_not_garbage(self): + assert _is_garbage({"text": "(cid:)"}) is False + assert _is_garbage({"text": "(cid)"}) is False + + def test_cid_with_zero(self): + assert _is_garbage({"text": "(cid:0)"}) is True + + def test_cid_with_large_number(self): + assert _is_garbage({"text": "(cid:99999)"}) is True From 7d36f2bc3441b54c758ecb925c3bbd60c377ab93 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:57:32 +0800 Subject: [PATCH 190/565] Fix retrieval function when metadata_condtion is specified in retrieval API (#13473) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Fix https://github.com/infiniflow/ragflow/issues/13388 The following command returns empty when there is doc with the meta data ``` curl --request POST \ --url http://localhost:9222/api/v1/retrieval \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer ragflow-fO3mPFePfLgUYg8-9gjBVVXbvHqrvMPLGaW0P86PvAk' \ --data '{ "question": "any question", "dataset_ids": ["9bb4f0591b8811f18a4a84ba59049aa3"], "metadata_condition": { "logic": "and", "conditions": [ { "name": "character", "comparison_operator": "is", "value": "刘备" } ] } }' ``` When metadata_condtion is specified in the retrieval API, it is converted to doc_ids and doc_ids is passed to retrieval function. In retrieval funciton, when doc_ids is explicitly provided , we should bypass threshold. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/sdk/doc.py | 2 +- rag/nlp/search.py | 6 ++++++ .../test_doc_sdk_routes_unit.py | 7 ++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 7ed5d0cca4c..364a959cd86 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1682,7 +1682,7 @@ async def retrieval_test(tenant_id): if not doc_ids: metadata_condition = req.get("metadata_condition") if metadata_condition: - metas = DocMetadataService.get_meta_by_kbs(kb_ids) + metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) # If metadata_condition has conditions but no docs match, return empty result if not doc_ids and metadata_condition.get("conditions"): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 30f6f047de1..3cf70b6d949 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -438,6 +438,12 @@ async def retrieval( # When vector_similarity_weight is 0, similarity_threshold is not meaningful for term-only scores. post_threshold = 0.0 if vector_similarity_weight <= 0 else similarity_threshold + + # When doc_ids is explicitly provided (metadata or document filtering), bypass threshold + # User wants those specific documents regardless of their relevance score + if doc_ids: + post_threshold = 0.0 + valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= post_threshold] filtered_count = len(valid_idx) ranks["total"] = int(filtered_count) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 872563ccaeb..cd4c2e9d238 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -220,6 +220,11 @@ def encode(self, texts): tenant_llm_service_mod.TenantService = _StubTenantService tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + + class _StubLLMFactoriesService: + pass + + tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) # Mock LLMService @@ -993,7 +998,7 @@ def test_retrieval_validation_matrix(self, monkeypatch): "get_request_json", lambda: _AwaitableValue({"dataset_ids": ["ds-1"], "question": "q", "metadata_condition": {"logic": "and"}}), ) - monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _ids: []) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) res = _run(module.retrieval_test.__wrapped__("tenant-1")) assert "code" in res From 557318e203b1846e6277db8cdd514e868042823c Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Tue, 10 Mar 2026 13:44:17 +0800 Subject: [PATCH 191/565] Fix: chats_openai in none stream condition (#13495) ### What problem does this PR solve? Fix: chats_openai in none stream condition #13453 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/prompts/generator.py | 6 +++++- .../test_session_management/test_session_sdk_routes_unit.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index b9c2113d8f2..3037e0c4a30 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -40,6 +40,9 @@ def get_value(d, k1, k2): def chunks_format(reference): if not reference or not isinstance(reference, dict): return [] + raw_chunks = reference.get("chunks", []) + if not isinstance(raw_chunks, list): + return [] return [ { "id": get_value(chunk, "chunk_id", "id"), @@ -55,7 +58,8 @@ def chunks_format(reference): "term_similarity": chunk.get("term_similarity"), "doc_type": get_value(chunk, "doc_type_kwd", "doc_type"), } - for chunk in reference.get("chunks", []) + for chunk in raw_chunks + if isinstance(chunk, dict) ] diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 6852024db30..a83a3564d4c 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -762,7 +762,7 @@ async def fake_async_chat(_dia, _msg, _stream, **_kwargs): res = _run(inspect.unwrap(module.chat_completion_openai_like)("tenant-1", "chat-1")) assert res["choices"][0]["message"]["content"] == "world" - + @pytest.mark.p2 def test_agents_openai_compatibility_unit(monkeypatch): From 92248e462737c2296f39bbaa9e48cdb6ccd0790c Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:44:24 +0800 Subject: [PATCH 192/565] Fix delete_document_metadata (#13496) ### What problem does this PR solve? Avoid getting doc in function delete_document_metadata as the doc might have been removed. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/joint_services/user_account_service.py | 4 +-- api/db/services/doc_metadata_service.py | 25 +++++++++---------- api/db/services/document_service.py | 4 +-- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 7490c9bad22..f9a25b498c5 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -173,7 +173,7 @@ def delete_user_data(user_id: str) -> dict: if doc_ids: for doc in doc_ids: try: - DocMetadataService.delete_document_metadata(doc["id"], skip_empty_check=True) + DocMetadataService.delete_document_metadata(doc["id"], doc["kb_id"], tenant_id=None, skip_empty_check=True) except Exception as e: logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}") @@ -290,7 +290,7 @@ def delete_user_data(user_id: str) -> dict: done_msg += f"- Deleted {doc_delete_res} documents.\n" for doc in created_documents: try: - DocMetadataService.delete_document_metadata(doc['id']) + DocMetadataService.delete_document_metadata(doc['id'], doc['kb_id'], doc['tenant_id']) except Exception as e: logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}") # step2.1.6 update dataset doc&chunk&token cnt diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index a63e12dcee7..02a6b892743 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -474,7 +474,7 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: # For Infinity or as fallback: use delete+insert logging.debug(f"[update_document_metadata] Using delete+insert method for doc_id: {doc_id}") - cls.delete_document_metadata(doc_id, skip_empty_check=True) + cls.delete_document_metadata(doc_id, kb_id, tenant_id, skip_empty_check=True) return cls.insert_document_metadata(doc_id, processed_meta) except Exception as e: @@ -483,7 +483,7 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: @classmethod @DB.connection_context() - def delete_document_metadata(cls, doc_id: str, skip_empty_check: bool = False) -> bool: + def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None, skip_empty_check: bool = False) -> bool: """ Delete document metadata from ES/Infinity. Also drops the metadata table if it becomes empty (efficiently). @@ -491,6 +491,8 @@ def delete_document_metadata(cls, doc_id: str, skip_empty_check: bool = False) - Args: doc_id: Document ID + kb_id: Knowledge base ID + tenant_id: Tenant ID, if not provided, get it from kb_id skip_empty_check: If True, skip checking/dropping empty table (for bulk deletions) Returns: @@ -498,18 +500,15 @@ def delete_document_metadata(cls, doc_id: str, skip_empty_check: bool = False) - """ try: logging.debug(f"[METADATA DELETE] Starting metadata deletion for document: {doc_id}") - # Get document with tenant_id - doc_query = Document.select(Document, Knowledgebase.tenant_id).join( - Knowledgebase, on=(Knowledgebase.id == Document.kb_id) - ).where(Document.id == doc_id) - doc = doc_query.first() - if not doc: - logging.warning(f"Document {doc_id} not found for metadata deletion") - return False + # Get tenant_id from kb_id if not provided + if tenant_id is None: + kb = Knowledgebase.get_or_none(Knowledgebase.id == kb_id) + if not kb: + logging.warning(f"Knowledgebase {kb_id} not found for metadata deletion") + return False + tenant_id = kb.tenant_id - tenant_id = doc.knowledgebase.tenant_id - kb_id = doc.kb_id index_name = cls._get_doc_meta_index_name(tenant_id) logging.debug(f"[delete_document_metadata] Deleting doc_id: {doc_id}, kb_id: {kb_id}, index: {index_name}") @@ -1143,7 +1142,7 @@ def _apply_deletes(meta): logging.debug(f"[batch_update_metadata] Updating doc_id: {doc_id}, meta: {meta}") # If metadata is empty, delete the row entirely instead of keeping empty metadata if not meta: - cls.delete_document_metadata(doc_id, skip_empty_check=True) + cls.delete_document_metadata(doc_id, kb_id, tenant_id=None, skip_empty_check=True) else: cls.update_document_metadata(doc_id, meta) updated_docs += 1 diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 8809373a323..f5b2d9d5102 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -311,7 +311,7 @@ def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]): @classmethod @DB.connection_context() def get_all_doc_ids_by_kb_ids(cls, kb_ids): - fields = [cls.model.id] + fields = [cls.model.id, cls.model.kb_id] docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)) docs.order_by(cls.model.create_time.asc()) # maybe cause slow query by deep paginate, optimize later @@ -399,7 +399,7 @@ def remove_document(cls, doc, tenant_id): # Delete document metadata (non-critical, log and continue) try: - DocMetadataService.delete_document_metadata(doc.id) + DocMetadataService.delete_document_metadata(doc.id, doc.kb_id, tenant_id) except Exception as e: logging.warning(f"Failed to delete metadata for document {doc.id}: {e}") From 39f2e28e7cd323948172dd2a4b1eb492bfb9b09b Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:24:33 +0800 Subject: [PATCH 193/565] Fix missmatch docnm_kwd in raptor chunks (#13451) ### What problem does this PR solve? issue #13393 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/svr/task_executor.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 2b3ba461512..3a12f78266f 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -778,6 +778,14 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si res = [] tk_count = 0 max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) + doc_name_by_id = {} + for doc_id in set(doc_ids): + ok, source_doc = DocumentService.get_by_id(doc_id) + if not ok or not source_doc: + continue + source_name = getattr(source_doc, "name", "") + if source_name: + doc_name_by_id[doc_id] = source_name async def generate(chunks, did): nonlocal tk_count, res @@ -792,11 +800,12 @@ async def generate(chunks, did): ) original_length = len(chunks) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) + effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"]) doc = { "doc_id": did, "kb_id": [str(row["kb_id"])], - "docnm_kwd": row["name"], - "title_tks": rag_tokenizer.tokenize(row["name"]), + "docnm_kwd": effective_doc_name, + "title_tks": rag_tokenizer.tokenize(effective_doc_name), "raptor_kwd": "raptor" } if row["pagerank"]: @@ -1047,7 +1056,7 @@ async def do_handle_task(task): return # bind LLM for raptor - chat_model_config = get_model_config_by_type_and_name(task_dataset_id, LLMType.CHAT, kb_task_llm_id) + chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) # run RAPTOR async with kg_limiter: From c9b5f43dfe6a5292ff5605cba4d8afc25d8ff442 Mon Sep 17 00:00:00 2001 From: balibabu Date: Tue, 10 Mar 2026 14:25:27 +0800 Subject: [PATCH 194/565] Feat: Display release status in agent version history. (#13479) ### What problem does this PR solve? Feat: Display release status in agent version history. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: balibabu --- api/apps/canvas_app.py | 20 ++++++++++ api/db/db_models.py | 2 + api/db/services/canvas_service.py | 20 +++++++++- api/db/services/user_canvas_version.py | 39 +++++++++++++++---- web/src/components/home-card.tsx | 27 +++++++++++-- web/src/custom.d.ts | 2 + web/src/hooks/use-agent-request.ts | 2 +- web/src/interfaces/database/agent.ts | 3 ++ web/src/interfaces/database/flow.ts | 1 + web/src/locales/en.ts | 5 +++ web/src/locales/zh.ts | 7 +++- .../components/publish-confirm-dialog.tsx | 7 ++-- web/src/pages/agent/version-dialog/index.tsx | 29 ++++++++++++-- web/src/pages/agents/agent-card.tsx | 7 +++- 14 files changed, 148 insertions(+), 23 deletions(-) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 3c5c501fd5e..0c4abe0596b 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -99,6 +99,7 @@ async def save(): user_canvas_id=req["id"], dsl=req["dsl"], title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")), + release=req.get("release"), ) replica_ok = CanvasReplicaService.replace_for_set( canvas_id=req["id"], @@ -133,6 +134,25 @@ def get(canvas_id): ) except ValueError as e: return get_data_error_result(message=str(e)) + + # Get the last publication time (latest released version's update_time) + last_publish_time = None + versions = UserCanvasVersionService.list_by_canvas_id(canvas_id) + if versions: + released_versions = [v for v in versions if v.release] + if released_versions: + # Sort by update_time descending and get the latest + released_versions.sort(key=lambda x: x.update_time, reverse=True) + last_publish_time = released_versions[0].update_time + + # Add last_publish_time to response data + if isinstance(c, dict): + c["last_publish_time"] = last_publish_time + else: + # If c is a model object, convert to dict first + c = c.to_dict() + c["last_publish_time"] = last_publish_time + return get_json_result(data=c) diff --git a/api/db/db_models.py b/api/db/db_models.py index 6348a68a304..2e3824050cd 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1075,6 +1075,7 @@ class UserCanvasVersion(DataBaseModel): title = CharField(max_length=255, null=True, help_text="Canvas title") description = TextField(null=True, help_text="Canvas description") + release = BooleanField(null=False, help_text="is released", default=False, index=True) dsl = JSONField(null=True, default={}) class Meta: @@ -1538,6 +1539,7 @@ def migrate_db(): alter_db_add_column(migrator, "dialog", "tenant_rerank_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) alter_db_add_column(migrator, "memory", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) alter_db_add_column(migrator, "memory", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) + alter_db_add_column(migrator, "user_canvas_version", "release", BooleanField(null=False, help_text="is released", default=False, index=True)) logging.disable(logging.NOTSET) # this is after re-enabling logging to allow logging changed user emails migrate_add_unique_email(migrator) diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 838951a9a98..f6aa41c4a44 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -19,7 +19,7 @@ from uuid import uuid4 from agent.canvas import Canvas from api.db import CanvasCategory, TenantPermission -from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation +from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation, UserCanvasVersion from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService from common.misc_utils import get_uuid @@ -173,7 +173,23 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id, count = agents.count() if page_number and items_per_page: agents = agents.paginate(page_number, items_per_page) - return list(agents.dicts()), count + + agents_list = list(agents.dicts()) + + # Get latest release time for each canvas + if agents_list: + canvas_ids = [a['id'] for a in agents_list] + release_times = ( + UserCanvasVersion.select(UserCanvasVersion.user_canvas_id, fn.MAX(UserCanvasVersion.create_time).alias("release_time")) + .where((UserCanvasVersion.user_canvas_id.in_(canvas_ids)) & (UserCanvasVersion.release)) + .group_by(UserCanvasVersion.user_canvas_id) + ) + release_time_map = {r.user_canvas_id: r.release_time for r in release_times} + + for agent in agents_list: + agent['release_time'] = release_time_map.get(agent['id']) + + return agents_list, count @classmethod @DB.connection_context() diff --git a/api/db/services/user_canvas_version.py b/api/db/services/user_canvas_version.py index f8bd6dae01e..d2861d576e2 100644 --- a/api/db/services/user_canvas_version.py +++ b/api/db/services/user_canvas_version.py @@ -45,7 +45,8 @@ def list_by_canvas_id(cls, user_canvas_id): cls.model.create_date, cls.model.update_date, cls.model.user_canvas_id, - cls.model.update_time] + cls.model.update_time, + cls.model.release] ).where(cls.model.user_canvas_id == user_canvas_id) return user_canvas_version except DoesNotExist: @@ -74,14 +75,14 @@ def get_all_canvas_version_by_canvas_ids(cls, canvas_ids): @DB.connection_context() def delete_all_versions(cls, user_canvas_id): try: - user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by( - cls.model.create_time.desc()) - if user_canvas_version.count() > 20: - delete_ids = [] - for i in range(20, user_canvas_version.count()): - delete_ids.append(user_canvas_version[i].id) + # Only get unpublished versions (False or None), keep all released versions + unpublished = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id, (~cls.model.release) | (cls.model.release.is_null(True))).order_by(cls.model.create_time.desc()) + # Only delete old unpublished versions beyond the limit + if unpublished.count() > 20: + delete_ids = [v.id for v in unpublished[20:]] cls.delete_by_ids(delete_ids) + return True except DoesNotExist: return None @@ -90,12 +91,15 @@ def delete_all_versions(cls, user_canvas_id): @classmethod @DB.connection_context() - def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=None): + def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=None, release=None): """ Persist a canvas snapshot into version history. If the latest version has the same DSL content, update that version in place instead of creating a new row. + + Exception: If the latest version is released (release=True) and current save is not, + create a new version to protect the released version. """ try: normalized_dsl = cls._normalize_dsl(dsl) @@ -107,11 +111,28 @@ def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=Non ) if latest and cls._normalize_dsl(latest.dsl) == normalized_dsl: + # Protect released version: if latest is released and current is not, + # create a new version instead of updating + if latest.release and not release: + insert_data = {"user_canvas_id": user_canvas_id, "dsl": normalized_dsl} + if title is not None: + insert_data["title"] = title + if description is not None: + insert_data["description"] = description + if release is not None: + insert_data["release"] = release + cls.insert(**insert_data) + cls.delete_all_versions(user_canvas_id) + return None, True + + # Normal case: update existing version update_data = {"dsl": normalized_dsl} if title is not None: update_data["title"] = title if description is not None: update_data["description"] = description + if release is not None: + update_data["release"] = release cls.update_by_id(latest.id, update_data) cls.delete_all_versions(user_canvas_id) return latest.id, False @@ -121,6 +142,8 @@ def save_or_replace_latest(cls, user_canvas_id, dsl, title=None, description=Non insert_data["title"] = title if description is not None: insert_data["description"] = description + if release is not None: + insert_data["release"] = release cls.insert(**insert_data) cls.delete_all_versions(user_canvas_id) return None, True diff --git a/web/src/components/home-card.tsx b/web/src/components/home-card.tsx index 7320960b954..9e57a355f3d 100644 --- a/web/src/components/home-card.tsx +++ b/web/src/components/home-card.tsx @@ -2,6 +2,7 @@ import { RAGFlowAvatar } from '@/components/ragflow-avatar'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { formatDate } from '@/utils/date'; import { ReactNode } from 'react'; +import { useTranslation } from 'react-i18next'; interface IProps { data: { @@ -9,6 +10,7 @@ interface IProps { description?: string; avatar?: string; update_time?: string | number; + release_time?: number; }; onClick?: () => void; moreDropdown: React.ReactNode; @@ -16,6 +18,12 @@ interface IProps { icon?: React.ReactNode; testId?: string; } + +function Time({ time }: { time: string | number | undefined }) { + return ( +

{formatDate(time)}

+ ); +} export function HomeCard({ data, onClick, @@ -24,6 +32,8 @@ export function HomeCard({ icon, testId, }: IProps) { + const { t } = useTranslation(); + return (
-

- {formatDate(data.update_time)} -

+ {data.release_time ? ( +
+
+ {`${t('flow.lastSavedAt')}:`} + +
+
+ {`${t('flow.publishedAt')}:`} + +
+
+ ) : ( + + )} {sharedBadge}
diff --git a/web/src/custom.d.ts b/web/src/custom.d.ts index dafdf09f1b5..9495d4f3c3c 100644 --- a/web/src/custom.d.ts +++ b/web/src/custom.d.ts @@ -1,3 +1,5 @@ +type Nullable = T | null; + declare module '*.md' { const content: string; export default content; diff --git a/web/src/hooks/use-agent-request.ts b/web/src/hooks/use-agent-request.ts index df577530e27..5efd47975fd 100644 --- a/web/src/hooks/use-agent-request.ts +++ b/web/src/hooks/use-agent-request.ts @@ -529,7 +529,7 @@ export const useFetchInputForm = (componentId?: string) => { export const useFetchVersionList = () => { const { id } = useParams(); const { data, isFetching: loading } = useQuery< - Array<{ created_at: string; title: string; id: string }> + Array<{ created_at: string; title: string; id: string; release?: boolean }> >({ queryKey: [AgentApiAction.FetchVersionList], initialData: [], diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index 3d23cc81cb0..c9a08b7f204 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -77,6 +77,9 @@ export declare interface IFlow { nickname: string; operator_permission: number; canvas_category: string; + release?: boolean; + release_time?: number; + last_publish_time?: number; } export interface IFlowTemplate { diff --git a/web/src/interfaces/database/flow.ts b/web/src/interfaces/database/flow.ts index 96de7ad218a..aa266615084 100644 --- a/web/src/interfaces/database/flow.ts +++ b/web/src/interfaces/database/flow.ts @@ -41,6 +41,7 @@ export declare interface IFlow { user_id: string; permission: string; nickname: string; + release?: boolean; } export interface IFlowTemplate { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 26c253051bc..f561c41bfe6 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1500,6 +1500,7 @@ Example: Virtual Hosted Style`, other: 'Other', ingestionPipeline: 'Ingestion pipeline', agents: 'Agents', + publishedAt: 'Published at', days: 'Days', beginInput: 'Begin input', ref: 'Variable', @@ -1581,6 +1582,7 @@ Example: Virtual Hosted Style`, citeTip: 'citeTip', name: 'Name', nameMessage: 'Please input name', + lastSavedAt: 'Last saved at', description: 'Description', descriptionMessage: 'This is an agent for a specific task.', examples: 'Examples', @@ -2185,6 +2187,9 @@ This process aggregates variables from multiple branches into a single variable 'Write your SQL query here. You can use variables, raw SQL, or mix both using variable syntax.', frameworkPrompts: 'Framework', release: 'Publish', + production: 'Production', + productionTooltip: + 'This version is published to production. Access it via the API or the embedded page.', confirmPublish: 'Confirm Publish', publishDescription: 'You are about to publish this data pipeline.', linkedDataset: 'Linked dataset', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 3a624d42ebc..07215383414 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -919,8 +919,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 githubDescription: '连接 GitHub,可同步 Pull Request 与 Issue 内容用于检索。', airtableDescription: '连接 Airtable,同步指定工作区下指定表格中的文件。', - dingtalkAITableDescription: - '连接钉钉AI表格,同步指定表格中的记录。', + dingtalkAITableDescription: '连接钉钉AI表格,同步指定表格中的记录。', gitlabDescription: '连接 GitLab,同步仓库、Issue、合并请求(MR)及相关文档内容。', asanaDescription: '连接 Asana,同步工作区中的文件。', @@ -1250,6 +1249,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 consumerApp: '消费者应用', other: '其他', agents: '智能体', + publishedAt: '发布于', beginInput: '开始输入', seconds: '秒', ref: '引用变量', @@ -1345,6 +1345,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 cite: '引用', citeTip: '引用', nameMessage: '请输入名称', + lastSavedAt: '上次保存于', description: '描述', examples: '示例', to: '下一步', @@ -1888,6 +1889,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 '在此处编写您的 SQL 查询。您可以使用变量、原始 SQL,或使用变量语法混合使用两者。', frameworkPrompts: '框架', release: '发布', + production: '正式版', + productionTooltip: '此版本已发布到生产环境。可通过 API 或嵌入页面访问。', createFromBlank: '从空白创建', createFromTemplate: '从模板创建', importJsonFile: '导入 JSON 文件', diff --git a/web/src/pages/agent/components/publish-confirm-dialog.tsx b/web/src/pages/agent/components/publish-confirm-dialog.tsx index a058bef9cbf..85fe297515e 100644 --- a/web/src/pages/agent/components/publish-confirm-dialog.tsx +++ b/web/src/pages/agent/components/publish-confirm-dialog.tsx @@ -8,6 +8,7 @@ import { DialogTitle, DialogTrigger, } from '@/components/ui/dialog'; +import { IFlow } from '@/interfaces/database/agent'; import { Operator } from '@/pages/agent/constant'; import useGraphStore from '@/pages/agent/store'; import { formatDate } from '@/utils/date'; @@ -16,7 +17,7 @@ import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; interface PublishConfirmDialogProps { - agentDetail: { title: string; update_time?: number }; + agentDetail: IFlow; loading: boolean; onPublish: () => void; } @@ -42,8 +43,8 @@ export function PublishConfirmDialog({ }, [nodes]); const lastPublished = useMemo(() => { - if (agentDetail?.update_time) { - return formatDate(agentDetail.update_time); + if (agentDetail?.last_publish_time) { + return formatDate(agentDetail.last_publish_time); } return '-'; }, [agentDetail?.update_time]); diff --git a/web/src/pages/agent/version-dialog/index.tsx b/web/src/pages/agent/version-dialog/index.tsx index 6a4bdee9b31..a1d8d32b9c3 100644 --- a/web/src/pages/agent/version-dialog/index.tsx +++ b/web/src/pages/agent/version-dialog/index.tsx @@ -10,6 +10,7 @@ import { } from '@/components/ui/dialog'; import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; import { Spin } from '@/components/ui/spin'; +import { RAGFlowTooltip } from '@/components/ui/tooltip'; import { useClientPagination } from '@/hooks/logic-hooks/use-pagination'; import { useFetchVersion, @@ -25,6 +26,12 @@ import { ReactNode, useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { nodeTypes } from '../canvas'; +function Dot() { + return ( + + ); +} + export function VersionDialog({ hideModal, }: IModalProps & { initialName?: string; title?: ReactNode }) { @@ -58,7 +65,7 @@ export function VersionDialog({ return ( - + {t('flow.historyVersion')} @@ -78,7 +85,10 @@ export function VersionDialog({ })} onClick={handleClick(x.id)} > - {x.title} +
+ {x.title} + {x.release && } +
))} @@ -92,11 +102,24 @@ export function VersionDialog({
-
{agent?.title}
+
+ {agent?.title} + {agent?.release && ( + + + + )} +

Created: {formatDate(agent?.create_date)}

+ diff --git a/web/src/pages/agents/agent-card.tsx b/web/src/pages/agents/agent-card.tsx index 678d574b0d6..a9ee760802e 100644 --- a/web/src/pages/agents/agent-card.tsx +++ b/web/src/pages/agents/agent-card.tsx @@ -19,7 +19,12 @@ export function AgentCard({ data, showAgentRenameModal }: DatasetCardProps) { return ( From 933f0437899ec88fc366bf46aeefe0b0f4684960 Mon Sep 17 00:00:00 2001 From: Alexander Vostres Date: Tue, 10 Mar 2026 09:02:01 +0200 Subject: [PATCH 195/565] Fix "Coordinate lower is less than upper" error with MinerU (#13483) ### What problem does this PR solve? Fixes #6004 #7142 #11959 Unlike #9207 we actually normalize the coordinates here ### Type of change - [X] Bug Fix (non-breaking change which fixes an issue) --- deepdoc/parser/mineru_parser.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index d63a04c843e..e334995a210 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -340,6 +340,11 @@ def _line_tag(self, bx): pn = [bx["page_idx"] + 1] positions = bx.get("bbox", (0, 0, 0, 0)) x0, top, x1, bott = positions + # Normalize flipped coordinates (MinerU may report inverted bbox for flipped images) + if x0 > x1: + x0, x1 = x1, x0 + if top > bott: + top, bott = bott, top if hasattr(self, "page_images") and self.page_images and len(self.page_images) > bx["page_idx"]: page_width, page_height = self.page_images[bx["page_idx"]].size @@ -429,6 +434,12 @@ def crop(self, text, ZM=1, need_position=False): img0 = self.page_images[pns[0]] x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1])) + if x0 > x1: + x0, x1 = x1, x0 + if y0 > y1: + y0, y1 = y1, y0 + if x1 <= x0 or y1 <= y0: + continue crop0 = img0.crop((x0, y0, x1, y1)) imgs.append(crop0) if 0 < ii < len(poss) - 1: @@ -442,6 +453,13 @@ def crop(self, text, ZM=1, need_position=False): continue page = self.page_images[pn] x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1])) + if x0 > x1: + x0, x1 = x1, x0 + if y0 > y1: + y0, y1 = y1, y0 + if x1 <= x0 or y1 <= y0: + bottom -= page.size[1] + continue cimgp = page.crop((x0, y0, x1, y1)) imgs.append(cimgp) if 0 < ii < len(poss) - 1: From b74edabd295d16acd2d0b87b4dbf75ce31ffd7d4 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Tue, 10 Mar 2026 15:02:24 +0800 Subject: [PATCH 196/565] Refact: optimize confluence performance (#13497) ### What problem does this PR solve? Refact: optimize confluence performance #13494 ### Type of change - [x] Refactoring --- common/data_source/confluence_connector.py | 4 ++-- rag/svr/sync_data_source.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index d2494c3de74..58a7d2f82bd 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -1310,7 +1310,7 @@ def __init__( self._confluence_client: OnyxConfluence | None = None self._low_timeout_confluence_client: OnyxConfluence | None = None self._fetched_titles: set[str] = set() - self.allow_images = False + self.allow_images = True # Track document names to detect duplicates self._document_name_counts: dict[str, int] = {} self._document_name_paths: dict[str, list[str]] = {} @@ -1597,7 +1597,7 @@ def _convert_page_to_document( id=page_url, source=DocumentSource.CONFLUENCE, semantic_identifier=semantic_identifier, - extension=".html", # Confluence pages are HTML + extension=".txt", # Confluence pages are HTML blob=page_content.encode("utf-8"), # Encode page content as bytes doc_updated_at=datetime_from_string(page["version"]["when"]), size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 044c7484dff..87bb8af9b23 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -275,6 +275,7 @@ async def _generate(self, task: dict): space=space, page_id=page_id, index_recursively=index_recursively, + ) credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], From 2d3f01396e0e39cf228de0b043b28d90c5bb81bb Mon Sep 17 00:00:00 2001 From: balibabu Date: Tue, 10 Mar 2026 16:01:31 +0800 Subject: [PATCH 197/565] Fix: The number of deleted session prompts is displayed incorrectly. #13499 (#13500) ### What problem does this PR solve? Fix: The number of deleted session prompts is displayed incorrectly. #13499 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/src/locales/zh.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 07215383414..b020ff02a1e 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -851,7 +851,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 tocEnhance: 'PageIndex', tocEnhanceTip: `解析文档时生成了目录信息(见General方法的'启用目录抽取'),让大模型返回和用户问题相关的目录项,从而利用目录项拿到相关chunk,对这些chunk在排序中进行加权。这种方法来源于模仿人类查询书本中知识的行为逻辑`, batchDeleteSessions: '批量删除', - deleteSelectedConfirm: '删除选中的 {count} 个会话?', + deleteSelectedConfirm: '删除选中的 {{count}} 个会话?', }, setting: { Verify: '验证', From 5836df81a0595c36d36afe5b8cee47610ed390d9 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Tue, 10 Mar 2026 17:30:21 +0800 Subject: [PATCH 198/565] feat(admin): Implemented default administrator initialization and login functionality. (#13504) ### What problem does this PR solve? feat(admin): Implemented default administrator initialization and login functionality. Added support for default administrator configuration, including super user nickname, email, and password. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- cmd/admin_server.go | 21 ++-- internal/admin/handler.go | 54 +++++++--- internal/admin/password.go | 111 +++++++++++++++++++++ internal/admin/router.go | 5 +- internal/admin/service.go | 195 +++++++++++++++++++++++++++++-------- internal/service/user.go | 74 ++++++++++++++ 6 files changed, 390 insertions(+), 70 deletions(-) create mode 100644 internal/admin/password.go diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 1165ce219a5..f9b095c908d 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -33,21 +33,18 @@ import ( "ragflow/internal/admin" "ragflow/internal/dao" - "ragflow/internal/handler" "ragflow/internal/logger" "ragflow/internal/server" - "ragflow/internal/service" "ragflow/internal/utility" ) // AdminServer admin server type AdminServer struct { - router *admin.Router - handler *admin.Handler - service *admin.Service - userHandler *handler.UserHandler - engine *gin.Engine - port string + router *admin.Router + handler *admin.Handler + service *admin.Service + engine *gin.Engine + port string } func main() { @@ -112,8 +109,12 @@ func main() { } adminService := admin.NewService() - userService := service.NewUserService() - adminHandler := admin.NewHandler(adminService, userService) + adminHandler := admin.NewHandler(adminService) + + // Initialize default admin user + if err := adminService.InitDefaultAdmin(); err != nil { + logger.Error("Failed to initialize default admin user", err) + } // Initialize router r := admin.NewRouter(adminHandler) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 4a4e36bfb6d..a47628f6468 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -18,6 +18,7 @@ package admin import ( "errors" + "fmt" "net/http" "ragflow/internal/common" "ragflow/internal/server" @@ -31,9 +32,7 @@ import ( // Common errors var ( - ErrInvalidCredentials = errors.New("invalid credentials") - ErrUserNotFound = errors.New("user not found") - ErrInvalidToken = errors.New("invalid token") + ErrUserNotFound = errors.New("user not found") ) // Handler admin handler @@ -43,8 +42,11 @@ type Handler struct { } // NewHandler create admin handler -func NewHandler(service *Service, userService *service.UserService) *Handler { - return &Handler{service: service, userService: userService} +func NewHandler(svc *Service) *Handler { + return &Handler{ + service: svc, + userService: service.NewUserService(), + } } // SuccessResponse success response @@ -96,40 +98,58 @@ func (h *Handler) Ping(c *gin.Context) { successNoData(c, "PONG") } -// LoginHTTPRequest login request body -type LoginHTTPRequest struct { - Email string `json:"email" binding:"required"` - Password string `json:"password" binding:"required"` -} - // Login handle admin login +// @Summary Admin Login +// @Description Admin login verification using email, only superuser can login +// @Tags admin +// @Accept json +// @Produce json +// @Param request body service.EmailLoginRequest true "login info with email" +// @Success 200 {object} map[string]interface{} +// @Router /admin/login [post] func (h *Handler) Login(c *gin.Context) { var req service.EmailLoginRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + "code": common.CodeBadRequest, "message": err.Error(), }) return } + // Use userService.LoginByEmail with adminLogin=true + // This allows default admin account to login admin system user, code, err := h.userService.LoginByEmail(&req, true) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), }) return } + // Check if user is superuser (admin) + if user.IsSuperuser == nil || !*user.IsSuperuser { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeForbidden, + "message": "Only superuser can login admin system", + }) + return + } + variables := server.GetVariables() secretKey := variables.SecretKey authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to generate auth token: %s", err.Error()), + }) + return + } // Set Authorization header with access_token - if user.AccessToken != nil { - c.Header("Authorization", authToken) - } + c.Header("Authorization", authToken) // Set CORS headers c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "*") @@ -919,6 +939,7 @@ func (h *Handler) TestSandboxConnection(c *gin.Context) { } // AuthMiddleware JWT auth middleware +// Validates that the user is authenticated and is a superuser (admin) func (h *Handler) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { token := c.GetHeader("Authorization") @@ -935,6 +956,7 @@ func (h *Handler) AuthMiddleware() gin.HandlerFunc { "code": code, "message": "Invalid access token", }) + c.Abort() return } diff --git a/internal/admin/password.go b/internal/admin/password.go new file mode 100644 index 00000000000..ab81b169baa --- /dev/null +++ b/internal/admin/password.go @@ -0,0 +1,111 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "strconv" + "strings" + + "golang.org/x/crypto/pbkdf2" +) + +// CheckWerkzeugPassword verifies a password against a werkzeug password hash +// Format: pbkdf2:sha256:iterations$salt$hash +func CheckWerkzeugPassword(password, hashStr string) bool { + parts := strings.Split(hashStr, "$") + if len(parts) != 3 { + return false + } + + // Parse method (e.g., "pbkdf2:sha256:150000") + methodParts := strings.Split(parts[0], ":") + if len(methodParts) != 3 { + return false + } + + if methodParts[0] != "pbkdf2" { + return false + } + + var hashFunc func() hash.Hash + switch methodParts[1] { + case "sha256": + hashFunc = sha256.New + case "sha512": + // sha512 not supported in this implementation + return false + default: + return false + } + + iterations, err := strconv.Atoi(methodParts[2]) + if err != nil { + return false + } + + salt := parts[1] + expectedHash := parts[2] + + // Decode salt from base64 + saltBytes, err := base64.StdEncoding.DecodeString(salt) + if err != nil { + // Try hex encoding + saltBytes, err = hex.DecodeString(salt) + if err != nil { + return false + } + } + + // Generate hash using PBKDF2 + key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc) + computedHash := base64.StdEncoding.EncodeToString(key) + + return computedHash == expectedHash +} + +// IsWerkzeugHash checks if a hash is in werkzeug format +func IsWerkzeugHash(hashStr string) bool { + return strings.HasPrefix(hashStr, "pbkdf2:") +} + +// GenerateWerkzeugPasswordHash generates a werkzeug-compatible password hash +func GenerateWerkzeugPasswordHash(password string, iterations int) (string, error) { + if iterations == 0 { + iterations = 150000 + } + + // Generate random salt + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return "", err + } + + // Generate hash using PBKDF2-SHA256 + key := pbkdf2.Key([]byte(password), salt, iterations, 32, sha256.New) + + // Format: pbkdf2:sha256:iterations$base64(salt)$base64(hash) + saltB64 := base64.StdEncoding.EncodeToString(salt) + hashB64 := base64.StdEncoding.EncodeToString(key) + + return fmt.Sprintf("pbkdf2:sha256:%d$%s$%s", iterations, saltB64, hashB64), nil +} diff --git a/internal/admin/router.go b/internal/admin/router.go index 4e9dd213465..6e2239d4de6 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -18,14 +18,11 @@ package admin import ( "github.com/gin-gonic/gin" - - "ragflow/internal/handler" ) // Router admin router type Router struct { - handler *Handler - userHandler *handler.UserHandler + handler *Handler } // NewRouter create admin router diff --git a/internal/admin/service.go b/internal/admin/service.go index 6e2e1b346cc..79fba699551 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -19,6 +19,7 @@ package admin import ( "crypto/rand" "crypto/tls" + "encoding/base64" "encoding/hex" "errors" "fmt" @@ -35,6 +36,13 @@ import ( "time" ) +// Service errors +var ( + ErrInvalidToken = errors.New("invalid token") + ErrNotAdmin = errors.New("user is not admin") + ErrUserInactive = errors.New("user is inactive") +) + // Service admin service layer type Service struct { userDAO *dao.UserDAO @@ -47,55 +55,32 @@ func NewService() *Service { } } -// LoginRequest login request -type LoginRequest struct { - Email string - Password string -} - -// LoginResponse login response -type LoginResponse struct { - Token string - UserID string - Email string - Nickname string +// Logout user logout +func (s *Service) Logout(user interface{}) error { + // Invalidate token by setting it to INVALID_ prefix + if u, ok := user.(*model.User); ok { + invalidToken := "INVALID_" + generateRandomHex(16) + return s.userDAO.UpdateAccessToken(u, invalidToken) + } + return nil } -// Login admin login -func (s *Service) Login(req *LoginRequest) (*LoginResponse, error) { - // Get user by email - user, err := s.userDAO.GetByEmail(req.Email) +// GetUserByToken get user by access token +func (s *Service) GetUserByToken(token string) (*model.User, error) { + user, err := s.userDAO.GetByAccessToken(token) if err != nil { - return nil, ErrInvalidCredentials + return nil, ErrInvalidToken } - // Check if user is active - if user.IsActive != "1" { - return nil, errors.New("user is not active") + if user.IsSuperuser == nil || !*user.IsSuperuser { + return nil, ErrNotAdmin } - // Generate access token - token := utility.GenerateToken() - if err := s.userDAO.UpdateAccessToken(user, token); err != nil { - return nil, err + if user.IsActive != "1" { + return nil, fmt.Errorf("user inactive") } - return &LoginResponse{ - Token: token, - UserID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - }, nil -} - -// Logout user logout -func (s *Service) Logout(user interface{}) error { - // Invalidate token by setting it to INVALID_ prefix - if u, ok := user.(*model.User); ok { - invalidToken := "INVALID_" + generateRandomHex(16) - return s.userDAO.UpdateAccessToken(u, invalidToken) - } - return nil + return user, nil } // generateRandomHex generate random hex string @@ -780,3 +765,133 @@ func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error { GlobalServerStatusStore.UpdateStatus(msg.ServerName, status) return nil } + +// InitDefaultAdmin initialize default admin user +// This matches Python's init_default_admin behavior +func (s *Service) InitDefaultAdmin() error { + // Default superuser settings (matching Python's DEFAULT_SUPERUSER_* defaults) + defaultNickname := "admin" + defaultEmail := "admin@ragflow.io" + defaultPassword := "admin" + + // Query superusers + var users []*model.User + err := dao.DB.Where("is_superuser = ? AND status = ?", true, "1").Find(&users).Error + if err != nil { + return fmt.Errorf("failed to query superusers: %w", err) + } + + if len(users) == 0 { + now := time.Now().Unix() + nowDate := time.Now() + userID := utility.GenerateToken() + accessToken := utility.GenerateToken() + status := "1" + loginChannel := "password" + isSuperuser := true + + // Python: password = encode_to_base64(password) = base64.b64encode(password) + // Then: generate_password_hash(base64_password) creates werkzeug hash + password := base64.StdEncoding.EncodeToString([]byte(defaultPassword)) + hashedPassword, err := GenerateWerkzeugPasswordHash(password, 150000) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + + user := &model.User{ + ID: userID, + Email: defaultEmail, + Nickname: defaultNickname, + Password: &hashedPassword, + AccessToken: &accessToken, + Status: &status, + IsActive: "1", + IsAuthenticated: "1", + IsAnonymous: "0", + LoginChannel: &loginChannel, + IsSuperuser: &isSuperuser, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + if err := dao.DB.Create(user).Error; err != nil { + return fmt.Errorf("can't init admin: %w", err) + } + + if err := s.addTenantForAdmin(userID, defaultNickname); err != nil { + return fmt.Errorf("failed to add tenant for admin: %w", err) + } + + return nil + } + + for _, user := range users { + if user.IsActive != "1" { + return fmt.Errorf("no active admin. Please update 'is_active' in db manually") + } + } + + for _, user := range users { + if user.Email == defaultEmail { + // Check if tenant exists + var count int64 + dao.DB.Model(&model.UserTenant{}).Where("user_id = ? AND status = ?", user.ID, "1").Count(&count) + if count == 0 { + nickname := defaultNickname + if user.Nickname != "" { + nickname = user.Nickname + } + if err := s.addTenantForAdmin(user.ID, nickname); err != nil { + return err + } + } + break + } + } + + return nil +} + +// addTenantForAdmin add tenant for admin user +func (s *Service) addTenantForAdmin(userID, nickname string) error { + now := time.Now().Unix() + nowDate := time.Now() + status := "1" + role := "owner" + tenantName := nickname + "'s Kingdom" + + tenant := &model.Tenant{ + ID: userID, + Name: &tenantName, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + if err := dao.DB.Create(tenant).Error; err != nil { + return err + } + + userTenant := &model.UserTenant{ + TenantID: userID, + UserID: userID, + InvitedBy: userID, + Role: role, + Status: &status, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + return dao.DB.Create(userTenant).Error +} diff --git a/internal/service/user.go b/internal/service/user.go index a87260b6805..bf3aff79525 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -18,12 +18,15 @@ package service import ( "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/pem" "errors" "fmt" + "hash" "os" "ragflow/internal/common" "ragflow/internal/server" @@ -32,6 +35,7 @@ import ( "strings" "time" + "golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/scrypt" "ragflow/internal/dao" @@ -408,7 +412,77 @@ func (s *UserService) HashPassword(password string) (string, error) { } // VerifyPassword verify password +// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format func (s *UserService) VerifyPassword(hashedPassword, password string) bool { + // Check if it's pbkdf2 format (werkzeug) + if strings.HasPrefix(hashedPassword, "pbkdf2:") { + return s.verifyPBKDF2Password(hashedPassword, password) + } + + // Check if it's scrypt format + if strings.HasPrefix(hashedPassword, "scrypt:") { + return s.verifyScryptPassword(hashedPassword, password) + } + + return false +} + +// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format) +// Format: pbkdf2:sha256:iterations$salt$hash +func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool { + parts := strings.Split(hashedPassword, "$") + if len(parts) != 3 { + return false + } + + // Parse method (e.g., "pbkdf2:sha256:150000") + methodParts := strings.Split(parts[0], ":") + if len(methodParts) != 3 { + return false + } + + if methodParts[0] != "pbkdf2" { + return false + } + + var hashFunc func() hash.Hash + switch methodParts[1] { + case "sha256": + hashFunc = sha256.New + case "sha512": + hashFunc = sha512.New + default: + return false + } + + iterations, err := strconv.Atoi(methodParts[2]) + if err != nil { + return false + } + + salt := parts[1] + expectedHash := parts[2] + + // Decode salt from base64 + saltBytes, err := base64.StdEncoding.DecodeString(salt) + if err != nil { + // Try hex encoding + saltBytes, err = hex.DecodeString(salt) + if err != nil { + return false + } + } + + // Generate hash using PBKDF2 + key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc) + computedHash := base64.StdEncoding.EncodeToString(key) + + return computedHash == expectedHash +} + +// verifyScryptPassword verifies password using scrypt format +// Format: scrypt:n:r:p$salt$hash +func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool { // Parse hash format: scrypt:n:r:p$salt$hash parts := strings.Split(hashedPassword, "$") if len(parts) != 3 { From 5a52ef3f2641f948efd25e4940c452317849f180 Mon Sep 17 00:00:00 2001 From: Liu An Date: Tue, 10 Mar 2026 17:31:20 +0800 Subject: [PATCH 199/565] Fix: bin directory cannot be copied to docker image introduced by #13444 (#13502) ### What problem does this PR solve? bin directory cannot be copied to docker image introduced by ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- .gitignore | 3 ++- Dockerfile | 8 +------- bin/.gitkeep | 0 3 files changed, 3 insertions(+), 8 deletions(-) create mode 100644 bin/.gitkeep diff --git a/.gitignore b/.gitignore index 58fbda13bbc..906c13dbfa4 100644 --- a/.gitignore +++ b/.gitignore @@ -229,4 +229,5 @@ internal/cpp/cmake-build-debug/ .trae/ # Go server build output -bin/ +bin/* +!bin/.gitkeep diff --git a/Dockerfile b/Dockerfile index 071efdfc33b..ee19086b3aa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -212,13 +212,7 @@ COPY pyproject.toml uv.lock ./ COPY mcp mcp COPY common common COPY memory memory - -RUN if [ -d bin ]; then \ - cp -r bin ./; \ - echo "✓ bin copied"; \ - else \ - echo "✗ bin ignored"; \ - fi +COPY bin bin COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/bin/.gitkeep b/bin/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d From 875ebfb286e25179d4511e4a8b8b290f8749f981 Mon Sep 17 00:00:00 2001 From: Heyang Wang Date: Tue, 10 Mar 2026 18:05:45 +0800 Subject: [PATCH 200/565] Feat: Support get aggregated parsing status to dataset via the API (#13481) ### What problem does this PR solve? Support getting aggregated parsing status to dataset via the API Issue: #12810 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: heyang.why --- api/apps/sdk/dataset.py | 136 +++--- api/db/services/document_service.py | 407 ++++++++---------- api/utils/validation_utils.py | 7 +- docs/references/http_api_reference.md | 61 ++- docs/references/python_api_reference.md | 20 +- example/http/dataset_example.sh | 6 + ...est_document_service_get_parsing_status.py | 326 ++++++++++++++ 7 files changed, 654 insertions(+), 309 deletions(-) create mode 100644 test/unit_test/api/db/services/test_document_service_get_parsing_status.py diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index caa75ec02b8..58f0442b61a 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -138,12 +138,7 @@ async def create(tenant_id): parser_cfg["metadata"] = fields parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) req["parser_config"] = parser_cfg - e, req = KnowledgebaseService.create_with_name( - name = req.pop("name", None), - tenant_id = tenant_id, - parser_id = req.pop("parser_id", None), - **req - ) + e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req) if not e: return req @@ -159,19 +154,19 @@ async def create(tenant_id): if not ok: return err - try: - if not KnowledgebaseService.save(**req): - return get_error_data_result() - ok, k = KnowledgebaseService.get_by_id(req["id"]) - if not ok: - return get_error_data_result(message="Dataset created failed") - response_data = remap_dictionary_keys(k.to_dict()) - return get_result(data=response_data) + if not KnowledgebaseService.save(**req): + return get_error_data_result() + ok, k = KnowledgebaseService.get_by_id(req["id"]) + if not ok: + return get_error_data_result(message="Dataset created failed") + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) except Exception as e: logging.exception(e) return get_error_data_result(message="Database operation failed") + @manager.route("/datasets", methods=["DELETE"]) # noqa: F821 @token_required async def delete(tenant_id): @@ -227,8 +222,7 @@ async def delete(tenant_id): continue kb_id_instance_pairs.append((kb_id, kb)) if len(error_kb_ids) > 0: - return get_error_permission_result( - message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") errors = [] success_count = 0 @@ -245,12 +239,12 @@ async def delete(tenant_id): ] ) File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) # Drop index for this dataset try: from rag.nlp import search + idxnm = search.index_name(kb.tenant_id) settings.docStoreConn.delete_idx(idxnm, kb_id) except Exception as e: @@ -352,8 +346,7 @@ async def update(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") # Map auto_metadata_config into parser_config if present auto_meta = req.pop("auto_metadata_config", None) @@ -384,8 +377,7 @@ async def update(tenant_id, dataset_id): del req["parser_config"] if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, - status=StatusEnum.VALID.value) + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) if exists: return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") @@ -393,8 +385,7 @@ async def update(tenant_id, dataset_id): if not req["embd_id"]: req["embd_id"] = kb.embd_id if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return get_error_data_result( - message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") ok, err = verify_embedding_availability(req["embd_id"], tenant_id) if not ok: return err @@ -404,12 +395,10 @@ async def update(tenant_id, dataset_id): return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") if req["pagerank"] > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") @@ -470,6 +459,15 @@ def list_datasets(tenant_id): required: false default: true description: Order in descending. + - in: query + name: include_parsing_status + type: boolean + required: false + default: false + description: | + Whether to include document parsing status counts in the response. + When true, each dataset object will include: unstart_count, running_count, + cancel_count, done_count, and fail_count. - in: header name: Authorization type: string @@ -487,17 +485,18 @@ def list_datasets(tenant_id): if err is not None: return get_error_argument_result(err) + include_parsing_status = args.get("include_parsing_status", False) + try: kb_id = request.args.get("id") name = args.get("name") + # check whether user has permission for the dataset with specified id if kb_id: - kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) - - if not kbs: + if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id): return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") + # check whether user has permission for the dataset with specified name if name: - kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) - if not kbs: + if not KnowledgebaseService.get_kb_by_name(name, tenant_id): return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) @@ -512,9 +511,17 @@ def list_datasets(tenant_id): name, ) + parsing_status_map = {} + if include_parsing_status and kbs: + kb_ids = [kb["id"] for kb in kbs] + parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids) + response_data_list = [] for kb in kbs: - response_data_list.append(remap_dictionary_keys(kb)) + data = remap_dictionary_keys(kb) + if include_parsing_status: + data.update(parsing_status_map.get(kb["id"], {})) + response_data_list.append(data) return get_result(data=response_data_list, total=total) except OperationalError as e: logging.exception(e) @@ -530,9 +537,7 @@ def get_auto_metadata(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - ) + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") parser_cfg = kb.parser_config or {} metadata = parser_cfg.get("metadata") or [] @@ -570,9 +575,7 @@ async def update_auto_metadata(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - ) + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") parser_cfg = kb.parser_config or {} fields = [] @@ -598,20 +601,13 @@ async def update_auto_metadata(tenant_id, dataset_id): return get_error_data_result(message="Database operation failed") -@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 +@manager.route("/datasets//knowledge_graph", methods=["GET"]) # noqa: F821 @token_required async def knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) _, kb = KnowledgebaseService.get_by_id(dataset_id) - req = { - "kb_id": [dataset_id], - "knowledge_graph_kwd": ["graph"] - } + req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]} obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): @@ -633,39 +629,29 @@ async def knowledge_graph(tenant_id, dataset_id): obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] if "edges" in obj["graph"]: node_id_set = {o["id"] for o in obj["graph"]["nodes"]} - filtered_edges = [o for o in obj["graph"]["edges"] if - o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] return get_result(data=obj) -@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 +@manager.route("/datasets//knowledge_graph", methods=["DELETE"]) # noqa: F821 @token_required def delete_knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) _, kb = KnowledgebaseService.get_by_id(dataset_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, - search.index_name(kb.tenant_id), dataset_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) @manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 @token_required -def run_graphrag(tenant_id,dataset_id): +def run_graphrag(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -707,15 +693,11 @@ def run_graphrag(tenant_id,dataset_id): @manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 @token_required -def trace_graphrag(tenant_id,dataset_id): +def trace_graphrag(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -734,15 +716,11 @@ def trace_graphrag(tenant_id,dataset_id): @manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 @token_required -def run_raptor(tenant_id,dataset_id): +def run_raptor(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -784,7 +762,7 @@ def run_raptor(tenant_id,dataset_id): @manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 @token_required -def trace_raptor(tenant_id,dataset_id): +def trace_raptor(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index f5b2d9d5102..fa4fc27ecaa 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -28,8 +28,7 @@ from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory -from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ - User +from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User from api.db.db_utils import bulk_insert_into_db from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -78,24 +77,21 @@ def get_cls_model_fields(cls): @classmethod @DB.connection_context() - def get_list(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None): + def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None): fields = cls.get_cls_model_fields() - docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\ - .join(File, on = (File.id == File2Document.file_id))\ - .join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\ + docs = ( + cls.model.select(*[*fields, UserCanvas.title]) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(File, on=(File.id == File2Document.file_id)) + .join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER) .where(cls.model.kb_id == kb_id) + ) if id: - docs = docs.where( - cls.model.id == id) + docs = docs.where(cls.model.id == id) if name: - docs = docs.where( - cls.model.name == name - ) + docs = docs.where(cls.model.name == name) if keywords: - docs = docs.where( - fn.LOWER(cls.model.name).contains(keywords.lower()) - ) + docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) if doc_ids: docs = docs.where(cls.model.id.in_(doc_ids)) if suffix: @@ -120,6 +116,7 @@ def get_list(cls, kb_id, page_number, items_per_page, @DB.connection_context() def check_doc_health(cls, tenant_id: str, filename): import os + MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id): raise RuntimeError("Exceed the maximum file number of a free user!") @@ -211,13 +208,14 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): """ fields = cls.get_cls_model_fields() if keywords: - query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) + query = ( + cls.model.select(*fields) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(File, on=(File.id == File2Document.file_id)) + .where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) ) else: - query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) - + query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) if run_status: query = query.where(cls.model.run.in_(run_status)) @@ -272,14 +270,60 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): "metadata": metadata_counter, }, total + @classmethod + @DB.connection_context() + def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]: + """Return aggregated document parsing status counts grouped by dataset (kb_id). + + For each kb_id, counts documents in each run-status bucket: + - unstart_count (run == "0") + - running_count (run == "1") + - cancel_count (run == "2") + - done_count (run == "3") + - fail_count (run == "4") + + Returns a dict keyed by kb_id, e.g. + {"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...} + """ + if not kb_ids: + return {} + + status_field_map = { + TaskStatus.UNSTART.value: "unstart_count", + TaskStatus.RUNNING.value: "running_count", + TaskStatus.CANCEL.value: "cancel_count", + TaskStatus.DONE.value: "done_count", + TaskStatus.FAIL.value: "fail_count", + } + + empty_status = {v: 0 for v in status_field_map.values()} + result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids} + + rows = ( + cls.model.select( + cls.model.kb_id, + cls.model.run, + fn.COUNT(cls.model.id).alias("cnt"), + ) + .where(cls.model.kb_id.in_(kb_ids)) + .group_by(cls.model.kb_id, cls.model.run) + .dicts() + ) + + for row in rows: + kb_id = row["kb_id"] + run_val = str(row["run"]) + field_name = status_field_map.get(run_val) + if field_name and kb_id in result: + result[kb_id][field_name] = int(row["cnt"]) + + return result + @classmethod @DB.connection_context() def count_by_kb_id(cls, kb_id, keywords, run_status, types): if keywords: - docs = cls.model.select().where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) @@ -295,9 +339,7 @@ def count_by_kb_id(cls, kb_id, keywords, run_status, types): @classmethod @DB.connection_context() def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]): - query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where( - cls.model.kb_id == kb_id - ) + query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id) if keywords: query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) @@ -329,12 +371,8 @@ def get_all_doc_ids_by_kb_ids(cls, kb_ids): @classmethod @DB.connection_context() def get_all_docs_by_creator_id(cls, creator_id): - fields = [ - cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id - ] - docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where( - cls.model.created_by == creator_id - ) + fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id] + docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id) docs.order_by(cls.model.create_time.asc()) # maybe cause slow query by deep paginate, optimize later offset, limit = 0, 100 @@ -361,6 +399,7 @@ def insert(cls, doc): @DB.connection_context() def remove_document(cls, doc, tenant_id): from api.db.services.task_service import TaskService, cancel_all_task_of + if not cls.delete_document_and_update_kb_counts(doc.id): return True @@ -406,17 +445,22 @@ def remove_document(cls, doc, tenant_id): # Cleanup knowledge graph references (non-critical, log and continue) try: graph_source = settings.docStoreConn.get_fields( - settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] + settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), + ["source_id"], ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, - {"remove": {"source_id": doc.id}}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, - {"removed_kwd": "Y"}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, - search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, + {"remove": {"source_id": doc.id}}, + search.index_name(tenant_id), + doc.kb_id, + ) + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, + search.index_name(tenant_id), + doc.kb_id, + ) except Exception as e: logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}") @@ -428,9 +472,7 @@ def delete_chunk_images(cls, doc, tenant_id): page = 0 page_size = 1000 while True: - chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), - page * page_size, page_size, search.index_name(tenant_id), - [doc.kb_id]) + chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) chunk_ids = settings.docStoreConn.get_doc_ids(chunks) if not chunk_ids: break @@ -455,71 +497,61 @@ def get_newly_uploaded(cls): Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, - cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ + cls.model.update_time, + ] + docs = ( + cls.model.select(*fields) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= current_timestamp() - 1000 * 600, - cls.model.run == TaskStatus.RUNNING.value) \ + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= current_timestamp() - 1000 * 600, + cls.model.run == TaskStatus.RUNNING.value, + ) .order_by(cls.model.update_time.asc()) + ) return list(docs.dicts()) @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, - cls.model.run, cls.model.parser_id] - unfinished_task_query = Task.select(Task.doc_id).where( - (Task.progress >= 0) & (Task.progress < 1) - ) + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] + unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1)) - docs = cls.model.select(*fields) \ - .where( + docs = cls.model.select(*fields).where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)), - (((cls.model.progress < 1) & (cls.model.progress > 0)) | - (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap + (((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))), + ) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap return list(docs.dicts()) @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duration=cls.model.process_duration + duration).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: logging.warning("Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num + - token_num, - chunk_num=Knowledgebase.chunk_num + - chunk_num).where( - Knowledgebase.id == kb_id).execute() + num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = cls.model.update(token_num=cls.model.token_num - token_num, - chunk_num=cls.model.chunk_num - chunk_num, - process_duration=cls.model.process_duration + duration).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - token_num, - chunk_num=Knowledgebase.chunk_num - - chunk_num - ).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @@ -551,17 +583,13 @@ def clear_chunk_num(cls, doc_id): doc = cls.model.get_by_id(doc_id) assert doc, "Can't fine document in database." - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - doc.token_num, - chunk_num=Knowledgebase.chunk_num - - doc.chunk_num, - doc_num=Knowledgebase.doc_num - 1 - ).where( - Knowledgebase.id == doc.kb_id).execute() + num = ( + Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1) + .where(Knowledgebase.id == doc.kb_id) + .execute() + ) return num - @classmethod @DB.connection_context() def clear_chunk_num_when_rerun(cls, doc_id): @@ -578,15 +606,10 @@ def clear_chunk_num_when_rerun(cls, doc_id): ) return num - @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -604,11 +627,7 @@ def get_knowledgebase_id(cls, doc_id): @classmethod @DB.connection_context() def get_tenant_id_by_name(cls, name): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -617,12 +636,13 @@ def get_tenant_id_by_name(cls, name): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) + .where(cls.model.id == doc_id, UserTenant.user_id == user_id) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -631,18 +651,13 @@ def accessible(cls, doc_id, user_id): @classmethod @DB.connection_context() def accessible4deletion(cls, doc_id, user_id): - docs = cls.model.select(cls.model.id - ).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join( - UserTenant, on=( - (UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)) - ).where( - cls.model.id == doc_id, - UserTenant.status == StatusEnum.VALID.value, - ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)) - ).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))) + .where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -651,11 +666,7 @@ def accessible4deletion(cls, doc_id, user_id): @classmethod @DB.connection_context() def get_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -664,11 +675,9 @@ def get_embd_id(cls, doc_id): @classmethod @DB.connection_context() def get_tenant_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = ( + cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + ) docs = docs.dicts() if not docs: return None @@ -705,8 +714,7 @@ def get_chunking_config(cls, doc_id): @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): fields = [cls.model.id] - doc_id = cls.model.select(*fields) \ - .where(cls.model.name == doc_name) + doc_id = cls.model.select(*fields).where(cls.model.name == doc_name) doc_id = doc_id.dicts() if not doc_id: return None @@ -725,8 +733,7 @@ def get_doc_ids_by_doc_names(cls, doc_names): @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail] - return list(cls.model.select( - *fields).where(cls.model.id.in_(docids)).dicts()) + return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() @@ -755,9 +762,7 @@ def dfs_update(old, new): @classmethod @DB.connection_context() def get_doc_count(cls, tenant_id): - docs = cls.model.select(cls.model.id).join(Knowledgebase, - on=(Knowledgebase.id == cls.model.kb_id)).where( - Knowledgebase.tenant_id == tenant_id) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id) return len(docs) @classmethod @@ -768,7 +773,7 @@ def begin2parse(cls, doc_id, keep_progress=False): "process_begin_at": get_format_time(), } if not keep_progress: - info["progress"] = random.random() * 1 / 100. + info["progress"] = random.random() * 1 / 100.0 info["run"] = TaskStatus.RUNNING.value # keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks @@ -781,19 +786,17 @@ def update_progress(cls): cls._sync_progress(docs) - @classmethod @DB.connection_context() - def update_progress_immediately(cls, docs:list[dict]): + def update_progress_immediately(cls, docs: list[dict]): if not docs: return cls._sync_progress(docs) - @classmethod @DB.connection_context() - def _sync_progress(cls, docs:list[dict]): + def _sync_progress(cls, docs: list[dict]): from api.db.services.task_service import TaskService for d in docs: @@ -841,27 +844,18 @@ def _sync_progress(cls, docs:list[dict]): # fallback cls.update_by_id(d["id"], {"process_begin_at": begin_at}) - info = { - "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), - "run": status} + info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status} if prg != 0 and not freeze_progress: info["progress"] = prg if msg: info["progress_msg"] = msg if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"): - info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) + info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority) else: - info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) + info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority) info["update_time"] = current_timestamp() info["update_date"] = get_format_time() - ( - cls.model.update(info) - .where( - (cls.model.id == d["id"]) - & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)) - ) - .execute() - ) + (cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute()) except Exception as e: if str(e).find("'0'") < 0: logging.exception("fetch task exception") @@ -875,7 +869,7 @@ def get_kb_doc_count(cls, kb_id): @DB.connection_context() def get_all_kb_doc_count(cls): result = {} - rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id) + rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id) for row in rows: result[row.kb_id] = row.count return result @@ -890,33 +884,19 @@ def do_cancel(cls, doc_id): pass return False - @classmethod @DB.connection_context() def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: # cancelled: run == "2" - cancelled = ( - cls.model.select(fn.COUNT(1)) - .where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)) - .scalar() - ) - downloaded = ( - cls.model.select(fn.COUNT(1)) - .where( - cls.model.kb_id == kb_id, - cls.model.source_type != "local" - ) - .scalar() - ) + cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar() + downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar() row = ( cls.model.select( # finished: progress == 1 fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"), - # failed: progress == -1 fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"), - # processing: 0 <= progress < 1 fn.COALESCE( fn.SUM( @@ -931,24 +911,15 @@ def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: 0, ).alias("processing"), ) - .where( - (cls.model.kb_id == kb_id) - & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)) - ) + .where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))) .dicts() .get() ) - return { - "processing": int(row["processing"]), - "finished": int(row["finished"]), - "failed": int(row["failed"]), - "cancelled": int(cancelled), - "downloaded": int(downloaded) - } + return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)} @classmethod - def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict): + def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict): from api.db.services.task_service import queue_dataflow, queue_tasks from api.db.services.file2document_service import File2DocumentService @@ -990,7 +961,7 @@ def new_task(): "from_page": 100000000, "to_page": 100000000, "task_type": ty, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, + "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } @@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): e, dia = DialogService.get_by_id(conv.dialog_id) if not dia.kb_ids: - raise LookupError("No dataset associated with this conversation. " - "Please add a dataset before uploading documents") + raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents") kb_id = dia.kb_ids[0] e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): def dummy(prog=None, msg=""): pass - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } + FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0} exe = ThreadPoolExecutor(max_workers=12) threads = [] @@ -1063,22 +1028,12 @@ def dummy(prog=None, msg=""): for d, blob in files: doc_nm[d["id"]] = d["name"] for d, blob in files: - kwargs = { - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": kb.tenant_id, - "lang": kb.language - } + kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language} threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) for (docinfo, _), th in zip(files, threads): docs = [] - doc = { - "doc_id": docinfo["id"], - "kb_id": [kb.id] - } + doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]} for ck in th.result(): d = deepcopy(doc) d.update(ck) @@ -1093,7 +1048,7 @@ def dummy(prog=None, msg=""): if isinstance(d["image"], bytes): output_buffer = BytesIO(d["image"]) else: - d["image"].save(output_buffer, format='JPEG') + d["image"].save(output_buffer, format="JPEG") settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(kb.id, d["id"]) @@ -1110,9 +1065,9 @@ def embedding(doc_id, cnts, batch_size=16): nonlocal embd_mdl, chunk_counts, token_counts vectors = [] for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i: i + batch_size]) + vts, c = embd_mdl.encode(cnts[i : i + batch_size]) vectors.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i:i + batch_size]) + chunk_counts[doc_id] += len(cnts[i : i + batch_size]) token_counts[doc_id] += c return vectors @@ -1127,22 +1082,25 @@ def embedding(doc_id, cnts, batch_size=16): if parser_ids[doc_id] != ParserType.PICTURE.value: from rag.graphrag.general.mind_map_extractor import MindMapExtractor + mindmap = MindMapExtractor(llm_bdl) try: mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) - cks.append({ - "id": get_uuid(), - "doc_id": doc_id, - "kb_id": [kb.id], - "docnm_kwd": doc_nm[doc_id], - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), - "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map" - }) + cks.append( + { + "id": get_uuid(), + "doc_id": doc_id, + "kb_id": [kb.id], + "docnm_kwd": doc_nm[doc_id], + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), + "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), + "content_with_weight": mind_map, + "knowledge_graph_kwd": "mind_map", + } + ) except Exception: logging.exception("Mind map generation error") @@ -1156,9 +1114,8 @@ def embedding(doc_id, cnts, batch_size=16): if not settings.docStoreConn.index_exist(idxnm, kb_id): settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id) try_create_idx = False - settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id) - DocumentService.increment_chunk_num( - doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) + DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 9e0b39aae2a..5864e6b4d2a 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -34,7 +34,9 @@ from api.constants import DATASET_NAME_LIMIT -async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: +async def validate_and_parse_json_request( + request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False +) -> tuple[dict[str, Any] | None, str | None]: """ Validates and parses JSON requests through a multi-stage validation pipeline. @@ -742,4 +744,5 @@ def validate_id(cls, v: Any) -> str: return validate_uuid1_hex(v) -class ListDatasetReq(BaseListReq): ... +class ListDatasetReq(BaseListReq): + include_parsing_status: Annotated[bool, Field(default=False)] diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 1f50d5753ee..8e7199b1192 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -835,14 +835,14 @@ Failure: ### List datasets -**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}` +**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}` Lists datasets. #### Request - Method: GET -- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}` +- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}` - Headers: - `'Authorization: Bearer '` @@ -854,6 +854,13 @@ curl --request GET \ --header 'Authorization: Bearer ' ``` +```bash +# List datasets with parsing status +curl --request GET \ + --url 'http://{address}/api/v1/datasets?include_parsing_status=true' \ + --header 'Authorization: Bearer ' +``` + ##### Request parameters - `page`: (*Filter parameter*) @@ -870,6 +877,13 @@ curl --request GET \ The name of the dataset to retrieve. - `id`: (*Filter parameter*) The ID of the dataset to retrieve. +- `include_parsing_status`: (*Filter parameter*) + Whether to include document parsing status counts in the response. Defaults to `false`. When set to `true`, each dataset object in the response will include the following additional fields: + - `unstart_count`: Number of documents not yet started parsing. + - `running_count`: Number of documents currently being parsed. + - `cancel_count`: Number of documents whose parsing was cancelled. + - `done_count`: Number of documents that have been successfully parsed. + - `fail_count`: Number of documents whose parsing failed. #### Response @@ -917,6 +931,49 @@ Success: } ``` +Success (with `include_parsing_status=true`): + +```json +{ + "code": 0, + "data": [ + { + "avatar": null, + "cancel_count": 0, + "chunk_count": 30, + "chunk_method": "qa", + "create_date": "2026-03-09T18:57:13", + "create_time": 1773053833094, + "created_by": "928f92a210b911f1ac4cc39e0b8fa3ad", + "description": null, + "document_count": 1, + "done_count": 1, + "embedding_model": "text-embedding-v2@Tongyi-Qianwen", + "fail_count": 0, + "id": "ba6586c21ba611f1a3dc476f0709e75e", + "language": "English", + "name": "Test Dataset", + "parser_config": { + "graphrag": { "use_graphrag": false }, + "llm_id": "deepseek-chat@DeepSeek", + "raptor": { "use_raptor": false } + }, + "permission": "me", + "running_count": 0, + "similarity_threshold": 0.2, + "status": "1", + "tenant_id": "928f92a210b911f1ac4cc39e0b8fa3ad", + "token_num": 1746, + "unstart_count": 0, + "update_date": "2026-03-09T18:59:32", + "update_time": 1773053972723, + "vector_similarity_weight": 0.3 + } + ], + "total_datasets": 1 +} +``` + Failure: ```json diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 430e58a0f6f..8fa97f3d563 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -266,7 +266,8 @@ RAGFlow.list_datasets( orderby: str = "create_time", desc: bool = True, id: str = None, - name: str = None + name: str = None, + include_parsing_status: bool = False ) -> list[DataSet] ``` @@ -301,6 +302,16 @@ The ID of the dataset to retrieve. Defaults to `None`. The name of the dataset to retrieve. Defaults to `None`. +##### include_parsing_status: `bool` + +Whether to include document parsing status counts in each returned `DataSet` object. Defaults to `False`. When set to `True`, each `DataSet` object will include the following additional attributes: + +- `unstart_count`: `int` Number of documents not yet started parsing. +- `running_count`: `int` Number of documents currently being parsed. +- `cancel_count`: `int` Number of documents whose parsing was cancelled. +- `done_count`: `int` Number of documents that have been successfully parsed. +- `fail_count`: `int` Number of documents whose parsing failed. + #### Returns - Success: A list of `DataSet` objects. @@ -322,6 +333,13 @@ dataset = rag_object.list_datasets(id = "id_1") print(dataset[0]) ``` +##### List datasets with parsing status + +```python +for dataset in rag_object.list_datasets(include_parsing_status=True): + print(dataset.done_count, dataset.fail_count, dataset.running_count) +``` + --- ### Update dataset diff --git a/example/http/dataset_example.sh b/example/http/dataset_example.sh index 492d902d003..1d2e8fa68f3 100644 --- a/example/http/dataset_example.sh +++ b/example/http/dataset_example.sh @@ -41,6 +41,12 @@ curl --request GET \ --url http://127.0.0.1:9380/api/v1/datasets \ --header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm' +# List datasets with parsing status +echo -e "\n-- List datasets with parsing status" +curl --request GET \ + --url 'http://127.0.0.1:9380/api/v1/datasets?include_parsing_status=true' \ + --header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm' + # Delete datasets echo -e "\n-- Delete datasets" curl --request DELETE \ diff --git a/test/unit_test/api/db/services/test_document_service_get_parsing_status.py b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py new file mode 100644 index 00000000000..997fe6f8611 --- /dev/null +++ b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py @@ -0,0 +1,326 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import types +import warnings + +import pytest + +# xgboost imports pkg_resources and emits a deprecation warning that is promoted +# to error in our pytest configuration; ignore it for this unit test module. +warnings.filterwarnings( + "ignore", + message="pkg_resources is deprecated as an API.*", + category=UserWarning, +) + + +def _install_cv2_stub_if_unavailable(): + try: + import cv2 # noqa: F401 + return + except Exception: + pass + + stub = types.ModuleType("cv2") + + stub.INTER_LINEAR = 1 + stub.INTER_CUBIC = 2 + stub.BORDER_CONSTANT = 0 + stub.BORDER_REPLICATE = 1 + stub.COLOR_BGR2RGB = 0 + stub.COLOR_BGR2GRAY = 1 + stub.COLOR_GRAY2BGR = 2 + stub.IMREAD_IGNORE_ORIENTATION = 128 + stub.IMREAD_COLOR = 1 + stub.RETR_LIST = 1 + stub.CHAIN_APPROX_SIMPLE = 2 + + def _missing(*_args, **_kwargs): + raise RuntimeError("cv2 runtime call is unavailable in this test environment") + + def _module_getattr(name): + if name.isupper(): + return 0 + return _missing + + stub.__getattr__ = _module_getattr + sys.modules["cv2"] = stub + + +_install_cv2_stub_if_unavailable() + +from api.db.services.document_service import DocumentService # noqa: E402 +from common.constants import TaskStatus # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers to access the original function bypassing @DB.connection_context() +# --------------------------------------------------------------------------- + +def _unwrapped_get_parsing_status(): + """Return the original (un-decorated) get_parsing_status_by_kb_ids function. + + @classmethod + @DB.connection_context() together means: + DocumentService.get_parsing_status_by_kb_ids.__func__ -> connection_context wrapper + ....__func__.__wrapped__ -> original function + """ + return DocumentService.get_parsing_status_by_kb_ids.__func__.__wrapped__ + + +# --------------------------------------------------------------------------- +# Fake ORM helpers – mimic the minimal peewee query chain used by the function +# --------------------------------------------------------------------------- + +class _FieldStub: + """Minimal stand-in for a peewee model field used in select/where/group_by.""" + + def in_(self, values): + """Called by .where(cls.model.kb_id.in_(kb_ids)) – no-op in tests.""" + return self + + def alias(self, name): + return self + + +class _FakeQuery: + """Chains .where(), .group_by(), .dicts() without touching a real database.""" + + def __init__(self, rows): + self._rows = rows + + def where(self, *_args, **_kwargs): + return self + + def group_by(self, *_args, **_kwargs): + return self + + def dicts(self): + return list(self._rows) + + +def _make_fake_model(rows): + """Create a fake Document model class whose select() returns *rows*.""" + + class _FakeModel: + id = _FieldStub() + kb_id = _FieldStub() + run = _FieldStub() + + @classmethod + def select(cls, *_args): + return _FakeQuery(rows) + + return _FakeModel + + +# --------------------------------------------------------------------------- +# Pytest fixture – patch DocumentService.model per test +# --------------------------------------------------------------------------- + +@pytest.fixture() +def call_with_rows(monkeypatch): + """Return a helper that runs get_parsing_status_by_kb_ids with fake DB rows.""" + + def _call(rows, kb_ids): + monkeypatch.setattr(DocumentService, "model", _make_fake_model(rows)) + fn = _unwrapped_get_parsing_status() + return fn(DocumentService, kb_ids) + + return _call + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +_ALL_STATUS_FIELDS = frozenset( + ["unstart_count", "running_count", "cancel_count", "done_count", "fail_count"] +) + + +@pytest.mark.p2 +class TestGetParsingStatusByKbIds: + + # ------------------------------------------------------------------ + # Edge-case: empty input list – must short-circuit before any DB call + # ------------------------------------------------------------------ + + def test_empty_kb_ids_returns_empty_dict(self, call_with_rows): + result = call_with_rows([], []) + assert result == {} + + # ------------------------------------------------------------------ + # A kb_id present in the input but with no matching documents + # ------------------------------------------------------------------ + + def test_single_kb_id_no_documents(self, call_with_rows): + result = call_with_rows(rows=[], kb_ids=["kb-1"]) + + assert set(result.keys()) == {"kb-1"} + assert set(result["kb-1"].keys()) == _ALL_STATUS_FIELDS + assert all(v == 0 for v in result["kb-1"].values()) + + # ------------------------------------------------------------------ + # A single kb_id with one document in each run-status bucket + # ------------------------------------------------------------------ + + def test_single_kb_id_all_five_statuses(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.UNSTART.value, "cnt": 3}, + {"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": 1}, + {"kb_id": "kb-1", "run": TaskStatus.CANCEL.value, "cnt": 2}, + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 10}, + {"kb_id": "kb-1", "run": TaskStatus.FAIL.value, "cnt": 4}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["unstart_count"] == 3 + assert result["kb-1"]["running_count"] == 1 + assert result["kb-1"]["cancel_count"] == 2 + assert result["kb-1"]["done_count"] == 10 + assert result["kb-1"]["fail_count"] == 4 + + # ------------------------------------------------------------------ + # Two kb_ids – counts must be independent per dataset + # ------------------------------------------------------------------ + + def test_multiple_kb_ids_aggregated_separately(self, call_with_rows): + rows = [ + {"kb_id": "kb-a", "run": TaskStatus.DONE.value, "cnt": 5}, + {"kb_id": "kb-a", "run": TaskStatus.FAIL.value, "cnt": 1}, + {"kb_id": "kb-b", "run": TaskStatus.UNSTART.value, "cnt": 7}, + {"kb_id": "kb-b", "run": TaskStatus.DONE.value, "cnt": 2}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-a", "kb-b"]) + + assert set(result.keys()) == {"kb-a", "kb-b"} + + assert result["kb-a"]["done_count"] == 5 + assert result["kb-a"]["fail_count"] == 1 + assert result["kb-a"]["unstart_count"] == 0 + assert result["kb-a"]["running_count"] == 0 + assert result["kb-a"]["cancel_count"] == 0 + + assert result["kb-b"]["unstart_count"] == 7 + assert result["kb-b"]["done_count"] == 2 + assert result["kb-b"]["fail_count"] == 0 + + # ------------------------------------------------------------------ + # An unrecognised run value must be silently ignored + # ------------------------------------------------------------------ + + def test_unknown_run_value_ignored(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": "9", "cnt": 99}, # "9" is not a TaskStatus + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 4}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 4 + assert all( + result["kb-1"][f] == 0 + for f in _ALL_STATUS_FIELDS - {"done_count"} + ) + + # ------------------------------------------------------------------ + # A row whose kb_id was NOT requested must not appear in the output + # ------------------------------------------------------------------ + + def test_row_with_unrequested_kb_id_is_filtered_out(self, call_with_rows): + rows = [ + {"kb_id": "kb-requested", "run": TaskStatus.DONE.value, "cnt": 3}, + {"kb_id": "kb-unexpected", "run": TaskStatus.DONE.value, "cnt": 100}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-requested"]) + + assert "kb-unexpected" not in result + assert result["kb-requested"]["done_count"] == 3 + + # ------------------------------------------------------------------ + # cnt values must be treated as integers regardless of DB type hints + # ------------------------------------------------------------------ + + def test_cnt_is_cast_to_int(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": "7"}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["running_count"] == 7 + assert isinstance(result["kb-1"]["running_count"], int) + + # ------------------------------------------------------------------ + # run value stored as integer in DB (some adapters may omit str cast) + # ------------------------------------------------------------------ + + def test_run_value_as_integer_is_handled(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": int(TaskStatus.DONE.value), "cnt": 5}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 5 + + # ------------------------------------------------------------------ + # All five status fields are initialised to 0 even when no rows exist + # ------------------------------------------------------------------ + + def test_all_five_fields_initialised_to_zero(self, call_with_rows): + result = call_with_rows(rows=[], kb_ids=["kb-empty"]) + + assert result["kb-empty"] == { + "unstart_count": 0, + "running_count": 0, + "cancel_count": 0, + "done_count": 0, + "fail_count": 0, + } + + # ------------------------------------------------------------------ + # Multiple kb_ids in the input – all should appear in the result + # even when no documents exist for some of them + # ------------------------------------------------------------------ + + def test_requested_kb_ids_all_present_in_result(self, call_with_rows): + rows = [ + {"kb_id": "kb-with-data", "run": TaskStatus.DONE.value, "cnt": 1}, + ] + result = call_with_rows( + rows=rows, kb_ids=["kb-with-data", "kb-empty-1", "kb-empty-2"] + ) + + assert set(result.keys()) == {"kb-with-data", "kb-empty-1", "kb-empty-2"} + assert result["kb-empty-1"] == {f: 0 for f in _ALL_STATUS_FIELDS} + assert result["kb-empty-2"] == {f: 0 for f in _ALL_STATUS_FIELDS} + + # ------------------------------------------------------------------ + # SCHEDULE (run=="5") is not mapped – must be silently ignored + # ------------------------------------------------------------------ + + def test_schedule_status_is_not_mapped(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.SCHEDULE.value, "cnt": 3}, + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 2}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 2 + # SCHEDULE is not a tracked bucket + assert "schedule_count" not in result["kb-1"] + assert all( + result["kb-1"][f] == 0 + for f in _ALL_STATUS_FIELDS - {"done_count"} + ) From 97e17d2cc06c30ff16e24941d1e79b8d0063c31f Mon Sep 17 00:00:00 2001 From: yzy <183124902@qq.com> Date: Tue, 10 Mar 2026 19:22:04 +0800 Subject: [PATCH 201/565] Fix: return structured JSON output for non-streaming agent API (#13389) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Previously, when an Agent component was configured with structured output, the non-streaming /agents/{agent_id}/completions API never returned the structured field in its response. The root cause: the non-streaming code path only collected message events to build full_content, then returned the workflow_finished payload — which only contains the output of the last component in the execution path (typically a Message component). Any structured output set by upstream components (e.g., Agent or LLM) was silently discarded. This PR fixes the non-streaming handler to iterate node_finished events and collect structured output from intermediate components. If any component produced a non-empty structured value, it is included in the final response under data.structured. The streaming path is unaffected, as it already exposes node_finished events to the caller. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/sdk/session.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 10439564d58..e628a3c960c 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -584,6 +584,7 @@ async def generate(): reference = {} final_ans = "" trace_items = [] + structured_output = {} async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): try: ans = json.loads(answer[5:]) @@ -594,20 +595,26 @@ async def generate(): if ans.get("data", {}).get("reference", None): reference.update(ans["data"]["reference"]) - if return_trace and ans.get("event") == "node_finished": - data = ans.get("data", {}) - trace_items.append( - { - "component_id": data.get("component_id"), - "trace": [copy.deepcopy(data)], - } - ) + if ans.get("event") == "node_finished": + node_out = ans.get("data", {}).get("outputs", {}) + if node_out.get("structured"): + structured_output = node_out["structured"] + if return_trace: + data = ans.get("data", {}) + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) final_ans = ans except Exception as e: return get_result(data=f"**ERROR**: {str(e)}") final_ans["data"]["content"] = full_content final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output if return_trace and final_ans: final_ans["data"]["trace"] = trace_items return get_result(data=final_ans) From 53b2e221f69a74e6ceea44d82456ea578e8bd09d Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 10 Mar 2026 21:13:14 +0800 Subject: [PATCH 202/565] Fix: support vLLM's new reasoning field (#13493) ### What problem does this PR solve? Support vLLM's new reasoning field ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 10b2fb5155e..4476ccbbde4 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -149,12 +149,13 @@ async def _async_chat_streamly(self, history, gen_conf, **kwargs): continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" - if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: + _reasoning = getattr(resp.choices[0].delta, "reasoning_content", None) or getattr(resp.choices[0].delta, "reasoning", None) + if kwargs.get("with_reasoning", True) and _reasoning: ans = "" if not reasoning_start: reasoning_start = True ans = "" - ans += resp.choices[0].delta.reasoning_content + "" + ans += _reasoning + "" else: reasoning_start = False ans = resp.choices[0].delta.content @@ -294,8 +295,9 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict raise Exception(f"500 response structure error. Response: {response}") if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: - if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: - ans += "" + response.choices[0].message.reasoning_content + "" + _reasoning = getattr(response.choices[0].message, "reasoning_content", None) or getattr(response.choices[0].message, "reasoning", None) + if _reasoning: + ans += "" + _reasoning + "" ans += response.choices[0].message.content if response.choices[0].finish_reason == "length": @@ -370,12 +372,13 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c if not hasattr(delta, "content") or delta.content is None: delta.content = "" - if hasattr(delta, "reasoning_content") and delta.reasoning_content: + _reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None) + if _reasoning: ans = "" if not reasoning_start: reasoning_start = True ans = "" - ans += delta.reasoning_content + "" + ans += _reasoning + "" yield ans else: reasoning_start = False @@ -1279,12 +1282,13 @@ async def async_chat_streamly(self, system, history, gen_conf, **kwargs): if not hasattr(delta, "content") or delta.content is None: delta.content = "" - if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: + _reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None) + if kwargs.get("with_reasoning", True) and _reasoning: ans = "" if not reasoning_start: reasoning_start = True ans = "" - ans += delta.reasoning_content + "" + ans += _reasoning + "" else: reasoning_start = False ans = delta.content @@ -1404,8 +1408,9 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict message = response.choices[0].message if not hasattr(message, "tool_calls") or not message.tool_calls: - if hasattr(message, "reasoning_content") and message.reasoning_content: - ans += f"{message.reasoning_content}" + _reasoning = getattr(message, "reasoning_content", None) or getattr(message, "reasoning", None) + if _reasoning: + ans += f"{_reasoning}" ans += message.content or "" if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) @@ -1485,12 +1490,13 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c if not hasattr(delta, "content") or delta.content is None: delta.content = "" - if hasattr(delta, "reasoning_content") and delta.reasoning_content: + _reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None) + if _reasoning: ans = "" if not reasoning_start: reasoning_start = True ans = "" - ans += delta.reasoning_content + "" + ans += _reasoning + "" yield ans else: reasoning_start = False From 1038bff32bc43440c683fa652cdaf050029b5101 Mon Sep 17 00:00:00 2001 From: balibabu Date: Tue, 10 Mar 2026 22:18:27 +0800 Subject: [PATCH 203/565] Feat: Add a user_id field to the message and retrieval operators. (#13508) ### What problem does this PR solve? Feat: Add a user_id field to the message and retrieval operators. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/locales/en.ts | 1 + web/src/locales/zh.ts | 1 + .../agent/form/components/user-id-form-field.tsx | 13 +++++++++++++ web/src/pages/agent/form/message-form/index.tsx | 3 +++ web/src/pages/agent/form/retrieval-form/next.tsx | 7 ++++++- 5 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 web/src/pages/agent/form/components/user-id-form-field.tsx diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index f561c41bfe6..debbcf593cd 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1839,6 +1839,7 @@ Example: Virtual Hosted Style`, dbType: 'Database type', database: 'Database', username: 'Username', + userId: 'User id', host: 'Host', port: 'Port', password: 'Password', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index b020ff02a1e..7efa7f261e4 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -1599,6 +1599,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 dbType: '数据库类型', database: '数据库', username: '用户名', + userId: '用户 ID', host: '主机', port: '端口', password: '密码', diff --git a/web/src/pages/agent/form/components/user-id-form-field.tsx b/web/src/pages/agent/form/components/user-id-form-field.tsx new file mode 100644 index 00000000000..422545cd351 --- /dev/null +++ b/web/src/pages/agent/form/components/user-id-form-field.tsx @@ -0,0 +1,13 @@ +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { useTranslation } from 'react-i18next'; +import { PromptEditor } from './prompt-editor'; + +export function UserIdFormField() { + const { t } = useTranslation(); + + return ( + + + + ); +} diff --git a/web/src/pages/agent/form/message-form/index.tsx b/web/src/pages/agent/form/message-form/index.tsx index 87071e5780d..ddc8a09ca74 100644 --- a/web/src/pages/agent/form/message-form/index.tsx +++ b/web/src/pages/agent/form/message-form/index.tsx @@ -21,6 +21,7 @@ import { ExportFileType } from '../../constant'; import { INextOperatorForm } from '../../interface'; import { FormWrapper } from '../components/form-wrapper'; import { PromptEditor } from '../components/prompt-editor'; +import { UserIdFormField } from '../components/user-id-form-field'; import { useShowWebhookResponseStatus } from './use-show-response-status'; import { useValues } from './use-values'; import { useWatchFormChange } from './use-watch-change'; @@ -42,6 +43,7 @@ function MessageForm({ node }: INextOperatorForm) { auto_play: z.boolean().optional(), status: z.number().optional(), memory_ids: z.array(z.string()).optional(), + user_id: z.string().optional(), }); const form = useForm({ @@ -163,6 +165,7 @@ function MessageForm({ node }: INextOperatorForm) { )} + ); diff --git a/web/src/pages/agent/form/retrieval-form/next.tsx b/web/src/pages/agent/form/retrieval-form/next.tsx index 4776e3af8fb..0e155076ef0 100644 --- a/web/src/pages/agent/form/retrieval-form/next.tsx +++ b/web/src/pages/agent/form/retrieval-form/next.tsx @@ -38,6 +38,7 @@ import { INextOperatorForm } from '../../interface'; import { FormWrapper } from '../components/form-wrapper'; import { Output } from '../components/output'; import { PromptEditor } from '../components/prompt-editor'; +import { UserIdFormField } from '../components/user-id-form-field'; import { useValues } from './use-values'; export const RetrievalPartialSchema = { @@ -54,6 +55,7 @@ export const RetrievalPartialSchema = { ...MetadataFilterSchema, memory_ids: z.array(z.string()).optional(), retrieval_from: z.string(), + user_id: z.string().optional(), }; export const FormSchema = z.object({ @@ -82,7 +84,10 @@ export function MemoryDatasetForm() { {retrievalFrom === RetrievalFrom.Memory ? ( - + <> + + + ) : ( )} From f7e2db070f5c073be908fadd68665010c8f47fa6 Mon Sep 17 00:00:00 2001 From: eviaaaaa <2278596667@qq.com> Date: Wed, 11 Mar 2026 10:00:07 +0800 Subject: [PATCH 204/565] Refa: implement unified lazy image loading for Docx parsers (qa/manual) (#13329) ## Summary This PR is the direct successor to the previous `docx` lazy-loading implementation. It addresses the technical debt intentionally left out in the last PR by fully migrating the `qa` and `manual` parsing strategies to the new lazy-loading model. Additionally, this PR comprehensively refactors the underlying `docx` parsing pipeline to eliminate significant code redundancy and introduces robust fallback mechanisms to handle completely corrupted image streams safely. ## What's Changed * **Centralized Abstraction (`docx_parser.py`)**: Moved the `get_picture` extraction logic up to the `RAGFlowDocxParser` base class. Previously, `naive`, `qa`, and `manual` parsers maintained separate, redundant copies of this method. All downstream strategies now natively gather raw blobs and return `LazyDocxImage` objects automatically. * **Robust Corrupted Image Fallback (`docx_parser.py`)**: Handled edge cases where `python-docx` encounters critically malformed magic headers. Implemented an explicit `try-except` structure that safely intercepts `UnrecognizedImageError` (and similar exceptions) and seamlessly falls back to retrieving the raw binary via `getattr(related_part, "blob", None)`, preventing parser crashes on damaged documents. * **Legacy Code & Redundancy Purge**: * Removed the duplicate `get_picture` methods from `naive.py`, `qa.py`, and `manual.py`. * Removed the standalone, immediate-decoding `concat_img` method in `manual.py`. It has been completely replaced by the globally unified, lazy-loading-compatible `rag.nlp.concat_img`. * Cleaned up unused legacy imports (e.g., `PIL.Image`, docx exception packages) across all updated strategy files. ## Scope To keep this PR focused, I have restricted these changes strictly to the unification of `docx` extraction logic and the lazy-load migration of `qa` and `manual`. ## Validation & Testing I've tested this to ensure no regressions and validated the fallback logic: * **Output Consistency**: Compared identical `.docx` inputs using `qa` and `manual` strategies before and after this branch: chunk counts, extracted text, table HTML, and attached images match perfectly. * **Memory Footprint Drop**: Confirmed a noticeable drop in peak memory usage when processing image-dense documents through the `qa` and `manual` pipelines, bringing them up to parity with the `naive` strategy's performance gains. ## Breaking Changes * None. --- deepdoc/parser/docx_parser.py | 47 ++++++++++++++++++++++++++++++++++- rag/app/manual.py | 44 ++------------------------------ rag/app/naive.py | 36 --------------------------- rag/app/qa.py | 12 --------- rag/nlp/__init__.py | 15 +++++++++-- rag/utils/lazy_image.py | 13 ++++++++++ 6 files changed, 74 insertions(+), 93 deletions(-) diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py index 2a65841e246..a17543cbf49 100644 --- a/deepdoc/parser/docx_parser.py +++ b/deepdoc/parser/docx_parser.py @@ -20,9 +20,54 @@ from collections import Counter from rag.nlp import rag_tokenizer from io import BytesIO - +import logging +from docx.image.exceptions import ( + InvalidImageStreamError, + UnexpectedEndOfFileError, + UnrecognizedImageError, +) +from rag.utils.lazy_image import LazyDocxImage class RAGFlowDocxParser: + def get_picture(self, document, paragraph): + imgs = paragraph._element.xpath(".//pic:pic") + if not imgs: + return None + image_blobs = [] + for img in imgs: + embed = img.xpath(".//a:blip/@r:embed") + if not embed: + continue + embed = embed[0] + image_blob = None + try: + related_part = document.part.related_parts[embed] + except Exception as e: + logging.warning(f"Skipping image due to unexpected error getting related_part: {e}") + continue + + try: + image = related_part.image + if image is not None: + image_blob = image.blob + except ( + UnrecognizedImageError, + UnexpectedEndOfFileError, + InvalidImageStreamError, + UnicodeDecodeError, + ) as e: + logging.info(f"Damaged image encountered, attempting blob fallback: {e}") + except Exception as e: + logging.warning(f"Unexpected error getting image, attempting blob fallback: {e}") + + if image_blob is None: + image_blob = getattr(related_part, "blob", None) + if image_blob: + image_blobs.append(image_blob) + if not image_blobs: + return None + return LazyDocxImage(image_blobs) + def __extract_table_content(self, tb): df = [] diff --git a/rag/app/manual.py b/rag/app/manual.py index 5f3b5879202..e2af0706f22 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -20,12 +20,11 @@ from common.constants import ParserType from io import BytesIO -from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context +from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context, concat_img from common.token_utils import num_tokens_from_string from deepdoc.parser import PdfParser, DocxParser from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper from docx import Document -from PIL import Image from rag.app.naive import by_plaintext, PARSERS from common.parser_config_utils import normalize_layout_recognizer @@ -71,45 +70,6 @@ class Docx(DocxParser): def __init__(self): pass - def get_picture(self, document, paragraph): - img = paragraph._element.xpath(".//pic:pic") - if not img: - return None - try: - img = img[0] - embed = img.xpath(".//a:blip/@r:embed")[0] - related_part = document.part.related_parts[embed] - image = related_part.image - if image is not None: - image = Image.open(BytesIO(image.blob)) - return image - elif related_part.blob is not None: - image = Image.open(BytesIO(related_part.blob)) - return image - else: - return None - except Exception: - return None - - def concat_img(self, img1, img2): - if img1 and not img2: - return img1 - if not img1 and img2: - return img2 - if not img1 and not img2: - return None - width1, height1 = img1.size - width2, height2 = img2.size - - new_width = max(width1, width2) - new_height = height1 + height2 - new_image = Image.new("RGB", (new_width, new_height)) - - new_image.paste(img1, (0, 0)) - new_image.paste(img2, (0, height1)) - - return new_image - def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None): self.doc = Document(filename) if not binary else Document(BytesIO(binary)) pn = 0 @@ -125,7 +85,7 @@ def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback= if not question_level or question_level > 6: # not a question last_answer = f"{last_answer}\n{p_text}" current_image = self.get_picture(self.doc, p) - last_image = self.concat_img(last_image, current_image) + last_image = concat_img(last_image, current_image) else: # is a question if last_answer or last_image: sum_question = "\n".join(question_stack) diff --git a/rag/app/naive.py b/rag/app/naive.py index 06201907446..1d2d0ebbf7f 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -21,7 +21,6 @@ from io import BytesIO from timeit import default_timer as timer from docx import Document -from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship from docx.table import Table as DocxTable from docx.text.paragraph import Paragraph @@ -34,7 +33,6 @@ from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html -from rag.utils.lazy_image import LazyDocxImage from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper_naive, vision_figure_parser_pdf_wrapper from deepdoc.parser.pdf_parser import PlainParser, VisionParser @@ -265,40 +263,6 @@ class Docx(DocxParser): def __init__(self): pass - def get_picture(self, document, paragraph): - imgs = paragraph._element.xpath(".//pic:pic") - if not imgs: - return None - image_blobs = [] - for img in imgs: - embed = img.xpath(".//a:blip/@r:embed") - if not embed: - continue - embed = embed[0] - try: - related_part = document.part.related_parts[embed] - image_blob = related_part.image.blob - except UnrecognizedImageError: - logging.info("Unrecognized image format. Skipping image.") - continue - except UnexpectedEndOfFileError: - logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.") - continue - except InvalidImageStreamError: - logging.info("The recognized image stream appears to be corrupted. Skipping image.") - continue - except UnicodeDecodeError: - logging.info("The recognized image stream appears to be corrupted. Skipping image.") - continue - except Exception as e: - logging.warning(f"The recognized image stream appears to be corrupted. Skipping image, exception: {e}") - continue - image_blobs.append(image_blob) - - if not image_blobs: - return None - return LazyDocxImage(image_blobs) - def __clean(self, line): line = re.sub(r"\u3000", " ", line).strip() return line diff --git a/rag/app/qa.py b/rag/app/qa.py index 95678faaa2b..da6d72cf736 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -27,7 +27,6 @@ from rag.nlp import rag_tokenizer, tokenize_table, concat_img from deepdoc.parser import PdfParser, ExcelParser, DocxParser from docx import Document -from PIL import Image from markdown import markdown from common.float_utils import get_float @@ -192,17 +191,6 @@ class Docx(DocxParser): def __init__(self): pass - def get_picture(self, document, paragraph): - img = paragraph._element.xpath('.//pic:pic') - if not img: - return None - img = img[0] - embed = img.xpath('.//a:blip/@r:embed')[0] - related_part = document.part.related_parts[embed] - image = related_part.image - image = Image.open(BytesIO(image.blob)).convert('RGB') - return image - def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None): self.doc = Document( filename) if not binary else Document(BytesIO(binary)) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 364e953881a..be1cef05b96 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -1200,7 +1200,7 @@ def add_chunk(t, image, pos=""): def docx_question_level(p, bull=-1): txt = re.sub(r"\u3000", " ", p.text).strip() - if p.style.name.startswith('Heading'): + if hasattr(p.style, 'name') and p.style.name and p.style.name.startswith('Heading'): return int(p.style.name.split(' ')[-1]), txt else: if bull < 0: @@ -1212,7 +1212,18 @@ def docx_question_level(p, bull=-1): def concat_img(img1, img2): - from rag.utils.lazy_image import ensure_pil_image + from rag.utils.lazy_image import ensure_pil_image, LazyDocxImage + + # Fast path: preserve laziness when both sides are LazyDocxImage or None. + if (img1 is None or isinstance(img1, LazyDocxImage)) and \ + (img2 is None or isinstance(img2, LazyDocxImage)): + if img1 and not img2: + return img1 + if not img1 and img2: + return img2 + if not img1 and not img2: + return None + return LazyDocxImage.merge(img1, img2) img1 = ensure_pil_image(img1) or img1 img2 = ensure_pil_image(img2) or img2 diff --git a/rag/utils/lazy_image.py b/rag/utils/lazy_image.py index 7de2bfd5ce5..c120bef913c 100644 --- a/rag/utils/lazy_image.py +++ b/rag/utils/lazy_image.py @@ -88,6 +88,19 @@ def __exit__(self, exc_type, exc, tb): self.close() return False + @staticmethod + def merge(a, b): + """ + Merge two LazyDocxImage instances by combining their blob lists. + """ + a_blobs = a._blobs if isinstance(a, LazyDocxImage) else [] + b_blobs = b._blobs if isinstance(b, LazyDocxImage) else [] + combined = a_blobs + b_blobs + if not combined: + return None + merged = LazyDocxImage(combined) + return merged + def ensure_pil_image(img): if isinstance(img, Image.Image): From 719301ca64bf117824f4695954a1a8c61fef292f Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 11 Mar 2026 11:23:13 +0800 Subject: [PATCH 205/565] Add auth middleware (#13506) ### What problem does this PR solve? Use auth middle-ware to check authorization. ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai --- cmd/server_main.go | 5 +- internal/dao/database.go | 1 - internal/handler/auth.go | 81 +++++++++++ internal/handler/chat.go | 81 ++--------- internal/handler/chat_session.go | 81 ++--------- internal/handler/chunk.go | 21 +-- internal/handler/common.go | 37 ++++++ internal/handler/connector.go | 21 +-- internal/handler/document.go | 37 ++++++ internal/handler/file.go | 81 ++--------- internal/handler/kb.go | 137 ++++++++----------- internal/handler/llm.go | 80 ++--------- internal/handler/search.go | 21 +-- internal/handler/tenant.go | 40 +----- internal/handler/user.go | 92 +++---------- internal/router/router.go | 221 ++++++++++++++++--------------- 16 files changed, 412 insertions(+), 625 deletions(-) create mode 100644 internal/handler/auth.go create mode 100644 internal/handler/common.go diff --git a/cmd/server_main.go b/cmd/server_main.go index c4919abeca8..9eef17d97e0 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,7 +6,7 @@ import ( "net/http" "os" "os/signal" - "ragflow/internal/common" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/utility" "strings" @@ -134,6 +134,7 @@ func startServer(config *server.Config) { fileService := service.NewFileService() // Initialize handler layer + authHandler := handler.NewAuthHandler() userHandler := handler.NewUserHandler(userService) tenantHandler := handler.NewTenantHandler(tenantService, userService) documentHandler := handler.NewDocumentHandler(documentService) @@ -148,7 +149,7 @@ func startServer(config *server.Config) { fileHandler := handler.NewFileHandler(fileService, userService) // Initialize router - r := router.NewRouter(userHandler, tenantHandler, documentHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler) // Create Gin engine ginEngine := gin.New() diff --git a/internal/dao/database.go b/internal/dao/database.go index 1529088fb4e..b35e79c80c2 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -248,6 +248,5 @@ func InitLLMFactory() error { } } - log.Println("LLM factories initialized successfully") return nil } diff --git a/internal/handler/auth.go b/internal/handler/auth.go new file mode 100644 index 00000000000..ca232645a0c --- /dev/null +++ b/internal/handler/auth.go @@ -0,0 +1,81 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "ragflow/internal/common" + "ragflow/internal/service" + + "github.com/gin-gonic/gin" +) + +// AuthHandler auth handler +type AuthHandler struct { + userService *service.UserService +} + +// NewAuthHandler create auth handler +func NewAuthHandler() *AuthHandler { + return &AuthHandler{ + userService: service.NewUserService(), + } +} + +// AuthMiddleware JWT auth middleware +// Validates that the user is authenticated and is a superuser (admin) +func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + c.Abort() + return + } + + // Get user by access token + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": code, + "message": "Invalid access token", + }) + c.Abort() + return + } + + if *user.IsSuperuser { + c.JSON(http.StatusForbidden, gin.H{ + "code": common.CodeForbidden, + "message": "Super user should access the URL", + }) + return + } + + c.Set("user", user) + c.Set("user_id", user.ID) + c.Set("email", user.Email) + c.Next() + } +} + +func (h *AuthHandler) LoginByEmail1(c *gin.Context) { + println("hello") +} diff --git a/internal/handler/chat.go b/internal/handler/chat.go index c7b2dde9842..b5e92192c96 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "strconv" "github.com/gin-gonic/gin" @@ -48,23 +49,9 @@ func NewChatHandler(chatService *service.ChatService, userService *service.UserS // @Success 200 {object} service.ListChatsResponse // @Router /v1/dialog/list [get] func (h *ChatHandler) ListChats(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -101,23 +88,9 @@ func (h *ChatHandler) ListChats(c *gin.Context) { // @Success 200 {object} service.ListChatsNextResponse // @Router /v1/dialog/next [post] func (h *ChatHandler) ListChatsNext(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -185,23 +158,9 @@ func (h *ChatHandler) ListChatsNext(c *gin.Context) { // @Success 200 {object} service.SetDialogResponse // @Router /v1/dialog/set [post] func (h *ChatHandler) SetDialog(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -257,23 +216,9 @@ type RemoveDialogsRequest struct { // @Success 200 {object} map[string]interface{} // @Router /v1/dialog/rm [post] func (h *ChatHandler) RemoveChats(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index 54995371a55..ebf293957ed 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "ragflow/internal/common" "github.com/gin-gonic/gin" @@ -50,23 +51,9 @@ func NewChatSessionHandler(chatSessionService *service.ChatSessionService, userS // @Success 200 {object} service.SetChatSessionResponse // @Router /v1/conversation/set [post] func (h *ChatSessionHandler) SetChatSession(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -113,23 +100,9 @@ type RemoveChatSessionsRequest struct { // @Success 200 {object} map[string]interface{} // @Router /v1/conversation/rm [post] func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -179,23 +152,9 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { // @Success 200 {object} service.ListChatSessionsResponse // @Router /v1/conversation/list [get] func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -259,23 +218,9 @@ type CompletionRequest struct { // @Success 200 {object} map[string]interface{} // @Router /v1/conversation/completion [post] func (h *ChatSessionHandler) Completion(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index d13f4ac2792..6b855ad4d14 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "github.com/gin-gonic/gin" @@ -48,23 +49,9 @@ func NewChunkHandler(chunkService *service.ChunkService, userService *service.Us // @Success 200 {object} map[string]interface{} // @Router /v1/chunk/retrieval_test [post] func (h *ChunkHandler) RetrievalTest(c *gin.Context) { - // Extract access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } diff --git a/internal/handler/common.go b/internal/handler/common.go new file mode 100644 index 00000000000..3eb0f6f15ae --- /dev/null +++ b/internal/handler/common.go @@ -0,0 +1,37 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "ragflow/internal/common" + "ragflow/internal/model" + + "github.com/gin-gonic/gin" +) + +func GetUser(c *gin.Context) (*model.User, common.ErrorCode, string) { + userAny, exist := c.Get("user") + if !exist { + return nil, common.CodeUnauthorized, "User not found" + } + + user, ok := userAny.(*model.User) + if !ok { + return nil, common.CodeUnauthorized, "User not found" + } + return user, common.CodeSuccess, "" +} diff --git a/internal/handler/connector.go b/internal/handler/connector.go index 6c0ebedb051..5b1c5faf3ce 100644 --- a/internal/handler/connector.go +++ b/internal/handler/connector.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "github.com/gin-gonic/gin" @@ -47,23 +48,9 @@ func NewConnectorHandler(connectorService *service.ConnectorService, userService // @Success 200 {object} service.ListConnectorsResponse // @Router /connector/list [get] func (h *ConnectorHandler) ListConnectors(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID diff --git a/internal/handler/document.go b/internal/handler/document.go index 10f08b6baf8..dedd4146a70 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "strconv" "github.com/gin-gonic/gin" @@ -47,6 +48,12 @@ func NewDocumentHandler(documentService *service.DocumentService) *DocumentHandl // @Success 200 {object} map[string]interface{} // @Router /api/v1/documents [post] func (h *DocumentHandler) CreateDocument(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + var req service.CreateDocumentRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ @@ -79,6 +86,12 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/documents/{id} [get] func (h *DocumentHandler) GetDocumentByID(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + id := c.Param("id") if id == "" { c.JSON(http.StatusBadRequest, gin.H{ @@ -111,6 +124,12 @@ func (h *DocumentHandler) GetDocumentByID(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/documents/{id} [put] func (h *DocumentHandler) UpdateDocument(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + id := c.Param("id") if id == "" { c.JSON(http.StatusBadRequest, gin.H{ @@ -149,6 +168,12 @@ func (h *DocumentHandler) UpdateDocument(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/documents/{id} [delete] func (h *DocumentHandler) DeleteDocument(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + id := c.Param("id") if id == "" { c.JSON(http.StatusBadRequest, gin.H{ @@ -180,6 +205,12 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/documents [get] func (h *DocumentHandler) ListDocuments(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) @@ -220,6 +251,12 @@ func (h *DocumentHandler) ListDocuments(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/authors/{author_id}/documents [get] func (h *DocumentHandler) GetDocumentsByAuthorID(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + authorIDStr := c.Param("author_id") authorID, err := strconv.Atoi(authorIDStr) if err != nil { diff --git a/internal/handler/file.go b/internal/handler/file.go index 3474ce0cb52..cae393ffc8a 100644 --- a/internal/handler/file.go +++ b/internal/handler/file.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "strconv" "github.com/gin-gonic/gin" @@ -54,23 +55,9 @@ func NewFileHandler(fileService *service.FileService, userService *service.UserS // @Success 200 {object} service.ListFilesResponse // @Router /v1/file/list [get] func (h *FileHandler) ListFiles(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -130,23 +117,9 @@ func (h *FileHandler) ListFiles(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/file/root_folder [get] func (h *FileHandler) GetRootFolder(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID @@ -178,23 +151,9 @@ func (h *FileHandler) GetRootFolder(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/file/parent_folder [get] func (h *FileHandler) GetParentFolder(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token (for validation) - _, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -235,23 +194,9 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/file/all_parent_folder [get] func (h *FileHandler) GetAllParentFolders(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token (for validation) - _, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } diff --git a/internal/handler/kb.go b/internal/handler/kb.go index d4d4e848ef9..ef608a2639f 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -40,33 +40,6 @@ func NewKnowledgebaseHandler(kbService *service.KnowledgebaseService, userServic } } -// getUserID extracts user ID from authorization header -// It validates the authorization token and returns the user ID -// Parameters: -// - c: gin.Context - the HTTP request context -// -// Returns: -// - string: the user ID -// - common.ErrorCode: the error code -// - error: any error that occurred -func (h *KnowledgebaseHandler) getUserID(c *gin.Context) (string, common.ErrorCode, error) { - token := c.GetHeader("Authorization") - if token == "" { - return "", common.CodeUnauthorized, ErrMissingAuth - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - return "", code, err - } - - if *user.IsSuperuser { - return "", common.CodeForbidden, ErrForbidden - } - - return user.ID, common.CodeSuccess, nil -} - // jsonResponse sends a JSON response with code and message func jsonResponse(c *gin.Context, code common.ErrorCode, data interface{}, message string) { c.JSON(http.StatusOK, gin.H{ @@ -115,9 +88,9 @@ var ( // @Success 200 {object} map[string]interface{} // @Router /v1/kb/create [post] func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -127,7 +100,7 @@ func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) { return } - result, code, err := h.kbService.CreateKB(&req, userID) + result, code, err := h.kbService.CreateKB(&req, user.ID) if err != nil { jsonError(c, code, err.Error()) return @@ -147,9 +120,9 @@ func (h *KnowledgebaseHandler) CreateKB(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/update [post] func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -159,7 +132,7 @@ func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) { return } - result, code, err := h.kbService.UpdateKB(&req, userID) + result, code, err := h.kbService.UpdateKB(&req, user.ID) if err != nil { if strings.Contains(err.Error(), "authorization") { jsonError(c, common.CodeAuthenticationError, err.Error()) @@ -183,9 +156,9 @@ func (h *KnowledgebaseHandler) UpdateKB(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/update_metadata_setting [post] func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) { - _, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -215,9 +188,9 @@ func (h *KnowledgebaseHandler) UpdateMetadataSetting(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/detail [get] func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -227,7 +200,7 @@ func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) { return } - result, code, err := h.kbService.GetDetail(kbID, userID) + result, code, err := h.kbService.GetDetail(kbID, user.ID) if err != nil { if strings.Contains(err.Error(), "authorized") { jsonError(c, common.CodeOperatingError, err.Error()) @@ -251,9 +224,9 @@ func (h *KnowledgebaseHandler) GetDetail(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/list [post] func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -317,7 +290,7 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { ownerIDs = *req.OwnerIDs } - result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, userID) + result, code, err := h.kbService.ListKbs(keywords, page, pageSize, parserID, orderby, desc, ownerIDs, user.ID) if err != nil { jsonError(c, code, err.Error()) return @@ -337,9 +310,9 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/rm [post] func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -351,7 +324,7 @@ func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) { return } - code, err = h.kbService.DeleteKB(req.KBID, userID) + code, err := h.kbService.DeleteKB(req.KBID, user.ID) if err != nil { if strings.Contains(err.Error(), "authorization") { jsonError(c, common.CodeAuthenticationError, err.Error()) @@ -375,9 +348,9 @@ func (h *KnowledgebaseHandler) DeleteKB(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/{kb_id}/tags [get] func (h *KnowledgebaseHandler) ListTags(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -387,7 +360,7 @@ func (h *KnowledgebaseHandler) ListTags(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -406,9 +379,9 @@ func (h *KnowledgebaseHandler) ListTags(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/tags [get] func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -420,7 +393,7 @@ func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) { kbIDs := strings.Split(kbIDsStr, ",") for _, kbID := range kbIDs { - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -441,9 +414,9 @@ func (h *KnowledgebaseHandler) ListTagsFromKbs(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/{kb_id}/rm_tags [post] func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -453,7 +426,7 @@ func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -481,9 +454,9 @@ func (h *KnowledgebaseHandler) RemoveTags(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/{kb_id}/rename_tag [post] func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -493,7 +466,7 @@ func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -521,9 +494,9 @@ func (h *KnowledgebaseHandler) RenameTag(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/{kb_id}/knowledge_graph [get] func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -533,7 +506,7 @@ func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -557,9 +530,9 @@ func (h *KnowledgebaseHandler) KnowledgeGraph(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/{kb_id}/knowledge_graph [delete] func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -569,7 +542,7 @@ func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -588,9 +561,9 @@ func (h *KnowledgebaseHandler) DeleteKnowledgeGraph(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/get_meta [get] func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -602,7 +575,7 @@ func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) { kbIDs := strings.Split(kbIDsStr, ",") for _, kbID := range kbIDs { - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } @@ -622,9 +595,9 @@ func (h *KnowledgebaseHandler) GetMeta(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/kb/basic_info [get] func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) { - userID, code, err := h.getUserID(c) - if err != nil { - jsonError(c, code, err.Error()) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -634,7 +607,7 @@ func (h *KnowledgebaseHandler) GetBasicInfo(c *gin.Context) { return } - if !h.kbService.Accessible(kbID, userID) { + if !h.kbService.Accessible(kbID, user.ID) { jsonError(c, common.CodeAuthenticationError, "No authorization.") return } diff --git a/internal/handler/llm.go b/internal/handler/llm.go index 9582eb37adb..90d28087977 100644 --- a/internal/handler/llm.go +++ b/internal/handler/llm.go @@ -61,23 +61,9 @@ func NewLLMHandler(llmService *service.LLMService, userService *service.UserServ // @Success 200 {object} map[string]interface{} // @Router /v1/llm/my_llms [get] func (h *LLMHandler) GetMyLLMs(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -113,23 +99,9 @@ func (h *LLMHandler) GetMyLLMs(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/llm/set_api_key [post] func (h *LLMHandler) SetAPIKey(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -180,23 +152,9 @@ func (h *LLMHandler) SetAPIKey(c *gin.Context) { // @Success 200 {array} FactoryResponse // @Router /v1/llm/factories [get] func (h *LLMHandler) Factories(c *gin.Context) { - // Extract token from request - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by token - _, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -261,23 +219,9 @@ func (h *LLMHandler) Factories(c *gin.Context) { // @Success 200 {object} map[string][]service.LLMListItem // @Router /v1/llm/list [get] func (h *LLMHandler) ListApp(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } diff --git a/internal/handler/search.go b/internal/handler/search.go index b291a780270..7eb17ea9bdb 100644 --- a/internal/handler/search.go +++ b/internal/handler/search.go @@ -18,6 +18,7 @@ package handler import ( "net/http" + "ragflow/internal/common" "strconv" "github.com/gin-gonic/gin" @@ -54,23 +55,9 @@ func NewSearchHandler(searchService *service.SearchService, userService *service // @Success 200 {object} service.ListSearchAppsResponse // @Router /v1/search/list [post] func (h *SearchHandler) ListSearchApps(c *gin.Context) { - // Get access token from Authorization header - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Missing Authorization header", - }) - return - } - - // Get user by access token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": code, - "message": err.Error(), - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } userID := user.ID diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 860acc3bbbc..bb43ffb98ac 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -49,23 +49,9 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service // @Success 200 {object} map[string]interface{} // @Router /v1/user/tenant_info [get] func (h *TenantHandler) TenantInfo(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -105,23 +91,9 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/tenant/list [get] func (h *TenantHandler) TenantList(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } diff --git a/internal/handler/user.go b/internal/handler/user.go index 8ec2d314f38..2678ecf1bf1 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -291,30 +291,14 @@ func (h *UserHandler) ListUsers(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/logout [post] func (h *UserHandler) Logout(c *gin.Context) { - // Extract token from request - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Missing Authorization header", - "data": false, - }) - return - } - - // Get user by token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } // Logout user - code, err = h.userService.Logout(user) + code, err := h.userService.Logout(user) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": code, @@ -341,25 +325,9 @@ func (h *UserHandler) Logout(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/info [get] func (h *UserHandler) Info(c *gin.Context) { - // Extract token from request - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Missing Authorization header", - "data": false, - }) - return - } - - // Get user by token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -446,25 +414,9 @@ func (h *UserHandler) Setting(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/setting/password [post] func (h *UserHandler) ChangePassword(c *gin.Context) { - // Extract token from request - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Missing Authorization header", - "data": false, - }) - return - } - - // Get user by token - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -480,7 +432,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { } // Change password - code, err = h.userService.ChangePassword(user, &req) + code, err := h.userService.ChangePassword(user, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": code, @@ -534,23 +486,9 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/set_tenant_info [post] func (h *UserHandler) SetTenantInfo(c *gin.Context) { - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeUnauthorized, - "message": "Unauthorized!", - "data": false, - }) - return - } - - user, code, err := h.userService.GetUserByToken(token) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "code": code, - "message": err.Error(), - "data": false, - }) + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) return } @@ -564,7 +502,7 @@ func (h *UserHandler) SetTenantInfo(c *gin.Context) { return } - err = h.userService.SetTenantInfo(user.ID, &req) + err := h.userService.SetTenantInfo(user.ID, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeDataError, diff --git a/internal/router/router.go b/internal/router/router.go index cebd6b97ac7..b7f8b0a6714 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -24,6 +24,7 @@ import ( // Router router type Router struct { + authHandler *handler.AuthHandler userHandler *handler.UserHandler tenantHandler *handler.TenantHandler documentHandler *handler.DocumentHandler @@ -40,6 +41,7 @@ type Router struct { // NewRouter create router func NewRouter( + authHandler *handler.AuthHandler, userHandler *handler.UserHandler, tenantHandler *handler.TenantHandler, documentHandler *handler.DocumentHandler, @@ -54,6 +56,7 @@ func NewRouter( fileHandler *handler.FileHandler, ) *Router { return &Router{ + authHandler: authHandler, userHandler: userHandler, tenantHandler: tenantHandler, documentHandler: documentHandler, @@ -83,132 +86,138 @@ func (r *Router) Setup(engine *gin.Engine) { engine.GET("/v1/system/config", r.systemHandler.GetConfig) engine.GET("/v1/system/configs", r.systemHandler.GetConfigs) engine.GET("/v1/system/version", r.systemHandler.GetVersion) - - // User login by email endpoint - engine.POST("/v1/user/login", r.userHandler.LoginByEmail) engine.POST("/v1/user/register", r.userHandler.Register) // User login channels endpoint engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels) - // User logout endpoint - engine.GET("/v1/user/logout", r.userHandler.Logout) - // User info endpoint - engine.GET("/v1/user/info", r.userHandler.Info) - // User tenant info endpoint - engine.GET("/v1/user/tenant_info", r.tenantHandler.TenantInfo) - // Tenant list endpoint - engine.GET("/v1/tenant/list", r.tenantHandler.TenantList) - // User settings endpoint - engine.POST("/v1/user/setting", r.userHandler.Setting) - // User change password endpoint - engine.POST("/v1/user/setting/password", r.userHandler.ChangePassword) - // User set tenant info endpoint - engine.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo) - - // API v1 route group - v1 := engine.Group("/api/v1") + + // User login by email endpoint + engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + + // Protected routes + authorized := engine.Group("") + authorized.Use(r.authHandler.AuthMiddleware()) { - // User routes - users := v1.Group("/users") + // User logout endpoint + authorized.GET("/v1/user/logout", r.userHandler.Logout) + // User info endpoint + authorized.GET("/v1/user/info", r.userHandler.Info) + // User tenant info endpoint + authorized.GET("/v1/user/tenant_info", r.tenantHandler.TenantInfo) + // Tenant list endpoint + authorized.GET("/v1/tenant/list", r.tenantHandler.TenantList) + // User settings endpoint + authorized.POST("/v1/user/setting", r.userHandler.Setting) + // User change password endpoint + authorized.POST("/v1/user/setting/password", r.userHandler.ChangePassword) + // User set tenant info endpoint + authorized.POST("/v1/user/set_tenant_info", r.userHandler.SetTenantInfo) + + // API v1 route group + v1 := authorized.Group("/api/v1") { - users.POST("/register", r.userHandler.Register) - users.POST("/login", r.userHandler.Login) - users.GET("", r.userHandler.ListUsers) - users.GET("/:id", r.userHandler.GetUserByID) + // User routes + users := v1.Group("/users") + { + users.POST("/register", r.userHandler.Register) + users.POST("/login", r.userHandler.Login) + users.GET("", r.userHandler.ListUsers) + users.GET("/:id", r.userHandler.GetUserByID) + } + + // Document routes + documents := v1.Group("/documents") + { + documents.POST("", r.documentHandler.CreateDocument) + documents.GET("", r.documentHandler.ListDocuments) + documents.GET("/:id", r.documentHandler.GetDocumentByID) + documents.PUT("/:id", r.documentHandler.UpdateDocument) + documents.DELETE("/:id", r.documentHandler.DeleteDocument) + } + + // Author routes + authors := v1.Group("/authors") + { + authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID) + } } - // Document routes - documents := v1.Group("/documents") + // Knowledge base routes + kb := authorized.Group("/v1/kb") { - documents.POST("", r.documentHandler.CreateDocument) - documents.GET("", r.documentHandler.ListDocuments) - documents.GET("/:id", r.documentHandler.GetDocumentByID) - documents.PUT("/:id", r.documentHandler.UpdateDocument) - documents.DELETE("/:id", r.documentHandler.DeleteDocument) + kb.POST("/create", r.knowledgebaseHandler.CreateKB) + kb.POST("/update", r.knowledgebaseHandler.UpdateKB) + kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting) + kb.GET("/detail", r.knowledgebaseHandler.GetDetail) + kb.POST("/list", r.knowledgebaseHandler.ListKbs) + kb.POST("/rm", r.knowledgebaseHandler.DeleteKB) + kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs) + kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta) + kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo) + + // KB ID specific routes + kbByID := kb.Group("/:kb_id") + { + kbByID.GET("/tags", r.knowledgebaseHandler.ListTags) + kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags) + kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag) + kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph) + kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph) + } } - // Author routes - authors := v1.Group("/authors") + // Chunk routes + chunk := authorized.Group("/v1/chunk") { - authors.GET("/:author_id/documents", r.documentHandler.GetDocumentsByAuthorID) + chunk.POST("/retrieval_test", r.chunkHandler.RetrievalTest) } - } - // Knowledge base routes - kb := engine.Group("/v1/kb") - { - kb.POST("/create", r.knowledgebaseHandler.CreateKB) - kb.POST("/update", r.knowledgebaseHandler.UpdateKB) - kb.POST("/update_metadata_setting", r.knowledgebaseHandler.UpdateMetadataSetting) - kb.GET("/detail", r.knowledgebaseHandler.GetDetail) - kb.POST("/list", r.knowledgebaseHandler.ListKbs) - kb.POST("/rm", r.knowledgebaseHandler.DeleteKB) - kb.GET("/tags", r.knowledgebaseHandler.ListTagsFromKbs) - kb.GET("/get_meta", r.knowledgebaseHandler.GetMeta) - kb.GET("/basic_info", r.knowledgebaseHandler.GetBasicInfo) - - // KB ID specific routes - kbByID := kb.Group("/:kb_id") + // LLM routes + llm := authorized.Group("/v1/llm") { - kbByID.GET("/tags", r.knowledgebaseHandler.ListTags) - kbByID.POST("/rm_tags", r.knowledgebaseHandler.RemoveTags) - kbByID.POST("/rename_tag", r.knowledgebaseHandler.RenameTag) - kbByID.GET("/knowledge_graph", r.knowledgebaseHandler.KnowledgeGraph) - kbByID.DELETE("/knowledge_graph", r.knowledgebaseHandler.DeleteKnowledgeGraph) + llm.GET("/my_llms", r.llmHandler.GetMyLLMs) + llm.GET("/factories", r.llmHandler.Factories) + llm.GET("/list", r.llmHandler.ListApp) + llm.POST("/set_api_key", r.llmHandler.SetAPIKey) } - } - // Chunk routes - chunk := engine.Group("/v1/chunk") - { - chunk.POST("/retrieval_test", r.chunkHandler.RetrievalTest) - } - - // LLM routes - llm := engine.Group("/v1/llm") - { - llm.GET("/my_llms", r.llmHandler.GetMyLLMs) - llm.GET("/factories", r.llmHandler.Factories) - llm.GET("/list", r.llmHandler.ListApp) - llm.POST("/set_api_key", r.llmHandler.SetAPIKey) - } - - // Chat routes - chat := engine.Group("/v1/dialog") - { - chat.GET("/list", r.chatHandler.ListChats) - chat.POST("/next", r.chatHandler.ListChatsNext) - chat.POST("/set", r.chatHandler.SetDialog) - chat.POST("/rm", r.chatHandler.RemoveChats) - } + // Chat routes + chat := authorized.Group("/v1/dialog") + { + chat.GET("/list", r.chatHandler.ListChats) + chat.POST("/next", r.chatHandler.ListChatsNext) + chat.POST("/set", r.chatHandler.SetDialog) + chat.POST("/rm", r.chatHandler.RemoveChats) + } - // Chat session (conversation) routes - session := engine.Group("/v1/conversation") - { - session.POST("/set", r.chatSessionHandler.SetChatSession) - session.POST("/rm", r.chatSessionHandler.RemoveChatSessions) - session.GET("/list", r.chatSessionHandler.ListChatSessions) - session.POST("/completion", r.chatSessionHandler.Completion) - } + // Chat session (conversation) routes + session := authorized.Group("/v1/conversation") + { + session.POST("/set", r.chatSessionHandler.SetChatSession) + session.POST("/rm", r.chatSessionHandler.RemoveChatSessions) + session.GET("/list", r.chatSessionHandler.ListChatSessions) + session.POST("/completion", r.chatSessionHandler.Completion) + } - // Connector routes - connector := engine.Group("/v1/connector") - { - connector.GET("/list", r.connectorHandler.ListConnectors) - } + // Connector routes + connector := authorized.Group("/v1/connector") + { + connector.GET("/list", r.connectorHandler.ListConnectors) + } - // Search routes - search := engine.Group("/v1/search") - { - search.POST("/list", r.searchHandler.ListSearchApps) - } + // Search routes + search := authorized.Group("/v1/search") + { + search.POST("/list", r.searchHandler.ListSearchApps) + } - // File routes - file := engine.Group("/v1/file") - { - file.GET("/list", r.fileHandler.ListFiles) - file.GET("/root_folder", r.fileHandler.GetRootFolder) - file.GET("/parent_folder", r.fileHandler.GetParentFolder) - file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders) + // File routes + file := authorized.Group("/v1/file") + { + file.GET("/list", r.fileHandler.ListFiles) + file.GET("/root_folder", r.fileHandler.GetRootFolder) + file.GET("/parent_folder", r.fileHandler.GetParentFolder) + file.GET("/all_parent_folder", r.fileHandler.GetAllParentFolders) + } } // Handle undefined routes From e294fee1452cbb6159e9a8bc852b6f03d582913d Mon Sep 17 00:00:00 2001 From: Jimmy Ben Klieve Date: Wed, 11 Mar 2026 11:27:20 +0800 Subject: [PATCH 206/565] refactor(ui): update knowledge graph, chunk, metadata, agent log styles (#13518) ### What problem does this PR solve? Update UI styles: - **Dataset** > **Knowledge graph** tooltip - **Dataset** > **Files** > **Manage metadata** modal - **Dataset** > **Files** > **Modify Chunking Method** > **Auto metadata** > **Manage generation settings** modal - **Agent** > **Canvas (Ingestion pipeline)** > **Dataflow result** ### Type of change - [x] Refactoring --- .../document-preview/document-header.tsx | 28 ++- web/src/components/document-preview/index.tsx | 2 +- .../document-preview/pdf-preview.tsx | 8 +- web/src/components/edit-tag/index.tsx | 39 ++-- web/src/components/list-filter-bar/index.tsx | 7 +- web/src/components/originui/timeline.tsx | 77 +++---- .../raptor-form-fields.tsx | 4 +- web/src/components/ui/button.tsx | 2 +- web/src/components/ui/checkbox.tsx | 10 +- web/src/components/ui/form.tsx | 9 +- web/src/components/ui/radio.tsx | 22 +- web/src/components/ui/segmented.tsx | 205 ++++++++++++------ web/src/components/ui/tooltip.tsx | 23 +- web/src/less/mixins.less | 4 +- web/src/locales/en.ts | 7 +- .../components/chunk-card/index.tsx | 42 ++-- .../chunk-result-bar/checkbox-sets.tsx | 59 +++-- .../components/chunk-result-bar/index.tsx | 70 +++--- .../components/knowledge-chunk/index.tsx | 135 ++++++------ .../components/time-line/index.tsx | 25 ++- web/src/pages/dataflow-result/index.tsx | 70 ++---- .../metedata/manage-modal-column.tsx | 14 +- .../components/metedata/manage-modal.tsx | 77 +++---- .../metedata/manage-values-modal.tsx | 4 +- .../dataset-setting/category-panel.tsx | 5 +- .../chunk-method-learn-more.tsx | 30 ++- .../configuration/common-item.tsx | 4 +- .../dataset/use-dataset-table-columns.tsx | 2 - .../dataset/knowledge-graph/force-graph.tsx | 88 +++++--- .../dataset/knowledge-graph/index.module.less | 4 +- .../pages/dataset/knowledge-graph/index.tsx | 20 +- web/src/pages/datasets/index.tsx | 2 +- 32 files changed, 613 insertions(+), 485 deletions(-) diff --git a/web/src/components/document-preview/document-header.tsx b/web/src/components/document-preview/document-header.tsx index f6656da86de..d5d47d72b19 100644 --- a/web/src/components/document-preview/document-header.tsx +++ b/web/src/components/document-preview/document-header.tsx @@ -1,21 +1,35 @@ import { formatDate } from '@/utils/date'; import { formatBytes } from '@/utils/file-util'; +import { useTranslation } from 'react-i18next'; type Props = { size: number; name: string; create_date: string; + className?: string; }; -export default ({ size, name, create_date }: Props) => { +export default ({ size, name, create_date, className }: Props) => { const sizeName = formatBytes(size); const dateStr = formatDate(create_date); + + const { t } = useTranslation(); + return ( -
-

{name}

-
- Size:{sizeName} Uploaded Time:{dateStr} -
-
+
+

{name}

+
+
{t('chunk.size')}
+
{sizeName}
+ +
{t('chunk.uploadedTime')}
+
{dateStr}
+
+
); }; diff --git a/web/src/components/document-preview/index.tsx b/web/src/components/document-preview/index.tsx index 62a9078c43f..9e31111d300 100644 --- a/web/src/components/document-preview/index.tsx +++ b/web/src/components/document-preview/index.tsx @@ -25,7 +25,7 @@ const Preview = ({ return ( <> {fileType === 'pdf' && highlights && setWidthAndHeight && ( -
+
& { httpHeaders?: Record; }; @@ -69,7 +69,11 @@ const PdfPreview = ({ return (
void; @@ -59,28 +56,35 @@ const EditTag = React.forwardRef( const forMap = (tag: string) => { return ( - - {tag} - -
+ + {tag} + +
{tag}
{!disabled && ( - { e.preventDefault(); handleClose(tag); }} - /> + > + + )}
- - +
+
); }; @@ -107,10 +111,11 @@ const EditTag = React.forwardRef( )}
{Array.isArray(tagChild) && tagChild.length > 0 && <>{tagChild}} + {!inputVisible && !disabled && ( + ); + })} + +
handleOnChange(actualValue)} - > - {isObject ? option.label : option} - - ); - })} -
- ); - }, + style={{ + positionAnchor: `--${anchorNamePrefix}-${String(selectedValue).replace('/', '')}`, + width: 'anchor-size(width)', + height: 'anchor-size(height)', + top: 'anchor(top)', + left: 'anchor(left)', + }} + /> +
+ ); + } + : ( + { + options, + value, + onChange, + className, + activeClassName, + itemClassName, + rounded = 'default', + sizeType = 'default', + buttonSize = 'default', + }, + ref, + ) => { + const [selectedValue, setSelectedValue] = React.useState< + SegmentedValue | undefined + >(value); + React.useEffect(() => { + setSelectedValue(value); + }, [value]); + const handleOnChange = (e: SegmentedValue) => { + if (onChange) { + onChange(e); + } + setSelectedValue(e); + }; + return ( +
+ {options.map((option) => { + const isObject = typeof option === 'object'; + const actualValue = isObject ? option.value : option; + + return ( + + ); + })} +
+ ); + }, ); -Segmented.displayName = 'Segmented'; +export { Segmented }; diff --git a/web/src/components/ui/tooltip.tsx b/web/src/components/ui/tooltip.tsx index 8f687209896..7042d11cf15 100644 --- a/web/src/components/ui/tooltip.tsx +++ b/web/src/components/ui/tooltip.tsx @@ -16,15 +16,17 @@ const TooltipContent = React.forwardRef< React.ElementRef, React.ComponentPropsWithoutRef >(({ className, sideOffset = 4, ...props }, ref) => ( - + + + )); TooltipContent.displayName = TooltipPrimitive.Content.displayName; @@ -35,11 +37,12 @@ export const FormTooltip = ({ tooltip }: { tooltip: React.ReactNode }) => { { e.preventDefault(); // Prevent clicking the tooltip from triggering form save }} > - + {tooltip} diff --git a/web/src/less/mixins.less b/web/src/less/mixins.less index e6c99f601a8..c77eac838c9 100644 --- a/web/src/less/mixins.less +++ b/web/src/less/mixins.less @@ -17,8 +17,8 @@ caption { color: @blurBackground; font-size: 14px; - height: 20px; - line-height: 20px; + // height: 20px; + line-height: 1.25; font-weight: 600; margin-bottom: 6px; } diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index debbcf593cd..99c7955a919 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -584,8 +584,9 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim naive: `

Supported file formats are MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPTX, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

This method chunks files using a 'naive' method:

+

  • Use vision detection model to split the texts into smaller segments.
  • -
  • Then, combine adjacent segments until the token count exceeds the threshold specified by 'Chunk token number for text', at which point a chunk is created.
  • `, +
  • Then, combine adjacent segments until the token count exceeds the threshold specified by 'Chunk token number for text', at which point a chunk is created.

`, paper: `

Only PDF file is supported.

Papers will be split by section, such as abstract, 1.1, 1.2.

This approach enables the LLM to summarize the paper more effectively and to provide more comprehensive, understandable responses. @@ -597,6 +598,7 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim

This chunking method supports XLSX and CSV/TXT file formats.

+
  • If a file is in XLSX or XLS (Excel 97-2003) format, it should contain two columns without headers: one for questions and the other for answers, with the question column preceding the answer column. Multiple sheets are acceptable, provided the columns are properly structured. @@ -604,6 +606,7 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim
  • If a file is in CSV/TXT format, it must be UTF-8 encoded with TAB as the delimiter to separate questions and answers.
  • +

Lines of texts that fail to follow the above rules will be ignored, and @@ -726,6 +729,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s table: 'Table', text: 'Text', }, + size: 'Size', + uploadedTime: 'Uploaded time', chunk: 'Chunk', bulk: 'Bulk', selectAll: 'Select all', diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx index 32f7dd2ed1a..4372c421545 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx @@ -73,10 +73,11 @@ const ChunkCard = ({ return ( - +

+ {/* Using instead of to avoid flickering when hovering over the image */} {item.image_id && ( @@ -98,41 +103,44 @@ const ChunkCard = ({ + )}
+ className={classNames( + // Keep whitespaces? + 'text-wrap break-words whitespace-pre', + textMode === ChunkTextMode.Ellipse && 'line-clamp-3', + )} + />
-
+
void; removeChunk: (e?: any) => void; switchChunk: (available: number) => void; @@ -13,6 +20,7 @@ type ICheckboxSetProps = { }; export default (props: ICheckboxSetProps) => { const { + className, selectAllChunk, removeChunk, switchChunk, @@ -45,39 +53,28 @@ export default (props: ICheckboxSetProps) => { }, [selectedChunkIds]); return ( -
-
- - -
+
+ + {isSelected && ( <> -
- - {t('chunk.enable')} -
-
- + + +
-
- - {t('chunk.delete')} -
+ + + )}
diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-result-bar/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-result-bar/index.tsx index e5221a15c9f..e05c4c121a0 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-result-bar/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-result-bar/index.tsx @@ -8,7 +8,8 @@ import { import { Radio } from '@/components/ui/radio'; import { Segmented } from '@/components/ui/segmented'; import { useTranslate } from '@/hooks/common-hooks'; -import { ListFilter, Plus } from 'lucide-react'; +import { cn } from '@/lib/utils'; +import { LucideFilter, Plus } from 'lucide-react'; import { useState } from 'react'; import { ChunkTextMode } from '../../constant'; interface ChunkResultBarProps { @@ -21,6 +22,7 @@ interface ChunkResultBarProps { searchString: string; } export default function ChunkResultBar({ + className, changeChunkTextMode, available, selectAllChunk, @@ -59,42 +61,46 @@ export default function ChunkResultBar({ changeChunkTextMode(value); }; return ( -
+
-
-
- } - onChange={handleInputChange} - value={searchString} - /> - - - - - - {filterContent} - - - -
+ + + + + + + {filterContent} + + + + + + {/*
*/}
diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx index 5c3a840404c..3e53d6171e1 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx @@ -24,6 +24,7 @@ import DocumentHeader from '@/components/document-preview/document-header'; import { useGetDocumentUrl } from '@/components/document-preview/hooks'; import { PageHeader } from '@/components/page-header'; import { Button } from '@/components/ui/button'; +import { Card, CardContent } from '@/components/ui/card'; import message from '@/components/ui/message'; import { RAGFlowPagination, @@ -174,7 +175,7 @@ const Chunk = () => { }, [documentInfo]); return ( - <> +
-
-
-
-
- -
-
+ + + +
+ + +
-
-
-
+
+ + +
- -
-
-

{t('chunk.chunkResult')}

-
- {t('chunk.chunkResultTip')} -
-
+
+

{t('chunk.chunkResult')}

+
+ {t('chunk.chunkResultTip')}
-
- -
+
+ + +
+
+ + { selectedChunkIds={selectedChunkIds} />
-
-
- {chunkList.map((item) => ( - x === item.chunk_id, - )} - handleCheckboxClick={handleSingleCheckboxClick} - switchChunk={handleSwitchChunk} - clickChunkCard={handleChunkCardClick} - selected={item.chunk_id === selectedChunkId} - textMode={textMode} - t={dataUpdatedAt} - > - ))} -
+ +
+ {chunkList.map((item) => ( + x === item.chunk_id, + )} + handleCheckboxClick={handleSingleCheckboxClick} + switchChunk={handleSwitchChunk} + clickChunkCard={handleChunkCardClick} + selected={item.chunk_id === selectedChunkId} + textMode={textMode} + t={dataUpdatedAt} + /> + ))}
-
+ +
{ onChange={(page, pageSize) => { onPaginationChange(page, pageSize); }} - > -
+ /> +
-
-
-
+ + + + {chunkUpdatingVisible && ( { parserId={documentInfo.parser_id} /> )} - +
); }; diff --git a/web/src/pages/dataflow-result/components/time-line/index.tsx b/web/src/pages/dataflow-result/components/time-line/index.tsx index 92a1d236d1d..e153d925027 100644 --- a/web/src/pages/dataflow-result/components/time-line/index.tsx +++ b/web/src/pages/dataflow-result/components/time-line/index.tsx @@ -1,11 +1,11 @@ import { CustomTimeline, TimelineNode } from '@/components/originui/timeline'; import { - Blocks, - File, - FilePlay, - FileStack, - Heading, - ListPlus, + LucideBlocks, + LucideFile, + LucideFilePlay, + LucideFileStack, + LucideHeading, + LucideListPlus, } from 'lucide-react'; import { useMemo } from 'react'; import { TimelineNodeType } from '../../constant'; @@ -21,28 +21,28 @@ export type ITimelineNodeObj = { export const TimelineNodeObj = { [TimelineNodeType.begin]: { title: 'File', - icon: , + icon: , clickable: false, }, [TimelineNodeType.parser]: { title: 'Parser', - icon: , + icon: , }, [TimelineNodeType.contextGenerator]: { title: 'Context Generator', - icon: , + icon: , }, [TimelineNodeType.titleSplitter]: { title: 'Title Splitter', - icon: , + icon: , }, [TimelineNodeType.characterSplitter]: { title: 'Character Splitter', - icon: , + icon: , }, [TimelineNodeType.tokenizer]: { title: 'Tokenizer', - icon: , + icon: , clickable: false, }, }; @@ -80,6 +80,7 @@ const TimelineDataFlow = ({ onStepChange={handleStepChange} orientation="horizontal" lineStyle="solid" + lineColor="rgb(var(--))" nodeSize={24} activeStyle={{ nodeSize: 30, diff --git a/web/src/pages/dataflow-result/index.tsx b/web/src/pages/dataflow-result/index.tsx index 5b7819d4758..ada172a17a3 100644 --- a/web/src/pages/dataflow-result/index.tsx +++ b/web/src/pages/dataflow-result/index.tsx @@ -19,27 +19,21 @@ import { useGetDocumentUrl } from '@/components/document-preview/hooks'; import { TimelineNode } from '@/components/originui/timeline'; import { PageHeader } from '@/components/page-header'; import Spotlight from '@/components/spotlight'; -import { - Breadcrumb, - BreadcrumbItem, - BreadcrumbLink, - BreadcrumbList, - BreadcrumbPage, - BreadcrumbSeparator, -} from '@/components/ui/breadcrumb'; import { Button } from '@/components/ui/button'; import { Modal } from '@/components/ui/modal/modal'; -import { AgentCategory } from '@/constants/agent'; +import { AgentCategory, AgentQuery } from '@/constants/agent'; import { Images } from '@/constants/common'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; +import { Routes } from '@/routes'; +import { LucideArrowBigLeft } from 'lucide-react'; import TimelineDataFlow from './components/time-line'; import { TimelineNodeType } from './constant'; import styles from './index.module.less'; import { IDslComponent, IPipelineFileLogDetail } from './interface'; import ParserContainer from './parser'; -const Chunk = () => { +const DataflowResult = () => { const { isReadOnly, knowledgeId, agentId, agentTitle, documentExtension } = useGetPipelineResultSearchParams(); @@ -158,46 +152,22 @@ const Chunk = () => { return ( <> - - - - { - if (knowledgeId) { - navigateToDatasetList(); - } - if (agentId) { - navigateToAgents(); - } - }} - > - {knowledgeId ? t('knowledgeDetails.dataset') : t('header.flow')} - - - - - { - if (knowledgeId) { - navigateToDatasetOverview(knowledgeId)(); - } - if (isAgent) { - navigateToAgent(agentId, AgentCategory.DataflowCanvas)(); - } - }} - > - {knowledgeId ? t('knowledgeDetails.overview') : agentTitle} - - - - - - {knowledgeId ? documentInfo?.name : t('flow.viewResult')} - - - - + + {type === 'dataflow' && (
{ ); }; -export default Chunk; +export default DataflowResult; diff --git a/web/src/pages/dataset/components/metedata/manage-modal-column.tsx b/web/src/pages/dataset/components/metedata/manage-modal-column.tsx index 266fba97751..bae956c3a14 100644 --- a/web/src/pages/dataset/components/metedata/manage-modal-column.tsx +++ b/web/src/pages/dataset/components/metedata/manage-modal-column.tsx @@ -4,7 +4,7 @@ import { DatePicker } from '@/components/ui/date-picker'; import { Input } from '@/components/ui/input'; import { formatDate } from '@/utils/date'; import { ColumnDef, Row, Table } from '@tanstack/react-table'; -import { ListChevronsDownUp, Settings, Trash2 } from 'lucide-react'; +import { ListChevronsDownUp, LucidePencil, Trash2 } from 'lucide-react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { @@ -147,7 +147,7 @@ export const useMetadataColumns = ({ header: () => {t('knowledgeDetails.metadata.description')}, cell: ({ row }) => (
- {row.getValue('description')} + {row.getValue('description') || '-'}
), }, @@ -347,17 +347,17 @@ export const useMetadataColumns = ({ cell: ({ row }) => (
- )} */} - {isCanAdd && activeTab !== 'built-in' && ( - - )} -
+
+ {secondTitle || t('knowledgeDetails.metadata.metadata')}
{rowSelectionIsEmpty || ( @@ -366,14 +338,43 @@ export const ManageMetadataModal = (props: IManageModalProps) => { value={activeTab} onValueChange={(v) => setActiveTab(v as MetadataSettingsTab)} > - - - {t('knowledgeDetails.metadata.generation')} - - - {t('knowledgeDetails.metadata.builtIn')} - - +
+ + + {t('knowledgeDetails.metadata.generation')} + + + {t('knowledgeDetails.metadata.builtIn')} + + + +
+ {/* {metadataType === MetadataType.Manage && ( + + )} */} + {isCanAdd && activeTab !== 'built-in' && ( + + )} +
+
+ diff --git a/web/src/pages/dataset/components/metedata/manage-values-modal.tsx b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx index f2a74a1c9ef..e54453aa70e 100644 --- a/web/src/pages/dataset/components/metedata/manage-values-modal.tsx +++ b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx @@ -283,8 +283,8 @@ export const ManageValuesModal = (props: IManageValuesProps) => { metaData.valueType === metadataValueTypeEnum['list'] && (
); }; diff --git a/web/src/pages/dataset/dataset-setting/chunk-method-learn-more.tsx b/web/src/pages/dataset/dataset-setting/chunk-method-learn-more.tsx index 22a013c6992..abc3a47cd11 100644 --- a/web/src/pages/dataset/dataset-setting/chunk-method-learn-more.tsx +++ b/web/src/pages/dataset/dataset-setting/chunk-method-learn-more.tsx @@ -1,7 +1,8 @@ import { Button } from '@/components/ui/button'; +import { Card, CardContent } from '@/components/ui/card'; import { cn } from '@/lib/utils'; import { t } from 'i18next'; -import { X } from 'lucide-react'; +import { LucideX } from 'lucide-react'; import { useState } from 'react'; import CategoryPanel from './category-panel'; @@ -20,20 +21,25 @@ const ChunkMethodLearnMore = ({ parserId }: { parserId: string }) => { {t('knowledgeDetails.learnMore')} -
- -
{ - setVisible(false); - }} +
-
+ + + + + + + ); }; diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index fca3070e0cd..0ce24fc29c7 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -415,8 +415,8 @@ export function AutoMetadata({ avatar={knowledgeBase.avatar} name={knowledgeBase.name} className="size-8" - > -
+ /> +
{knowledgeBase.name}
diff --git a/web/src/pages/dataset/dataset/use-dataset-table-columns.tsx b/web/src/pages/dataset/dataset/use-dataset-table-columns.tsx index b59dcda09dc..f053fb6672a 100644 --- a/web/src/pages/dataset/dataset/use-dataset-table-columns.tsx +++ b/web/src/pages/dataset/dataset/use-dataset-table-columns.tsx @@ -46,7 +46,6 @@ export function useDatasetTableColumns({ id: 'select', header: ({ table }) => ( ( row.toggleSelected(!!value)} aria-label="Select row" diff --git a/web/src/pages/dataset/knowledge-graph/force-graph.tsx b/web/src/pages/dataset/knowledge-graph/force-graph.tsx index d4776b22e4f..2062f6bb4bc 100644 --- a/web/src/pages/dataset/knowledge-graph/force-graph.tsx +++ b/web/src/pages/dataset/knowledge-graph/force-graph.tsx @@ -1,15 +1,16 @@ import { ElementDatum, Graph, IElementEvent } from '@antv/g6'; import isEmpty from 'lodash/isEmpty'; -import { useCallback, useEffect, useMemo, useRef } from 'react'; +import { useCallback, useEffect, useId, useMemo, useRef } from 'react'; import { buildNodesAndCombos, defaultComboLabel } from './util'; import { useIsDarkTheme } from '@/components/theme-provider'; +import { cn } from '@/lib/utils'; import styles from './index.module.less'; const TooltipColorMap = { - combo: 'red', - node: 'black', - edge: 'blue', + combo: 'text-red-600', + node: 'text-black', + edge: 'text-blue-600', }; interface IProps { @@ -18,6 +19,7 @@ interface IProps { } const ForceGraph = ({ data, show }: IProps) => { + const tooltipId = useId(); const containerRef = useRef(null); const graphRef = useRef(null); const isDark = useIsDarkTheme(); @@ -52,23 +54,46 @@ const ForceGraph = ({ data, show }: IProps) => { getContent: (e: IElementEvent, items: ElementDatum) => { if (Array.isArray(items)) { if (items.some((x) => x?.isCombo)) { - return `

${items?.[0]?.data?.label}

`; + return `

${items?.[0]?.data?.label}

`; } - let result = ``; - items.forEach((item) => { - result += `

${item?.id}

`; - if (item?.entity_type) { - result += `
Entity type: ${item?.entity_type}
`; - } - if (item?.weight) { - result += `
Weight: ${item?.weight}
`; - } - if (item?.description) { - result += `

${item?.description}

`; - } - }); - return result + '
'; + + return items + .flatMap((item) => { + return [ + '
', + `

${item?.id}

`, + '
', + ...(item?.entity_type + ? [ + '
', + '
Entity type:
', + `
${item.entity_type}
`, + '
', + ] + : []), + ...(item?.weight + ? [ + '
', + '
Weight:
', + `
${item.weight}
`, + '
', + ] + : []), + '
', + item.description + ? `

${item.description}

` + : '', + '
', + ]; + }) + .join(''); } + return undefined; }, }, @@ -82,34 +107,32 @@ const ForceGraph = ({ data, show }: IProps) => { node: { style: { size: (d) => { - let size = 100 + ((d.rank as number) || 0) * 5; - size = size > 300 ? 300 : size; - return size; + const size = 100 + ((d.rank as number) || 0) * 5; + return Math.min(size, 300); }, + labelText: (d) => d.id, labelFill: isDark ? 'rgba(255,255,255,1)' : 'rgba(0,0,0,1)', // labelPadding: 30, labelFontSize: 40, - // labelOffsetX: 20, + // labelOffsetX: 20, labelOffsetY: 20, labelPlacement: 'center', labelWordWrap: true, }, palette: { type: 'group', - field: (d) => { - return d?.entity_type as string; - }, + field: (d) => d?.entity_type as string, }, }, edge: { style: (model) => { const weight: number = Number(model?.weight) || 2; - const lineWeight = weight * 4; + return { stroke: isDark ? 'rgba(255,255,255,0.5)' : 'rgba(0,0,0,0.5)', lineDash: [10, 10], - lineWidth: lineWeight > 8 ? 8 : lineWeight, + lineWidth: Math.min(weight * 4, 8), }; }, }, @@ -149,12 +172,9 @@ const ForceGraph = ({ data, show }: IProps) => { return (
); }; diff --git a/web/src/pages/dataset/knowledge-graph/index.module.less b/web/src/pages/dataset/knowledge-graph/index.module.less index 7c5d1f5a869..6af1b11c9af 100644 --- a/web/src/pages/dataset/knowledge-graph/index.module.less +++ b/web/src/pages/dataset/knowledge-graph/index.module.less @@ -1,5 +1,7 @@ .forceContainer { :global(.tooltip) { - border-radius: 10px !important; + padding: 0.5rem 0.75rem !important; + border-radius: 0.5rem !important; + font-family: var(--font-sans) !important; } } diff --git a/web/src/pages/dataset/knowledge-graph/index.tsx b/web/src/pages/dataset/knowledge-graph/index.tsx index 539b752d2c6..6b31f1fc4d5 100644 --- a/web/src/pages/dataset/knowledge-graph/index.tsx +++ b/web/src/pages/dataset/knowledge-graph/index.tsx @@ -1,7 +1,8 @@ import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog'; import { Button } from '@/components/ui/button'; +import { Card } from '@/components/ui/card'; import { useFetchKnowledgeGraph } from '@/hooks/use-knowledge-request'; -import { Trash2 } from 'lucide-react'; +import { LucideTrash2 } from 'lucide-react'; import React from 'react'; import { useTranslation } from 'react-i18next'; import ForceGraph from './force-graph'; @@ -13,18 +14,23 @@ const KnowledgeGraph: React.FC = () => { const { handleDeleteKnowledgeGraph } = useDeleteKnowledgeGraph(); return ( -
+ - -
+ + + ); }; diff --git a/web/src/pages/datasets/index.tsx b/web/src/pages/datasets/index.tsx index 04f177aca08..9f32c7d847a 100644 --- a/web/src/pages/datasets/index.tsx +++ b/web/src/pages/datasets/index.tsx @@ -91,7 +91,7 @@ export default function Datasets() { value={filterValue} filters={owners} onChange={handleFilterSubmit} - className="px-8" + className="px-8 mb-4" icon={'datasets'} >
-
+
{tips ?? t('knowledgeConfiguration.photoTip')}
diff --git a/web/src/components/file-upload-dialog/index.tsx b/web/src/components/file-upload-dialog/index.tsx index d7239e2e8b6..251c52f4cb9 100644 --- a/web/src/components/file-upload-dialog/index.tsx +++ b/web/src/components/file-upload-dialog/index.tsx @@ -76,7 +76,7 @@ function UploadForm({ submit, showParseOnCreation }: UploadFormProps) { data-testid="parse-on-creation-toggle" onCheckedChange={field.onChange} checked={field.value} - > + /> )} )} @@ -85,7 +85,7 @@ function UploadForm({ submit, showParseOnCreation }: UploadFormProps) { )} @@ -124,10 +124,7 @@ export function FileUploadDialog({ {t('common.comingSoon')} */} - + {t('common.save')} diff --git a/web/src/components/file-uploader.tsx b/web/src/components/file-uploader.tsx index 56163df10f2..813787a7ca0 100644 --- a/web/src/components/file-uploader.tsx +++ b/web/src/components/file-uploader.tsx @@ -2,7 +2,7 @@ 'use client'; -import { FileText, FolderUp, Upload, X } from 'lucide-react'; +import { FileText, FolderUp, LucideTrash2, Upload } from 'lucide-react'; import * as React from 'react'; import Dropzone, { type DropzoneProps, @@ -80,12 +80,12 @@ function FileCard({ file, progress, onRemove }: FileCardProps) {
@@ -300,71 +300,67 @@ export function FileUploader(props: FileUploaderProps) { const isDisabled = disabled || (files?.length ?? 0) >= maxFileCount; - const renderDropzone = (isFolderMode: boolean = false) => ( - 1 || multiple} - disabled={isDisabled} - noClick={isFolderMode} - noDrag={isFolderMode} - > - {({ getRootProps, getInputProps, isDragActive }) => ( -
- {!isFolderMode && } - {isDragActive && !isFolderMode ? ( -
-
-
-

- {t('fileManager.dropFilesHere', 'Drop the files here')} -

-
- ) : ( -
{ - if (isFolderMode && !isDisabled) { - folderInputRef.current?.click(); - } - }} - > -
- {isFolderMode ? ( -
+ )} + + ); + }; return (
@@ -417,7 +413,7 @@ export function FileUploader(props: FileUploaderProps) { )} {files?.length ? ( -
+
{files?.map((file, index) => ( - - - - + + + ); } diff --git a/web/src/components/ui/button.tsx b/web/src/components/ui/button.tsx index 93f366e0164..76ee1682d3c 100644 --- a/web/src/components/ui/button.tsx +++ b/web/src/components/ui/button.tsx @@ -68,6 +68,13 @@ const buttonVariants = cva( hover:bg-state-error/10 focus-visible:bg-state-error/10 `, + 'danger-hover': ` + bg-bg-input border-border-button + hover:bg-state-error/10 focus-visible:bg-state-error/10 + hover:text-state-error focus-visible:text-state-error + hover:border-state-error focus-visible:border-state-error + `, + // Ghost variant series // Button has transparent background, without borders ghost: ` @@ -91,11 +98,11 @@ const buttonVariants = cva( size: { auto: '', - xl: 'h-12 rounded-xl px-5', + xl: 'h-12 rounded-xl px-5 gap-3', lg: 'h-10 rounded-lg px-4', default: 'h-8 rounded px-3', - sm: 'h-7 rounded-sm px-2', - xs: 'h-6 rounded-xs px-1', + sm: 'h-7 rounded-sm px-2 gap-1', + xs: 'h-6 rounded-xs px-1 gap-0.5', 'icon-xl': 'size-12 rounded-xl', 'icon-lg': 'size-10 rounded-lg', diff --git a/web/src/components/ui/checkbox.tsx b/web/src/components/ui/checkbox.tsx index 50f937cfb0f..362b7d5ca68 100644 --- a/web/src/components/ui/checkbox.tsx +++ b/web/src/components/ui/checkbox.tsx @@ -1,7 +1,7 @@ 'use client'; import * as CheckboxPrimitive from '@radix-ui/react-checkbox'; -import { LucideCheck } from 'lucide-react'; +import { LucideCheck, LucideMinus } from 'lucide-react'; import * as React from 'react'; import { cn } from '@/lib/utils'; @@ -23,7 +23,12 @@ const Checkbox = React.forwardRef< {...props} > - + {props.checked === 'indeterminate' && ( + + )} + {props.checked === true && ( + + )} )); diff --git a/web/src/constants/common.ts b/web/src/constants/common.ts index f988d535f45..e34692a5694 100644 --- a/web/src/constants/common.ts +++ b/web/src/constants/common.ts @@ -41,6 +41,7 @@ export const fileIconMap = { xml: 'xml.svg', }; +// TODO: Need to migrate to standard BCP 47 language tag export const LanguageList = [ 'English', 'Chinese', @@ -68,7 +69,7 @@ export const LanguageMap = { Vietnamese: 'Tiếng việt', Japanese: '日本語', 'Portuguese BR': 'Português BR', - German: 'German', + German: 'Deutsch', French: 'Français', Italian: 'Italiano', Bulgarian: 'Български', diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 1f12cc927b1..e58b958923d 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -360,11 +360,11 @@ Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default lim filesSelected: 'Files selected', upload: 'Upload', run: 'Parse', - runningStatus0: 'PENDING', - runningStatus1: 'PARSING', - runningStatus2: 'CANCELED', - runningStatus3: 'SUCCESS', - runningStatus4: 'FAIL', + runningStatus0: 'Pending', + runningStatus1: 'Parsing', + runningStatus2: 'Cancelled', + runningStatus3: 'Success', + runningStatus4: 'Fail', pageRanges: 'Page ranges', pageRangesTip: 'Range of pages to be parsed; pages outside this range will not be processed.', diff --git a/web/src/pages/dataset/dataset-overview/index.tsx b/web/src/pages/dataset/dataset-overview/index.tsx index 5d60359396b..a535960bc3f 100644 --- a/web/src/pages/dataset/dataset-overview/index.tsx +++ b/web/src/pages/dataset/dataset-overview/index.tsx @@ -1,4 +1,3 @@ -import FileStatusBadge from '@/components/file-status-badge'; import { FilterCollection } from '@/components/list-filter-bar/interface'; import SvgIcon from '@/components/svg-icon'; import { useIsDarkTheme } from '@/components/theme-provider'; @@ -192,12 +191,7 @@ const FileLogsPage: FC = () => { return { id: value, // label: RunningStatusMap[value].label, - label: ( - - ), + label: RunningStatusMap[value], }; }), }, diff --git a/web/src/pages/dataset/sidebar/index.tsx b/web/src/pages/dataset/sidebar/index.tsx index 77bb55a3bc6..3fb9a6311cc 100644 --- a/web/src/pages/dataset/sidebar/index.tsx +++ b/web/src/pages/dataset/sidebar/index.tsx @@ -72,9 +72,9 @@ export function SideBar({ refreshCount }: PropType) { }, [t, routerData]); return ( -
); }, meta: { cellClassName: 'max-w-[20vw]' }, @@ -147,21 +151,18 @@ export function FilesTable({ return ( -
- - - - - {name} - -
+
- -

{name}

-
+ + {name}
); }, @@ -170,13 +171,18 @@ export function FilesTable({ accessorKey: 'create_time', header: ({ column }) => { return ( - + +
); }, cell: ({ row }) => ( @@ -189,13 +195,18 @@ export function FilesTable({ accessorKey: 'size', header: ({ column }) => { return ( - + +
); }, cell: ({ row }) => ( @@ -213,6 +224,9 @@ export function FilesTable({ { id: 'actions', header: t('action'), + meta: { + headerCellClassName: 'w-0', + }, enableHiding: false, enablePinning: true, cell: ({ row }) => { @@ -222,7 +236,7 @@ export function FilesTable({ showConnectToKnowledgeModal={showConnectToKnowledgeModal} showFileRenameModal={showFileRenameModal} showMoveFileModal={showMoveFileModal} - > + /> ); }, }, @@ -271,7 +285,12 @@ export function FilesTable({ {headerGroup.headers.map((header) => { return ( - + {header.isPlaceholder ? null : flexRender( diff --git a/web/src/pages/files/index.tsx b/web/src/pages/files/index.tsx index 3f8581ff95a..db2c57a383e 100644 --- a/web/src/pages/files/index.tsx +++ b/web/src/pages/files/index.tsx @@ -11,7 +11,7 @@ import { } from '@/components/ui/dropdown-menu'; import { useRowSelection } from '@/hooks/logic-hooks/use-row-selection'; import { useFetchFileList } from '@/hooks/use-file-request'; -import { Upload } from 'lucide-react'; +import { LucidePlus } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import { CreateFolderDialog } from './create-folder-dialog'; import { FileBreadcrumb } from './file-breadcrumb'; @@ -87,35 +87,38 @@ export default function Files() { ); return ( -
- - - - - - - - {t('fileManager.uploadFile')} - - - - {t('fileManager.newFolder')} - - - - - {!rowSelectionIsEmpty && ( - - )} +
+
+ + + + + + + + {t('fileManager.uploadFile')} + + + + {t('fileManager.newFolder')} + + + + + {!rowSelectionIsEmpty && ( + + )} +
)} -
+ ); } diff --git a/web/src/pages/user-setting/components/user-setting-header/index.tsx b/web/src/pages/user-setting/components/user-setting-header/index.tsx index d54e5776920..4f6026f434a 100644 --- a/web/src/pages/user-setting/components/user-setting-header/index.tsx +++ b/web/src/pages/user-setting/components/user-setting-header/index.tsx @@ -1,26 +1,7 @@ +import Spotlight from '@/components/spotlight'; import { Card, CardContent, CardHeader } from '@/components/ui/card'; import { PropsWithChildren } from 'react'; -export const UserSettingHeader = ({ - name, - description, -}: { - name: string; - description?: string; -}) => { - return ( - <> -
-
{name}
- {description && ( -
{description}
- )} -
- {/* */} - - ); -}; - export function Title({ children }: PropsWithChildren) { return {children}; } @@ -34,11 +15,17 @@ export function ProfileSettingWrapperCard({ children, }: ProfileSettingWrapperCardProps) { return ( - - + + {header} - {children} + + {children} + + ); } diff --git a/web/src/pages/user-setting/data-source/index.tsx b/web/src/pages/user-setting/data-source/index.tsx index 801b919db39..d4da96d7bf6 100644 --- a/web/src/pages/user-setting/data-source/index.tsx +++ b/web/src/pages/user-setting/data-source/index.tsx @@ -1,19 +1,62 @@ -import { CardTitle } from '@/components/ui/card'; import { useTranslation } from 'react-i18next'; -import Spotlight from '@/components/spotlight'; import { Button } from '@/components/ui/button'; import { Plus } from 'lucide-react'; -import { - ProfileSettingWrapperCard, - UserSettingHeader, -} from '../components/user-setting-header'; +import { ProfileSettingWrapperCard } from '../components/user-setting-header'; import AddDataSourceModal from './add-datasource-modal'; import { AddedSourceCard } from './component/added-source-card'; import { DataSourceKey, useDataSourceInfo } from './constant'; import { useAddDataSource, useListDataSource } from './hooks'; import { IDataSorceInfo } from './interface'; +const AvailableSourceCard = ({ + name, + description, + icon, + onAdd, +}: IDataSorceInfo & { onAdd: () => void }) => { + const { t } = useTranslation(); + + return ( +
onAdd()} + > + + {icon} + + +
+

{name}

+ + +
+ +

+ {description} +

+
+ ); +}; + const DataSource = () => { const { t } = useTranslation(); const { dataSourceInfo } = useDataSourceInfo(); @@ -38,97 +81,68 @@ const DataSource = () => { showAddingModal, } = useAddDataSource(); - const AbailableSourceCard = ({ - id, - name, - description, - icon, - }: IDataSorceInfo) => { - return ( -
- showAddingModal({ - id, - name, - description, - icon, - }) - } - > -
-
{icon}
-
-
{name}
-
{description}
-
-
-
- -
-
- ); - }; - return ( +
+

+ {t('setting.dataSources')} +

+

+ {t('setting.datasourceDescription')} +

+
} > - -
-
-
- {categorizedList?.length <= 0 && ( -
- {t('setting.sourceEmptyTip')} -
- )} - {categorizedList.map((item, index) => ( - - ))} -
-
-
- {/* */} - - {t('setting.availableSources')} -
- {t('setting.availableSourcesDescription')} -
-
-
-
- {/* */} -
- {dataSourceTemplates.map((item, index) => ( - - ))} +
+
+ {categorizedList?.length <= 0 && ( +
+ {t('setting.sourceEmptyTip')} +
+ )} + {categorizedList.map((item, index) => ( + + ))} +
+ +
+
+ {/* */} +

+ {t('setting.availableSources')} +
+ {t('setting.availableSourcesDescription')}
-

-
-
+ + - {addingModalVisible && ( - { - console.log(data); - handleAddOk(data); - }} - sourceData={addSource} - > - )} + {/* */} +
    + {dataSourceTemplates.map((item) => ( +
  • + showAddingModal(item)} + /> +
  • + ))} +
+
+ + {addingModalVisible && ( + { + console.log(data); + handleAddOk(data); + }} + sourceData={addSource} + /> + )}
); }; diff --git a/web/src/pages/user-setting/mcp/edit-mcp-dialog.tsx b/web/src/pages/user-setting/mcp/edit-mcp-dialog.tsx index ca9391884c1..b472d21965f 100644 --- a/web/src/pages/user-setting/mcp/edit-mcp-dialog.tsx +++ b/web/src/pages/user-setting/mcp/edit-mcp-dialog.tsx @@ -196,7 +196,7 @@ export function EditMcpDialog({ form={form} setFieldChanged={setFieldChanged} > - + -
- {t('mcp.mcpServers')} -
-
-
+
+
+

+ {t('mcp.mcpServers')} +

+ +

{t('mcp.customizeTheListOfMcpServers')} -

-
- - - - -
-
- +

+
+ +
+ + + + +
+ } > - {!data.mcp_servers?.length && ( -
-
{t('empty.noMCP')}
- -
- )} - {!!data.mcp_servers?.length && ( - <> - {isSelectionMode && ( -
- - - - {t('mcp.selected')} {selectedList.length} - -
- - - ), - }} +
+ {data.mcp_servers?.length ? ( + <> + {isSelectionMode && ( +
+ +
-
- )} - - {data.mcp_servers.map((item) => ( - - ))} - -
- + + ), + }} + > + + +
+ + )} + + {data.mcp_servers.map((item) => ( + + ))} + +
+ +
+ + ) : ( +
+
+ {t('empty.noMCP')} +
+
- - )} + )} +
+ {editVisible && ( )} - ); } diff --git a/web/src/pages/user-setting/profile/index.tsx b/web/src/pages/user-setting/profile/index.tsx index 8b496a9d503..4b0d897ffb8 100644 --- a/web/src/pages/user-setting/profile/index.tsx +++ b/web/src/pages/user-setting/profile/index.tsx @@ -22,10 +22,7 @@ import { Loader2Icon, PenLine } from 'lucide-react'; import { FC, useEffect } from 'react'; import { useForm } from 'react-hook-form'; import { z } from 'zod'; -import { - ProfileSettingWrapperCard, - UserSettingHeader, -} from '../components/user-setting-header'; +import { ProfileSettingWrapperCard } from '../components/user-setting-header'; import { EditType, modalTitle, useProfile } from './hooks/use-profile'; const baseSchema = z.object({ @@ -123,10 +120,14 @@ const ProfilePage: FC = () => { //
+
+

+ {t('profile')} +

+

+ {t('profileDescription')} +

+
} > @@ -142,11 +143,11 @@ const ProfilePage: FC = () => {
{profile.userName}
+ @@ -175,10 +176,9 @@ const ProfilePage: FC = () => { {profile.timeZone}
@@ -208,10 +208,9 @@ const ProfilePage: FC = () => { {profile.currPasswd ? '********' : ''}
diff --git a/web/src/pages/user-setting/setting-model/components/modal-card.tsx b/web/src/pages/user-setting/setting-model/components/modal-card.tsx index b27c2ba17a1..f688cc9f6c5 100644 --- a/web/src/pages/user-setting/setting-model/components/modal-card.tsx +++ b/web/src/pages/user-setting/setting-model/components/modal-card.tsx @@ -93,26 +93,22 @@ export const ModelProviderCard: FC = ({
@@ -161,9 +157,9 @@ export const ModelProviderCard: FC = ({ ))}
-
+
    {item.llm.map((model) => ( -
    @@ -176,36 +172,38 @@ export const ModelProviderCard: FC = ({
    -
    +
    {isLocalLlmFactory(item.name) && ( )} + { handleEnableLlm(model.name, value); }} /> +
    -
    + ))} -
+
)} diff --git a/web/src/pages/user-setting/setting-model/components/system-setting.tsx b/web/src/pages/user-setting/setting-model/components/system-setting.tsx index 607f2a99f22..f1fbdaaf519 100644 --- a/web/src/pages/user-setting/setting-model/components/system-setting.tsx +++ b/web/src/pages/user-setting/setting-model/components/system-setting.tsx @@ -173,13 +173,16 @@ const SystemSetting = ({ onOk, loading }: IProps) => { }; return ( -
-
-
{t('systemModelSettings')}
-
+
+
+

+ {t('systemModelSettings')} +

+

{t('systemModelDescription')} -

-
+

+ +
{llmList.map((item) => ( @@ -194,7 +197,7 @@ const SystemSetting = ({ onOk, loading }: IProps) => { {t('common:cancel')}
*/} -
+ ); }; diff --git a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx index ba87e3830a5..a78aa686289 100644 --- a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx +++ b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx @@ -79,97 +79,97 @@ export const AvailableModels: FC<{ }; return ( -
-
- {t('availableModels')} -
- {/* Search Bar */} -
- {/*
*/} - setSearchTerm(e.target.value)} - className="w-full px-4 py-2 pl-10 bg-bg-input border border-border-default rounded-lg focus:outline-none focus:ring-1 focus:ring-border-button transition-colors" - /> - {/* */} - {/*
*/} -
+
+

{t('availableModels')}

+ {/* Search Bar */} +
+ {/*
*/} + setSearchTerm(e.target.value)} + className="w-full px-4 py-2 pl-10 bg-bg-input border border-border-default rounded-lg focus:outline-none focus:ring-1 focus:ring-border-button transition-colors" + /> + {/* */} + {/*
*/} +
- {/* Tags Filter */} -
- - {allTags.map((tag) => ( + {/* Tags Filter */} +
- ))} -
+ + {allTags.map((tag) => ( + + ))} +
+
{/* Models List */} -
+
{filteredModels.map((model) => (
handleAddModel(model.name)} >
-
+
{model.name}
{!!APIMapUrl[model.name as keyof typeof APIMapUrl] && ( )}
-
-
+
{sortTags(model.tags).map((tag, index) => ( ))}
-
+ ); }; diff --git a/web/src/pages/user-setting/setting-team/index.tsx b/web/src/pages/user-setting/setting-team/index.tsx index 374400c744a..cd70b4d504e 100644 --- a/web/src/pages/user-setting/setting-team/index.tsx +++ b/web/src/pages/user-setting/setting-team/index.tsx @@ -10,10 +10,7 @@ import Spotlight from '@/components/spotlight'; import { SearchInput } from '@/components/ui/input'; import { UserPlus } from 'lucide-react'; import { useState } from 'react'; -import { - ProfileSettingWrapperCard, - UserSettingHeader, -} from '../components/user-setting-header'; +import { ProfileSettingWrapperCard } from '../components/user-setting-header'; import AddingUserModal from './add-user-modal'; import { useAddUser } from './hooks'; import TenantTable from './tenant-table'; @@ -40,53 +37,60 @@ const UserSettingTeam = () => { // /> +
+

+ {userInfo?.nickname + ' ' + t('setting.workspace')} +

+
} > - - - {/* */} - - {t('setting.teamMembers')} - -
+ +
+ + + {/* */} + + {t('setting.teamMembers')} + + +
+ setSearchUser(e.target.value)} + /> + +
+
+ + + + +
+ + + + {/* */} + + {t('setting.joinedTeams')} + setSearchTerm(e.target.value)} placeholder={t('common.search')} - value={searchUser} - onChange={(e) => setSearchUser(e.target.value)} /> - -
-
- - - -
- - - - {/* */} - - {t('setting.joinedTeams')} - - setSearchTerm(e.target.value)} - placeholder={t('common.search')} - /> - - - - - + + + + + +
{addingTenantModalVisible && ( { const renderSortIcon = () => { if (sortOrder === 'asc') { - return ; + return ; } else if (sortOrder === 'desc') { - return ; + return ; } else { - return ; + return ; } }; @@ -77,13 +77,16 @@ const TenantTable = ({ searchTerm }: { searchTerm: string }) => { {t('common.name')} - -
+ +
{t('setting.updateDate')} - {renderSortIcon()} +
{t('setting.email')} diff --git a/web/src/pages/user-setting/setting-team/user-table.tsx b/web/src/pages/user-setting/setting-team/user-table.tsx index 840e08a898a..a4f6fb569cc 100644 --- a/web/src/pages/user-setting/setting-team/user-table.tsx +++ b/web/src/pages/user-setting/setting-team/user-table.tsx @@ -71,11 +71,11 @@ const UserTable = ({ searchUser }: { searchUser: string }) => { const renderSortIcon = () => { if (sortOrder === 'asc') { - return ; + return ; } else if (sortOrder === 'desc') { - return ; + return ; } else { - return ; + return ; } }; return ( @@ -84,13 +84,16 @@ const UserTable = ({ searchUser }: { searchUser: string }) => { {t('common.name')} - -
+ +
{t('setting.updateDate')} - {renderSortIcon()} +
{t('setting.email')} @@ -110,7 +113,7 @@ const UserTable = ({ searchUser }: { searchUser: string }) => { ) : sortedData && sortedData.length > 0 ? ( sortedData.map((record) => ( - +
[ - { icon: Server, label: t('setting.dataSources'), key: Routes.DataSource }, - { icon: Box, label: t('setting.model'), key: Routes.Model }, - { icon: Banknote, label: 'MCP', key: Routes.Mcp }, - { icon: Users, label: t('setting.team'), key: Routes.Team }, - { icon: User, label: t('setting.profile'), key: Routes.Profile }, - { icon: Unplug, label: t('setting.api'), key: Routes.Api }, + { + icon: , + label: t('setting.dataSources'), + key: Routes.DataSource, + }, + { + icon: , + label: t('setting.model'), + key: Routes.Model, + 'data-testid': 'settings-nav-model-providers', + }, + { + icon: , + label: 'MCP', + key: Routes.Mcp, + }, + { + icon: , + label: t('setting.team'), + key: Routes.Team, + }, + { + icon: , + label: t('setting.profile'), + key: Routes.Profile, + }, + { + icon: , + label: t('setting.api'), + key: Routes.Api, + }, // { // icon: MessageSquareQuote, // label: 'Prompt Templates', @@ -33,10 +63,10 @@ const menuItems = (t: TFunction) => [ // { icon: Cog, label: t('setting.system'), key: Routes.System }, // { icon: Banknote, label: 'Plan', key: Routes.Plan }, ]; + export function SideBar() { - const pathName = useSecondPathName(); const { data: userInfo } = useFetchUserInfo(); - const { handleMenuClick, active } = useHandleMenuClick(); + const { handleMenuClick, active: activeItemKey } = useHandleMenuClick(); const { version, fetchSystemVersion } = useFetchSystemVersion(); const { t } = useTranslation(); useEffect(() => { @@ -48,40 +78,38 @@ export function SideBar() { return ( ); } From 70c1da25ce450b2f8e7d2ffb277b84141a928609 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 12 Mar 2026 13:41:08 +0800 Subject: [PATCH 224/565] Update go admin server default port to 9383 (#13559) ### What problem does this PR solve? As title ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- cmd/admin_server.go | 4 ++-- internal/server/config.go | 10 +++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/cmd/admin_server.go b/cmd/admin_server.go index f9b095c908d..291f7868eef 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -137,7 +137,7 @@ func main() { r.Setup(ginEngine) // Create HTTP server - addr := fmt.Sprintf(":9381") + addr := fmt.Sprintf(":%d", cfg.Admin.Port) srv := &http.Server{ Addr: addr, Handler: ginEngine, @@ -160,7 +160,7 @@ func main() { // Start server in a goroutine go func() { logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: 9381")) + logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) } diff --git a/internal/server/config.go b/internal/server/config.go index 6db21f9b9e9..acdcf4c0179 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -333,17 +333,13 @@ func Init(configPath string) error { } // Set default values for admin configuration if not configured - if globalConfig.Admin.Host == "" { - globalConfig.Admin.Host = v.GetString("admin.host") - } if globalConfig.Admin.Host == "" { globalConfig.Admin.Host = "127.0.0.1" } if globalConfig.Admin.Port == 0 { - globalConfig.Admin.Port = v.GetInt("admin.http_port") - } - if globalConfig.Admin.Port == 0 { - globalConfig.Admin.Port = 9381 + globalConfig.Admin.Port = 9383 + } else { + globalConfig.Admin.Port += 2 } // Load REGISTER_ENABLED from environment variable (default: 1) From d926a7291a92b8b9e989602bb16110d1d8370352 Mon Sep 17 00:00:00 2001 From: cambrianlee <53687180+cambrianlee@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:23:55 +0800 Subject: [PATCH 225/565] Fix typo: documnet_keyword -> document_keyword in Chunk class (#13531) ### What problem does this PR solve? The Chunk class had a typo in the attribute name 'documnet_keyword', which caused the document_name field to remain empty when retrieving chunks via the SDK. This fix corrects the spelling to 'document_keyword'. Changes: - Line 36: Changed self.documnet_keyword to self.document_keyword - Line 52: Updated backward compatibility code to use self.document_keyword ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- sdk/python/ragflow_sdk/modules/chunk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/ragflow_sdk/modules/chunk.py b/sdk/python/ragflow_sdk/modules/chunk.py index 0f5bf596b65..609cb2745ff 100644 --- a/sdk/python/ragflow_sdk/modules/chunk.py +++ b/sdk/python/ragflow_sdk/modules/chunk.py @@ -33,7 +33,7 @@ def __init__(self, rag, res_dict): self.create_timestamp = 0.0 self.dataset_id = None self.document_name = "" - self.documnet_keyword = "" + self.document_keyword = "" self.document_id = "" self.available = True # Additional fields for retrieval results @@ -49,7 +49,7 @@ def __init__(self, rag, res_dict): #for backward compatibility if not self.document_name: - self.document_name = self.documnet_keyword + self.document_name = self.document_keyword def update(self, update_message: dict): From fe36350da3783efb3a6519945e6ecccb7cf8b76f Mon Sep 17 00:00:00 2001 From: Josh <50523060+JosefAschauer@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:03:30 +0100 Subject: [PATCH 226/565] Fix: avoid empty doc filter in knowledge retrieval (#13484) ## Summary Fix knowledge-base chat retrieval when no individual document IDs are selected. ## Root Cause `async_chat()` initialized `doc_ids` as an empty list when the request did not explicitly select documents. That empty list was then forwarded into retrieval as an active `doc_id` filter, effectively becoming `doc_id IN []` and suppressing all chunk matches. ## Changes - treat missing selected document IDs as `None` instead of `[]` - keep explicit document filtering when IDs are actually provided - add regression coverage for the shared chat retrieval path ## Validation - `python3 -m py_compile api/db/services/dialog_service.py test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py` - `.venv/bin/python -m pytest test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py` - manually verified that chat completions again inject retrieved knowledge into the prompt --------- Co-authored-by: Yingfeng --- api/db/services/dialog_service.py | 8 +- ...t_dialog_service_use_sql_source_columns.py | 95 +++++++++++++++++++ 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 7a549b69d0e..4289f507b51 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -494,12 +494,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] - attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] + attachments = None + if "doc_ids" in kwargs: + attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id] attachments_= "" image_attachments = [] image_files = [] if "doc_ids" in messages[-1]: - attachments = messages[-1]["doc_ids"] + attachments = [doc_id for doc_id in messages[-1]["doc_ids"] if doc_id] if "files" in messages[-1]: if llm_type == "chat": text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) @@ -559,7 +561,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} knowledges = [] - if attachments is not None and "knowledge" in param_keys: + if "knowledge" in param_keys: logging.debug("Proceeding with retrieval") tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index a79d9358178..71941e3874a 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -17,6 +17,7 @@ import sys import types import warnings +from types import SimpleNamespace import pytest @@ -101,6 +102,19 @@ def sql_retrieval(self, sql, format="json"): return self._results[idx] +class _StubAsyncRetriever: + def __init__(self, result): + self.result = result + self.calls = [] + + async def retrieval(self, *args, **kwargs): + self.calls.append({"args": args, "kwargs": kwargs}) + return self.result + + def retrieval_by_children(self, chunks, tenant_ids): + return chunks + + @pytest.fixture def force_es_engine(monkeypatch): monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_INFINITY", False) @@ -219,3 +233,84 @@ def test_use_sql_source_repair_is_bounded_to_single_retry(monkeypatch, force_es_ assert "Source" not in result["answer"] assert len(chat_model.calls) == 2 assert len(retriever.sql_calls) == 2 + + +@pytest.mark.p2 +def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch): + retriever = _StubAsyncRetriever( + { + "total": 1, + "chunks": [ + { + "chunk_id": "chunk-1", + "content_ltks": "chunk text", + "content_with_weight": "Chunk text from dataset.", + "doc_id": "doc-1", + "docnm_kwd": "doc.txt", + "kb_id": "kb-1", + "important_kwd": [], + "positions": [], + "vector": [0.1, 0.2], + } + ], + "doc_aggs": [], + } + ) + chat_model = _StubChatModel(["stub answer"]) + dialog = SimpleNamespace( + kb_ids=["kb-1"], + llm_id="chat-model", + tenant_id="tenant-id", + llm_setting={}, + similarity_threshold=0.1, + vector_similarity_weight=0.2, + top_n=8, + top_k=32, + meta_data_filter=None, + prompt_config={ + "quote": False, + "keyword": False, + "tts": False, + "empty_response": "", + "system": "Use only this knowledge: {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "reasoning": False, + "toc_enhance": False, + "use_kg": False, + }, + ) + + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + monkeypatch.setattr(dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat") + monkeypatch.setattr( + dialog_service.TenantLLMService, + "get_model_config", + lambda *_args, **_kwargs: {"llm_factory": "unit", "max_tokens": 4096}, + ) + monkeypatch.setattr(dialog_service.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None) + monkeypatch.setattr( + dialog_service, + "get_models", + lambda _dialog: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None), + ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service, "label_question", lambda _question, _kbs: None) + monkeypatch.setattr( + dialog_service, + "kb_prompt", + lambda kbinfos, _max_tokens: ["Chunk text from dataset."] if kbinfos["chunks"] else [], + ) + monkeypatch.setattr(dialog_service, "message_fit_in", lambda msg, _max_tokens: (0, msg)) + + async def _collect(): + items = [] + async for item in dialog_service.async_chat(dialog, [{"role": "user", "content": "What does the dataset say?"}], stream=False): + items.append(item) + return items + + result = asyncio.run(_collect()) + + assert len(retriever.calls) == 1 + assert retriever.calls[0]["kwargs"]["doc_ids"] is None + assert "Chunk text from dataset." in chat_model.calls[0]["system_prompt"] + assert result[0]["answer"] == "stub answer" From 827a1f72d8453721b29b06232d8a358a428c5114 Mon Sep 17 00:00:00 2001 From: NeedmeFordev <124189514+spider-yamet@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:09:03 +0900 Subject: [PATCH 227/565] feat(parser): support external Docling server via DOCLING_SERVER_URL (#13527) ### What problem does this PR solve? This PR adds support for parsing PDFs through an external Docling server, so RAGFlow can connect to remote `docling serve` deployments instead of relying only on local in-process Docling. It addresses the feature request in [#13426](https://github.com/infiniflow/ragflow/issues/13426) and aligns with the external-server usage pattern already used by MinerU. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): ### What is changed? - Add external Docling server support in `DoclingParser`: - Use `DOCLING_SERVER_URL` to enable remote parsing mode. - Try `POST /v1/convert/source` first, and fallback to `/v1alpha/convert/source`. - Keep existing local Docling behavior when `DOCLING_SERVER_URL` is not set. - Wire Docling env settings into parser invocation paths: - `rag/app/naive.py` - `rag/flow/parser/parser.py` - Add Docling env hints in constants and update docs: - `docs/guides/dataset/select_pdf_parser.md` - `docs/guides/agent/agent_component_reference/parser.md` - `docs/faq.mdx` ### Why this approach? This keeps the change focused on one issue and one capability (external Docling connectivity), without introducing unrelated provider-model plumbing. ### Validation - Static checks: - `python -m py_compile` on changed Python files - `python -m ruff check` on changed Python files - Functional checks: - Remote v1 endpoint path works - v1alpha fallback works - Local Docling path remains available when server URL is unset ### Related links - Feature request: [Support external Docling server (issue #13426)](https://github.com/infiniflow/ragflow/issues/13426) - Compare view for this branch: [main...feat/docling-server](https://github.com/infiniflow/ragflow/compare/main...spider-yamet:ragflow:feat/docling-server?expand=1) ##### Fixes [#13426](https://github.com/infiniflow/ragflow/issues/13426) --- common/constants.py | 3 + deepdoc/parser/docling_parser.py | 189 +++++++++++++++++- docs/faq.mdx | 18 ++ .../agent/agent_component_reference/parser.md | 6 + docs/guides/dataset/select_pdf_parser.md | 6 + rag/app/naive.py | 8 +- rag/flow/parser/parser.py | 26 ++- 7 files changed, 246 insertions(+), 10 deletions(-) diff --git a/common/constants.py b/common/constants.py index cbc2f534c95..24530c45737 100644 --- a/common/constants.py +++ b/common/constants.py @@ -219,6 +219,9 @@ class ForgettingPolicy(StrEnum): # ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR" # ENV_MINERU_BACKEND = "MINERU_BACKEND" # ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT" +# ENV_DOCLING_SERVER_URL = "DOCLING_SERVER_URL" +# ENV_DOCLING_OUTPUT_DIR = "DOCLING_OUTPUT_DIR" +# ENV_DOCLING_DELETE_OUTPUT = "DOCLING_DELETE_OUTPUT" # ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR" # ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS" # ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES" diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py index 2ad1e8d3d31..ccd45bab124 100644 --- a/deepdoc/parser/docling_parser.py +++ b/deepdoc/parser/docling_parser.py @@ -17,6 +17,8 @@ import logging import re +import base64 +import os from dataclasses import dataclass from enum import Enum from io import BytesIO @@ -25,6 +27,7 @@ from typing import Any, Callable, Iterable, Optional import pdfplumber +import requests from PIL import Image try: @@ -74,15 +77,41 @@ def _extract_bbox_from_prov(item, prov_attr: str = "prov") -> Optional[_BBox]: class DoclingParser(RAGFlowPdfParser): - def __init__(self): + def __init__(self, docling_server_url: str = "", request_timeout: int = 600): self.logger = logging.getLogger(self.__class__.__name__) self.page_images: list[Image.Image] = [] self.page_from = 0 self.page_to = 10_000 self.outlines = [] - - - def check_installation(self) -> bool: + self.docling_server_url = (docling_server_url or "").rstrip("/") + self.request_timeout = request_timeout + + def _effective_server_url(self, docling_server_url: Optional[str] = None) -> str: + return (docling_server_url or self.docling_server_url or "").rstrip("/") or ( + os.environ.get("DOCLING_SERVER_URL", "").rstrip("/") + ) + + @staticmethod + def _is_http_endpoint_valid(url: str, timeout: int = 5) -> bool: + try: + response = requests.head(url, timeout=timeout, allow_redirects=True) + return response.status_code in [200, 301, 302, 307, 308] + except Exception: + try: + response = requests.get(url, timeout=timeout, allow_redirects=True) + return response.status_code in [200, 301, 302, 307, 308] + except Exception: + return False + + def check_installation(self, docling_server_url: Optional[str] = None) -> bool: + server_url = self._effective_server_url(docling_server_url) + if server_url: + for path in ("/openapi.json", "/docs", "/v1/convert/source"): + if self._is_http_endpoint_valid(f"{server_url}{path}", timeout=5): + return True + self.logger.warning(f"[Docling] external server not reachable: {server_url}") + return False + if DocumentConverter is None: self.logger.warning("[Docling] 'docling' is not importable, please: pip install docling") return False @@ -277,6 +306,141 @@ def _transfer_to_tables(self, doc): tables.append(((img, [captions]), positions if positions else "")) return tables + @staticmethod + def _sections_from_remote_text(text: str, parse_method: str) -> list[tuple[str, ...]]: + txt = (text or "").strip() + if not txt: + return [] + if parse_method == "manual": + return [(txt, DoclingContentType.TEXT.value, "")] + if parse_method == "paper": + return [(txt, DoclingContentType.TEXT.value)] + return [(txt, "")] + + @staticmethod + def _extract_remote_document_entries(payload: Any) -> list[dict[str, Any]]: + if not isinstance(payload, dict): + return [] + if isinstance(payload.get("document"), dict): + return [payload["document"]] + if isinstance(payload.get("documents"), list): + return [d for d in payload["documents"] if isinstance(d, dict)] + if isinstance(payload.get("results"), list): + docs = [] + for it in payload["results"]: + if isinstance(it, dict): + if isinstance(it.get("document"), dict): + docs.append(it["document"]) + elif isinstance(it.get("result"), dict): + docs.append(it["result"]) + else: + docs.append(it) + return docs + return [] + + def _parse_pdf_remote( + self, + filepath: str | PathLike[str], + binary: BytesIO | bytes | None = None, + callback: Optional[Callable] = None, + *, + parse_method: str = "raw", + docling_server_url: Optional[str] = None, + request_timeout: Optional[int] = None, + ): + server_url = self._effective_server_url(docling_server_url) + if not server_url: + raise RuntimeError("[Docling] DOCLING_SERVER_URL is not configured.") + + timeout = request_timeout or self.request_timeout + if binary is not None: + if isinstance(binary, (bytes, bytearray)): + pdf_bytes = bytes(binary) + else: + pdf_bytes = bytes(binary.getbuffer()) + else: + src_path = Path(filepath) + if not src_path.exists(): + raise FileNotFoundError(f"PDF not found: {src_path}") + with open(src_path, "rb") as f: + pdf_bytes = f.read() + + if callback: + callback(0.2, f"[Docling] Requesting external server: {server_url}") + + filename = Path(filepath).name or "input.pdf" + b64 = base64.b64encode(pdf_bytes).decode("ascii") + v1_payload = { + "options": { + "from_formats": ["pdf"], + "to_formats": ["json", "md", "text"], + }, + "sources": [ + { + "kind": "file", + "filename": filename, + "base64_string": b64, + } + ], + } + v1alpha_payload = { + "options": { + "from_formats": ["pdf"], + "to_formats": ["json", "md", "text"], + }, + "file_sources": [ + { + "filename": filename, + "base64_string": b64, + } + ], + } + errors = [] + response_json = None + for endpoint, payload in ( + ("/v1/convert/source", v1_payload), + ("/v1alpha/convert/source", v1alpha_payload), + ): + try: + resp = requests.post( + f"{server_url}{endpoint}", + json=payload, + timeout=timeout, + ) + if resp.status_code < 300: + response_json = resp.json() + break + errors.append(f"{endpoint}: HTTP {resp.status_code} {resp.text[:300]}") + except Exception as exc: + errors.append(f"{endpoint}: {exc}") + + if response_json is None: + raise RuntimeError("[Docling] remote convert failed: " + " | ".join(errors)) + + docs = self._extract_remote_document_entries(response_json) + if not docs: + raise RuntimeError("[Docling] remote response does not contain parsed documents.") + + sections: list[tuple[str, ...]] = [] + tables = [] + for doc in docs: + md = doc.get("md_content") + txt = doc.get("text_content") + if isinstance(md, str) and md.strip(): + sections.extend(self._sections_from_remote_text(md, parse_method=parse_method)) + elif isinstance(txt, str) and txt.strip(): + sections.extend(self._sections_from_remote_text(txt, parse_method=parse_method)) + + json_content = doc.get("json_content") + if isinstance(json_content, dict): + md_fallback = json_content.get("md_content") + if isinstance(md_fallback, str) and md_fallback.strip() and not sections: + sections.extend(self._sections_from_remote_text(md_fallback, parse_method=parse_method)) + + if callback: + callback(0.95, f"[Docling] Remote sections: {len(sections)}") + return sections, tables + def parse_pdf( self, filepath: str | PathLike[str], @@ -287,12 +451,25 @@ def parse_pdf( lang: Optional[str] = None, method: str = "auto", delete_output: bool = True, - parse_method: str = "raw" + parse_method: str = "raw", + docling_server_url: Optional[str] = None, + request_timeout: Optional[int] = None, ): - if not self.check_installation(): + if not self.check_installation(docling_server_url=docling_server_url): raise RuntimeError("Docling not available, please install `docling`") + server_url = self._effective_server_url(docling_server_url) + if server_url: + return self._parse_pdf_remote( + filepath=filepath, + binary=binary, + callback=callback, + parse_method=parse_method, + docling_server_url=server_url, + request_timeout=request_timeout, + ) + if binary is not None: tmpdir = Path(output_dir) if output_dir else Path.cwd() / ".docling_tmp" tmpdir.mkdir(parents=True, exist_ok=True) diff --git a/docs/faq.mdx b/docs/faq.mdx index 965aa16dab2..9c45e0fe61a 100644 --- a/docs/faq.mdx +++ b/docs/faq.mdx @@ -567,6 +567,24 @@ RAGFlow supports MinerU's `vlm-http-client` backend, enabling you to delegate do When using the `vlm-http-client` backend, the RAGFlow server requires no GPU, only network connectivity. This enables cost-effective distributed deployment with multiple RAGFlow instances sharing one remote vLLM server. ::: +### How to use an external Docling Serve server for document parsing? + +RAGFlow supports Docling in two modes: + +1. **Local Docling** (existing mode): install Docling in the RAGFlow runtime (`USE_DOCLING=true`) and parse in-process. +2. **External Docling Serve** (remote mode): point RAGFlow to a Docling Serve endpoint. + +To enable remote mode, set: + +```bash +DOCLING_SERVER_URL=http://your-docling-serve-host:5001 +``` + +Behavior: + +- When `DOCLING_SERVER_URL` is set, RAGFlow sends PDFs to Docling Serve using `/v1/convert/source` (and falls back to `/v1alpha/convert/source` for older servers). +- When `DOCLING_SERVER_URL` is not set, RAGFlow uses local in-process Docling. + ### How to use PaddleOCR for document parsing? From v0.24.0 onwards, RAGFlow includes PaddleOCR as an optional PDF parser. Please note that RAGFlow acts only as a *remote client* for PaddleOCR, calling the PaddleOCR API to parse PDFs and reading the returned files. diff --git a/docs/guides/agent/agent_component_reference/parser.md b/docs/guides/agent/agent_component_reference/parser.md index cdc0a9e1750..75b6341cb23 100644 --- a/docs/guides/agent/agent_component_reference/parser.md +++ b/docs/guides/agent/agent_component_reference/parser.md @@ -65,6 +65,12 @@ Starting from v0.22.0, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF p - If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown. - If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component. +To use an external Docling Serve instance (instead of local in-process Docling), set: + +- `DOCLING_SERVER_URL`: The Docling Serve API endpoint (for example, `http://docling-host:5001`). + +When `DOCLING_SERVER_URL` is set, RAGFlow sends PDF content to Docling Serve (`/v1/convert/source`, with fallback to `/v1alpha/convert/source`) and ingests the returned markdown/text. If the variable is not set, RAGFlow keeps using local Docling (`USE_DOCLING=true` + installed package) behavior. + :::note All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI. ::: diff --git a/docs/guides/dataset/select_pdf_parser.md b/docs/guides/dataset/select_pdf_parser.md index fa2d068cb42..d96992f5af7 100644 --- a/docs/guides/dataset/select_pdf_parser.md +++ b/docs/guides/dataset/select_pdf_parser.md @@ -65,6 +65,12 @@ Starting from v0.22.0, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF p - If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown. - If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component. +To use an external Docling Serve instance (instead of local in-process Docling), set: + +- `DOCLING_SERVER_URL`: The Docling Serve API endpoint (for example, `http://docling-host:5001`). + +When `DOCLING_SERVER_URL` is set, RAGFlow sends PDF content to Docling Serve (`/v1/convert/source`, with fallback to `/v1alpha/convert/source`) and ingests the returned markdown/text. If the variable is not set, RAGFlow keeps using local Docling (`USE_DOCLING=true` + installed package) behavior. + :::note All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI. ::: diff --git a/rag/app/naive.py b/rag/app/naive.py index 1d2d0ebbf7f..3eec55df036 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -153,15 +153,17 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese parse_method = kwargs.get("parse_method", "raw") if not pdf_parser.check_installation(): - callback(-1, "Docling not found.") + if callback: + callback(-1, "Docling not found.") return None, None, pdf_parser sections, tables = pdf_parser.parse_pdf( filepath=filename, binary=binary, callback=callback, - output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), - delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))), + output_dir=os.environ.get("DOCLING_OUTPUT_DIR", ""), + delete_output=bool(int(os.environ.get("DOCLING_DELETE_OUTPUT", 1))), + docling_server_url=os.environ.get("DOCLING_SERVER_URL", ""), parse_method=parse_method, ) return sections, tables, pdf_parser diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 9ee69948413..3f779e252ca 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -32,6 +32,7 @@ from common.constants import LLMType from common.misc_utils import get_uuid from deepdoc.parser import ExcelParser +from deepdoc.parser.docling_parser import DoclingParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser from deepdoc.parser.tcadp_parser import TCADPParser from rag.app.naive import Docx @@ -173,7 +174,7 @@ def check(self): pdf_parse_method = pdf_config.get("parse_method", "") self.check_empty(pdf_parse_method, "Parse method abnormal.") - if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru", "tcadp parser", "paddleocr"]: + if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru", "docling", "tcadp parser", "paddleocr"]: self.check_empty(pdf_config.get("lang", ""), "PDF VLM language") pdf_output_format = pdf_config.get("output_format", "") @@ -371,6 +372,29 @@ def resolve_mineru_llm_name(): "text": t, } bboxes.append(box) + elif parse_method.lower() == "docling": + pdf_parser = DoclingParser(docling_server_url=os.environ.get("DOCLING_SERVER_URL", "")) + lines, _ = pdf_parser.parse_pdf( + filepath=name, + binary=blob, + callback=self.callback, + parse_method=conf.get("docling_parse_method", "raw"), + docling_server_url=os.environ.get("DOCLING_SERVER_URL", ""), + ) + bboxes = [] + for item in lines: + if not isinstance(item, tuple) or not item: + continue + text = item[0] + poss = item[-1] if len(item) >= 2 else "" + box = { + "text": text, + "image": pdf_parser.crop(poss, 1) if isinstance(poss, str) and poss else None, + "positions": [[pos[0][-1], *pos[1:]] for pos in pdf_parser.extract_positions(poss)] + if isinstance(poss, str) and poss + else [], + } + bboxes.append(box) elif parse_method.lower() == "tcadp parser": # ADP is a document parsing tool using Tencent Cloud API table_result_type = conf.get("table_result_type", "1") From 819483b990d0f2da0ee4bf9034d0f668fece57ca Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Thu, 12 Mar 2026 17:49:02 +0800 Subject: [PATCH 228/565] Fix: image pdf in ingestion pipeline (#13563) ### What problem does this PR solve? Fix: image pdf in ingestion pipeline #13550 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- deepdoc/parser/pdf_parser.py | 27 ++++++++++++++++----------- rag/flow/tokenizer/tokenizer.py | 3 ++- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 6020361c07b..5e8f9694a05 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -1785,21 +1785,26 @@ def min_rectangle_distance(rect1, rect2): logging.debug("No valid local positions for table/figure; skip insertion.") continue - bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)] - dists = [ - (min_rectangle_distance((pn, left, right, top + self.page_cum_height[pn], bott + self.page_cum_height[pn]), rect), i) - for i, rect in bboxes - for pn, left, right, top, bott in local_poss - ] - min_i = np.argmin(dists, axis=0)[0] - min_i, rect = bboxes[dists[min_i][-1]] if isinstance(txt, list): txt = "\n".join(txt) pn, left, right, top, bott = local_poss[0] - if self.boxes[min_i]["bottom"] < top + self.page_cum_height[pn]: - min_i += 1 + insert_at = len(self.boxes) + bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)] + if bboxes: + dists = [ + (min_rectangle_distance((cand_pn, cand_left, cand_right, cand_top + self.page_cum_height[cand_pn], cand_bott + self.page_cum_height[cand_pn]), rect), i) + for i, rect in bboxes + for cand_pn, cand_left, cand_right, cand_top, cand_bott in local_poss + ] + if dists: + nearest_bbox_idx = int(np.argmin([dist for dist, _ in dists])) + insert_at, _ = bboxes[dists[nearest_bbox_idx][-1]] + if self.boxes[insert_at]["bottom"] < top + self.page_cum_height[pn]: + insert_at += 1 + else: + logging.debug("No text boxes available; append %s block directly.", layout_type) self.boxes.insert( - min_i, + insert_at, { "page_number": pn + 1, "x0": left, diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index dcf4751064f..0d213c512e8 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -108,7 +108,8 @@ def batch_encode(txts): async def _invoke(self, **kwargs): try: chunks = kwargs.get("chunks") - kwargs["chunks"] = [c for c in chunks if c is not None] + if chunks is not None: + kwargs["chunks"] = [c for c in chunks if c is not None] from_upstream = TokenizerFromUpstream.model_validate(kwargs) except Exception as e: From aaca5504b55c180abd9da95a5e6b132fbde97e5f Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Thu, 12 Mar 2026 17:49:13 +0800 Subject: [PATCH 229/565] Feat: inject sys.date into canvas (#13567) ### What problem does this PR solve? Inject sys.date into canvas. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/canvas.py | 10 ++++++++-- web/src/constants/agent.tsx | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index 48c4a6cbbc4..15983024b9b 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -15,6 +15,7 @@ # import asyncio import base64 +import datetime import inspect import binascii import json @@ -287,7 +288,8 @@ def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custo "sys.user_id": tenant_id, "sys.conversation_turns": 0, "sys.files": [], - "sys.history": [] + "sys.history": [], + "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") } self.variables = {} super().__init__(dsl, tenant_id, task_id, custom_header=custom_header) @@ -300,13 +302,16 @@ def load(self): self.globals = self.dsl["globals"] if "sys.history" not in self.globals: self.globals["sys.history"] = [] + if "sys.date" not in self.globals: + self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") else: self.globals = { "sys.query": "", "sys.user_id": "", "sys.conversation_turns": 0, "sys.files": [], - "sys.history": [] + "sys.history": [], + "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") } if "variables" in self.dsl: self.variables = self.dsl["variables"] @@ -368,6 +373,7 @@ def reset(self, mem=False): self.globals[k] = "" async def run(self, **kwargs): + self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") st = time.perf_counter() self._loop = asyncio.get_running_loop() self.message_id = get_uuid() diff --git a/web/src/constants/agent.tsx b/web/src/constants/agent.tsx index ad487e01f45..cdb1df91547 100644 --- a/web/src/constants/agent.tsx +++ b/web/src/constants/agent.tsx @@ -33,6 +33,7 @@ export enum AgentGlobals { SysConversationTurns = 'sys.conversation_turns', SysFiles = 'sys.files', SysHistory = 'sys.history', + SysDate = 'sys.date', } export const AgentGlobalsSysQueryWithBrace = `{${AgentGlobals.SysQuery}}`; @@ -272,5 +273,6 @@ export const EmptyDsl = { [AgentGlobals.SysConversationTurns]: 0, [AgentGlobals.SysFiles]: [], [AgentGlobals.SysHistory]: [], + [AgentGlobals.SysDate]: '', }, }; From 23c50af671ed84702305c09c95276249bb7a1270 Mon Sep 17 00:00:00 2001 From: Jinghan Xu <2827092384@qq.com> Date: Thu, 12 Mar 2026 18:02:12 +0800 Subject: [PATCH 230/565] Fix: allow document parsing status recovery after transient errors (#13341) ### What problem does this PR solve? Fixes #13285 When an LLM returns a transient error (e.g. overloaded) during parsing, the task progress is set to -1. Previously, the progress could never be updated again, leaving the document permanently stuck in FAIL status even after the task successfully recovered and completed. Three coordinated changes address this: 1. task_service.update_progress: relax the progress update guard to accept prog >= 1 even when current progress is -1, so a task that recovers from a transient failure can report completion. 2. document_service.get_unfinished_docs: include documents that are marked FAIL (progress == -1) but still have at least one non-failed task (task.progress >= 0) in the polling set, so their status can be re-synced once a task recovers. Documents where all tasks have permanently failed are excluded to avoid unnecessary polling. 3. document_service.update_progress: explicitly set document status to RUNNING when not all tasks have finished, instead of preserving whatever stale status (potentially FAIL) the document previously had. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/document_service.py | 16 ++++++++++++---- api/db/services/task_service.py | 17 ++++++----------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 8671aa98909..4782bf85de1 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -522,15 +522,21 @@ def get_newly_uploaded(cls): @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] - unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1)) + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, + cls.model.run, cls.model.parser_id] + unfinished_task_query = Task.select(Task.doc_id).where( + (Task.progress >= 0) & (Task.progress < 1) + ) + docs_with_non_failed_tasks = Task.select(Task.doc_id).where(Task.progress >= 0).distinct() docs = cls.model.select(*fields).where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)), - (((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))), - ) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap + (((cls.model.progress < 1) & (cls.model.progress > 0)) | + (cls.model.id.in_(unfinished_task_query)) | + ((cls.model.progress == -1) & (cls.model.run == TaskStatus.FAIL.value) & + (cls.model.id.in_(docs_with_non_failed_tasks))))) # including GraphRAG/RAPTOR/Mindmap; re-sync failed docs return list(docs.dicts()) @classmethod @@ -850,6 +856,8 @@ def _sync_progress(cls, docs: list[dict]): elif finished: prg = 1 status = TaskStatus.DONE.value + elif not finished: + status = TaskStatus.RUNNING.value # only for special task and parsed docs and unfinished freeze_progress = special_task_running and doc_progress >= 1 and not finished diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 3975c0ec3fc..80817323076 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -304,9 +304,8 @@ def update_progress(cls, id, info): Update Rules: - progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines. - - progress: Only updates if the current progress is not -1 AND - (the new progress is -1 OR greater than the existing progress), - to avoid overwriting valid progress with invalid or regressive values. + - progress: Updates when (a) new progress >= 1 (allows recovery from -1), or + (b) current progress != -1 AND (new progress is -1 OR greater than existing). Args: id (str): The unique identifier of the task to update. @@ -327,10 +326,8 @@ def update_progress(cls, id, info): prog = info["progress"] cls.model.update(progress=prog).where( (cls.model.id == id) & - ( - (cls.model.progress != -1) & - ((prog == -1) | (prog > cls.model.progress)) - ) + ((prog >= 1) | ((cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)))) ).execute() else: with DB.lock("update_progress", -1): @@ -341,10 +338,8 @@ def update_progress(cls, id, info): prog = info["progress"] cls.model.update(progress=prog).where( (cls.model.id == id) & - ( - (cls.model.progress != -1) & - ((prog == -1) | (prog > cls.model.progress)) - ) + ((prog >= 1) | ((cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)))) ).execute() process_duration = (datetime.now() - task.begin_at).total_seconds() From 79c4d9fefa4ff9a72d40e75a0c3d795eb14dfe35 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 12 Mar 2026 18:58:25 +0800 Subject: [PATCH 231/565] Create go version storage component, but not used (#13561) ### What problem does this PR solve? Implement: minio, s3, oss, azure_sas, azure_spn, gcs, opendal ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- go.mod | 73 ++++- go.sum | 246 ++++++++++++++-- internal/storage/minio.go | 397 ++++++++++++++++++++++++++ internal/storage/oss.go | 415 +++++++++++++++++++++++++++ internal/storage/s3.go | 425 ++++++++++++++++++++++++++++ internal/storage/storage_factory.go | 298 +++++++++++++++++++ internal/storage/types.go | 102 +++++++ 7 files changed, 1932 insertions(+), 24 deletions(-) create mode 100644 internal/storage/minio.go create mode 100644 internal/storage/oss.go create mode 100644 internal/storage/s3.go create mode 100644 internal/storage/storage_factory.go create mode 100644 internal/storage/types.go diff --git a/go.mod b/go.mod index 3a47d85eb4b..7aa72113e29 100644 --- a/go.mod +++ b/go.mod @@ -1,51 +1,99 @@ module ragflow -go 1.24.0 +go 1.25 require ( + cloud.google.com/go/storage v1.35.1 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 + github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake v1.4.4 + github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2/config v1.32.11 + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4 + github.com/aws/smithy-go v1.24.2 github.com/elastic/go-elasticsearch/v8 v8.19.1 github.com/gin-gonic/gin v1.9.1 - github.com/google/uuid v1.4.0 + github.com/go-sql-driver/mysql v1.7.0 + github.com/google/uuid v1.6.0 github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a + github.com/minio/minio-go/v7 v7.0.99 + github.com/peterh/liner v1.2.2 github.com/redis/go-redis/v9 v9.18.0 github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de github.com/spf13/viper v1.18.2 go.uber.org/zap v1.27.1 golang.org/x/crypto v0.47.0 + google.golang.org/api v0.153.0 gorm.io/driver/mysql v1.5.2 gorm.io/gorm v1.25.5 ) require ( + cloud.google.com/go v0.110.10 // indirect + cloud.google.com/go/compute v1.23.3 // indirect + cloud.google.com/go/compute/metadata v0.2.3 // indirect + cloud.google.com/go/iam v1.1.5 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/elastic-transport-go/v8 v8.8.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.16.0 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/google/s2a-go v0.1.7 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect + github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.2.11 // indirect + github.com/klauspost/crc32 v1.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.3 // indirect + github.com/minio/crc64nvme v1.1.1 // indirect + github.com/minio/md5-simd v1.1.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect - github.com/peterh/liner v1.2.2 // indirect + github.com/philhofer/fwd v1.2.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/rs/xid v1.6.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -53,19 +101,30 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tinylib/msgp v1.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel v1.28.0 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect go.opentelemetry.io/otel/trace v1.28.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.10.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect - golang.org/x/net v0.48.0 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/oauth2 v0.15.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect - golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.5.0 // indirect + golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect + google.golang.org/grpc v1.59.0 // indirect google.golang.org/protobuf v1.32.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 23b9cdd0d15..965ceba2d24 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,71 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.110.10 h1:LXy9GEO+timppncPIAZoOj3l58LIU9k+kn48AN7IO3Y= +cloud.google.com/go v0.110.10/go.mod h1:v1OoFqYxiBkUrruItNM3eT4lLByNjxmJSV/xDKJNnic= +cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= +cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= +cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= +cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= +cloud.google.com/go/iam v1.1.5 h1:1jTsCu4bcsNsE4iiqNT5SHwrDRCfRmIaaaVFhRveTJI= +cloud.google.com/go/iam v1.1.5/go.mod h1:rB6P/Ic3mykPbFio+vo7403drjlgvoWfYpJhMXEbzv8= +cloud.google.com/go/storage v1.35.1 h1:B59ahL//eDfx2IIKFBeT5Atm9wnNmj3+8xG/W4WB//w= +cloud.google.com/go/storage v1.35.1/go.mod h1:M6M/3V/D3KpzMTJyPOR/HU6n2Si5QdaXYEsng2xgOs8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew= +github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake v1.4.4 h1:7QtoGxKm6mPhsWzEZtrn3tQF1hmMMZblngnqNoE61I8= +github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake v1.4.4/go.mod h1:juYrzH1q6A+g9ZZbGh0OmjS7zaMq3rFDrPhVnYSgFMA= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= +github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 h1:qi3e/dmpdONhj1RyIZdi6DKKpDXS5Lb8ftr3p7cyHJc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20/go.mod h1:V1K+TeJVD5JOk3D9e5tsX2KUdL7BlB+FV6cBhdobN8c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 h1:BYf7XNsJMzl4mObARUBUib+j2tf0U//JAAtTnYqvqCw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11/go.mod h1:aEUS4WrNk/+FxkBZZa7tVgp4pGH+kFGW40Y8rCPqt5g= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 h1:JnQeStZvPHFHeyky/7LbMlyQjUa+jIBj36OlWm0pzIk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19/go.mod h1:HGyasyHvYdFQeJhvDHfH7HXkHh57htcJGKDZ+7z+I24= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4 h1:4ExZyubQ6LQQVuF2Qp9OsfEvsTdAWh5Gfwf6PgIdLdk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4/go.mod h1:NF3JcMGOiARAss1ld3WGORCw71+4ExDD2cbbdKS5PpA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -5,21 +73,30 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo= github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= github.com/elastic/go-elasticsearch/v8 v8.19.1 h1:0iEGt5/Ds9MNVxEp3hqLsXdbe6SjleaVHONg/FuR09Q= github.com/elastic/go-elasticsearch/v8 v8.19.1/go.mod h1:tHJQdInFa6abmDbDCEH2LJja07l/SIpaGpJcm13nt7s= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -30,6 +107,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -47,11 +126,47 @@ github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/martian/v3 v3.3.2 h1:IqNFLAmvJOgVlpdEBiQbDc2EwKW77amAycfTuWKdfvw= +github.com/google/martian/v3 v3.3.2/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= +github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= +github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= +github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= +github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= +github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a h1:Inib12UR9HAfBubrGNraPjKt/Cu8xPbTJbC50+0wP5U= @@ -62,13 +177,22 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= +github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= +github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= @@ -77,6 +201,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.3 h1:a+kO+98RDGEfo6asOGMmpodZq4FNtnGP54yps8BzLR4= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI= +github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.99 h1:2vH/byrwUkIpFQFOilvTfaUpvAX3fEFhEzO+DR3DlCE= +github.com/minio/minio-go/v7 v7.0.99/go.mod h1:EtGNKtlX20iL2yaYnxEigaIvj0G0GwSDnifnG8ClIdw= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -88,13 +218,20 @@ github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOS github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/peterh/liner v1.2.2 h1:aJ4AOodmL+JxOZZEL2u9iJf8omNRpqHc/EbrK+3mAXw= github.com/peterh/liner v1.2.2/go.mod h1:xFwJyiKIXJZUKItq5dGHZSTBRAuG/CpeNpWLyiNRNwI= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -121,16 +258,20 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= +github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= @@ -147,31 +288,100 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.6.0 h1:S0JTfE48HbRj80+4tbvZDYsJ3tGv6BUU3XxyZ7CirAc= golang.org/x/arch v0.6.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.15.0 h1:s8pnnxNVzjWyrvYdFUQq5llS1PX2zhPXmccZv99h7uQ= +golang.org/x/oauth2 v0.15.0/go.mod h1:q48ptWNTY5XWf+JNten23lcvHpLJ0ZSxF5ttTHKVCAM= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.153.0 h1:N1AwGhielyKFaUqH07/ZSIQR3uNPcV7NVw0vj+j4iR4= +google.golang.org/api v0.153.0/go.mod h1:3qNJX5eOmhiWYc67jRA/3GsDw97UFb5ivv7Y2PrriAY= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= +google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:J7XzRzVy1+IPwWHZUzoD0IccYZIrXILAQpc+Qy9CMhY= +google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 h1:JpwMPBpFN3uKhdaekDpiNlImDdkUAyiJ6ez/uxGaUSo= +google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:0xJLfVdJqpAPl8tDg1ujOCGzx6LFLttXT5NhllGOXY4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f h1:ultW7fxlIvee4HYrtnaRPon9HpEgFk5zYpmfMgtKB5I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f/go.mod h1:L9KNLi232K1/xB6f7AlSX692koaRnKaWSR0stBki0Yc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= +google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -182,4 +392,6 @@ gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/storage/minio.go b/internal/storage/minio.go new file mode 100644 index 00000000000..280694efc4d --- /dev/null +++ b/internal/storage/minio.go @@ -0,0 +1,397 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package storage + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "net/http" + "time" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "go.uber.org/zap" +) + +// MinioConfig holds MinIO storage configuration +type MinioConfig struct { + Host string `mapstructure:"host"` // MinIO server host (e.g., "localhost:9000") + User string `mapstructure:"user"` // Access key + Password string `mapstructure:"password"` // Secret key + Secure bool `mapstructure:"secure"` // Use HTTPS + Verify bool `mapstructure:"verify"` // Verify SSL certificates + Bucket string `mapstructure:"bucket"` // Default bucket (optional) + PrefixPath string `mapstructure:"prefix_path"` // Path prefix (optional) +} + +// MinioStorage implements Storage interface for MinIO +type MinioStorage struct { + client *minio.Client + bucket string + prefixPath string + config *MinioConfig +} + +// NewMinioStorage creates a new MinIO storage instance +func NewMinioStorage(config *MinioConfig) (*MinioStorage, error) { + storage := &MinioStorage{ + bucket: config.Bucket, + prefixPath: config.PrefixPath, + config: config, + } + + if err := storage.connect(); err != nil { + return nil, err + } + + return storage, nil +} + +func (m *MinioStorage) connect() error { + var transport http.RoundTripper + + // Configure transport for SSL/TLS verification + if m.config.Secure { + verify := m.config.Verify + transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: !verify, + }, + } + } + + client, err := minio.New(m.config.Host, &minio.Options{ + Creds: credentials.NewStaticV4(m.config.User, m.config.Password, ""), + Secure: m.config.Secure, + Transport: transport, + }) + if err != nil { + return fmt.Errorf("failed to connect to MinIO: %w", err) + } + + m.client = client + return nil +} + +func (m *MinioStorage) reconnect() { + if err := m.connect(); err != nil { + zap.L().Error("Failed to reconnect to MinIO", zap.Error(err)) + } +} + +func (m *MinioStorage) resolveBucketAndPath(bucket, fnm string) (string, string) { + actualBucket := bucket + if m.bucket != "" { + actualBucket = m.bucket + } + + actualPath := fnm + if m.bucket != "" { + if m.prefixPath != "" { + actualPath = fmt.Sprintf("%s/%s/%s", m.prefixPath, bucket, fnm) + } else { + actualPath = fmt.Sprintf("%s/%s", bucket, fnm) + } + } else if m.prefixPath != "" { + actualPath = fmt.Sprintf("%s/%s", m.prefixPath, fnm) + } + + return actualBucket, actualPath +} + +// Health checks MinIO service availability +func (m *MinioStorage) Health() bool { + ctx := context.Background() + + if m.bucket != "" { + exists, err := m.client.BucketExists(ctx, m.bucket) + if err != nil { + zap.L().Warn("MinIO health check failed", zap.Error(err)) + return false + } + return exists + } + + _, err := m.client.ListBuckets(ctx) + if err != nil { + zap.L().Warn("MinIO health check failed", zap.Error(err)) + return false + } + return true +} + +// Put uploads an object to MinIO +func (m *MinioStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { + bucket, fnm = m.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 3; i++ { + // Ensure bucket exists + if m.bucket == "" { + exists, err := m.client.BucketExists(ctx, bucket) + if err != nil { + zap.L().Error("Failed to check bucket existence", zap.String("bucket", bucket), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + if !exists { + if err := m.client.MakeBucket(ctx, bucket, minio.MakeBucketOptions{}); err != nil { + zap.L().Error("Failed to create bucket", zap.String("bucket", bucket), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + } + } + + reader := bytes.NewReader(binary) + _, err := m.client.PutObject(ctx, bucket, fnm, reader, int64(len(binary)), minio.PutObjectOptions{}) + if err != nil { + zap.L().Error("Failed to put object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + + return nil + } + + return fmt.Errorf("failed to put object after 3 retries") +} + +// Get retrieves an object from MinIO +func (m *MinioStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { + bucket, fnm = m.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 2; i++ { + obj, err := m.client.GetObject(ctx, bucket, fnm, minio.GetObjectOptions{}) + if err != nil { + zap.L().Error("Failed to get object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + defer obj.Close() + + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(obj); err != nil { + zap.L().Error("Failed to read object data", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + + return buf.Bytes(), nil + } + + return nil, fmt.Errorf("failed to get object after retries") +} + +// Rm removes an object from MinIO +func (m *MinioStorage) Rm(bucket, fnm string, tenantID ...string) error { + bucket, fnm = m.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + if err := m.client.RemoveObject(ctx, bucket, fnm, minio.RemoveObjectOptions{}); err != nil { + zap.L().Error("Failed to remove object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + return err + } + + return nil +} + +// ObjExist checks if an object exists in MinIO +func (m *MinioStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { + bucket, fnm = m.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + exists, err := m.client.BucketExists(ctx, bucket) + if err != nil || !exists { + return false + } + + _, err = m.client.StatObject(ctx, bucket, fnm, minio.StatObjectOptions{}) + if err != nil { + errResponse := minio.ToErrorResponse(err) + if errResponse.Code == "NoSuchKey" || errResponse.Code == "NoSuchBucket" { + return false + } + zap.L().Error("Failed to stat object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + return false + } + + return true +} + +// GetPresignedURL generates a presigned URL for accessing an object +func (m *MinioStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { + bucket, fnm = m.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 10; i++ { + url, err := m.client.PresignedGetObject(ctx, bucket, fnm, expires, nil) + if err != nil { + zap.L().Error("Failed to get presigned URL", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + m.reconnect() + time.Sleep(time.Second) + continue + } + + return url.String(), nil + } + + return "", fmt.Errorf("failed to get presigned URL after 10 retries") +} + +// BucketExists checks if a bucket exists +func (m *MinioStorage) BucketExists(bucket string) bool { + actualBucket := bucket + if m.bucket != "" { + actualBucket = m.bucket + } + + ctx := context.Background() + + exists, err := m.client.BucketExists(ctx, actualBucket) + if err != nil { + zap.L().Error("Failed to check bucket existence", zap.String("bucket", actualBucket), zap.Error(err)) + return false + } + + return exists +} + +// RemoveBucket removes a bucket and all its objects +func (m *MinioStorage) RemoveBucket(bucket string) error { + actualBucket := bucket + origBucket := bucket + + if m.bucket != "" { + actualBucket = m.bucket + } + + ctx := context.Background() + + // Build prefix for single-bucket mode + prefix := "" + if m.bucket != "" { + if m.prefixPath != "" { + prefix = fmt.Sprintf("%s/", m.prefixPath) + } + prefix += fmt.Sprintf("%s/", origBucket) + } + + // List and delete objects with prefix + objectsCh := make(chan minio.ObjectInfo) + + go func() { + defer close(objectsCh) + for obj := range m.client.ListObjects(ctx, actualBucket, minio.ListObjectsOptions{ + Prefix: prefix, + Recursive: true, + }) { + if obj.Err != nil { + zap.L().Error("Error listing objects", zap.Error(obj.Err)) + return + } + objectsCh <- obj + } + }() + + for err := range m.client.RemoveObjects(ctx, actualBucket, objectsCh, minio.RemoveObjectsOptions{}) { + zap.L().Error("Failed to remove object", zap.String("key", err.ObjectName), zap.Error(err.Err)) + } + + // Only remove the actual bucket if not in single-bucket mode + if m.bucket == "" { + if err := m.client.RemoveBucket(ctx, actualBucket); err != nil { + zap.L().Error("Failed to remove bucket", zap.String("bucket", actualBucket), zap.Error(err)) + return err + } + } + + return nil +} + +// Copy copies an object from source to destination +func (m *MinioStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { + srcBucket, srcPath = m.resolveBucketAndPath(srcBucket, srcPath) + destBucket, destPath = m.resolveBucketAndPath(destBucket, destPath) + + ctx := context.Background() + + // Ensure destination bucket exists + if m.bucket == "" { + exists, err := m.client.BucketExists(ctx, destBucket) + if err != nil { + zap.L().Error("Failed to check bucket existence", zap.String("bucket", destBucket), zap.Error(err)) + return false + } + if !exists { + if err := m.client.MakeBucket(ctx, destBucket, minio.MakeBucketOptions{}); err != nil { + zap.L().Error("Failed to create bucket", zap.String("bucket", destBucket), zap.Error(err)) + return false + } + } + } + + // Check if source object exists + _, err := m.client.StatObject(ctx, srcBucket, srcPath, minio.StatObjectOptions{}) + if err != nil { + zap.L().Error("Source object not found", zap.String("bucket", srcBucket), zap.String("key", srcPath), zap.Error(err)) + return false + } + + // Copy object + srcOpts := minio.CopySrcOptions{ + Bucket: srcBucket, + Object: srcPath, + } + destOpts := minio.CopyDestOptions{ + Bucket: destBucket, + Object: destPath, + } + + _, err = m.client.CopyObject(ctx, destOpts, srcOpts) + if err != nil { + zap.L().Error("Failed to copy object", zap.String("src", fmt.Sprintf("%s/%s", srcBucket, srcPath)), zap.String("dest", fmt.Sprintf("%s/%s", destBucket, destPath)), zap.Error(err)) + return false + } + + return true +} + +// Move moves an object from source to destination +func (m *MinioStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { + if m.Copy(srcBucket, srcPath, destBucket, destPath) { + if err := m.Rm(srcBucket, srcPath); err != nil { + zap.L().Error("Failed to remove source object after copy", zap.String("bucket", srcBucket), zap.String("key", srcPath), zap.Error(err)) + return false + } + return true + } + return false +} diff --git a/internal/storage/oss.go b/internal/storage/oss.go new file mode 100644 index 00000000000..d74aa9da7a2 --- /dev/null +++ b/internal/storage/oss.go @@ -0,0 +1,415 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package storage + +import ( + "bytes" + "context" + "errors" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/smithy-go" + "go.uber.org/zap" +) + +// OSSConfig holds Aliyun OSS storage configuration +// OSS is compatible with S3 API +type OSSConfig struct { + AccessKeyID string `mapstructure:"access_key"` // OSS Access Key ID + SecretAccessKey string `mapstructure:"secret_key"` // OSS Secret Access Key + EndpointURL string `mapstructure:"endpoint_url"` // OSS Endpoint (e.g., "https://oss-cn-hangzhou.aliyuncs.com") + Region string `mapstructure:"region"` // Region (e.g., "cn-hangzhou") + Bucket string `mapstructure:"bucket"` // Default bucket (optional) + PrefixPath string `mapstructure:"prefix_path"` // Path prefix (optional) + SignatureVersion string `mapstructure:"signature_version"` // Signature version + AddressingStyle string `mapstructure:"addressing_style"` // Addressing style +} + +// OSSStorage implements Storage interface for Aliyun OSS +// OSS uses S3-compatible API +type OSSStorage struct { + client *s3.Client + bucket string + prefixPath string + config *OSSConfig +} + +// NewOSSStorage creates a new OSS storage instance +func NewOSSStorage(config *OSSConfig) (*OSSStorage, error) { + storage := &OSSStorage{ + bucket: config.Bucket, + prefixPath: config.PrefixPath, + config: config, + } + + if err := storage.connect(); err != nil { + return nil, err + } + + return storage, nil +} + +func (o *OSSStorage) connect() error { + ctx := context.Background() + + // Create static credentials + creds := credentials.NewStaticCredentialsProvider( + o.config.AccessKeyID, + o.config.SecretAccessKey, + "", + ) + + // Load configuration + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(o.config.Region), + config.WithCredentialsProvider(creds), + ) + if err != nil { + return fmt.Errorf("failed to load OSS config: %w", err) + } + + // Create S3 client with OSS endpoint + o.client = s3.NewFromConfig(cfg, func(opts *s3.Options) { + opts.BaseEndpoint = aws.String(o.config.EndpointURL) + }) + + return nil +} + +func (o *OSSStorage) reconnect() { + if err := o.connect(); err != nil { + zap.L().Error("Failed to reconnect to OSS", zap.Error(err)) + } +} + +func (o *OSSStorage) resolveBucketAndPath(bucket, fnm string) (string, string) { + actualBucket := bucket + if o.bucket != "" { + actualBucket = o.bucket + } + + actualPath := fnm + if o.prefixPath != "" { + actualPath = fmt.Sprintf("%s/%s", o.prefixPath, fnm) + } + + return actualBucket, actualPath +} + +// Health checks OSS service availability +func (o *OSSStorage) Health() bool { + bucket := o.bucket + if bucket == "" { + bucket = "health-check-bucket" + } + + fnm := "txtxtxtxt1" + if o.prefixPath != "" { + fnm = fmt.Sprintf("%s/%s", o.prefixPath, fnm) + } + binary := []byte("_t@@@1") + + ctx := context.Background() + + // Ensure bucket exists + if !o.BucketExists(bucket) { + _, err := o.client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + zap.L().Error("Failed to create bucket for health check", zap.String("bucket", bucket), zap.Error(err)) + return false + } + } + + // Try to upload a test object + reader := bytes.NewReader(binary) + _, err := o.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + Body: reader, + }) + + if err != nil { + zap.L().Error("Health check failed", zap.Error(err)) + return false + } + + return true +} + +// Put uploads an object to OSS +func (o *OSSStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { + bucket, fnm = o.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 2; i++ { + // Ensure bucket exists + if !o.BucketExists(bucket) { + _, err := o.client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + zap.L().Error("Failed to create bucket", zap.String("bucket", bucket), zap.Error(err)) + o.reconnect() + time.Sleep(time.Second) + continue + } + zap.L().Info("Created bucket", zap.String("bucket", bucket)) + } + + reader := bytes.NewReader(binary) + _, err := o.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + Body: reader, + }) + if err != nil { + zap.L().Error("Failed to put object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + o.reconnect() + time.Sleep(time.Second) + continue + } + + return nil + } + + return fmt.Errorf("failed to put object after retries") +} + +// Get retrieves an object from OSS +func (o *OSSStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { + bucket, fnm = o.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 2; i++ { + result, err := o.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + zap.L().Error("Failed to get object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + o.reconnect() + time.Sleep(time.Second) + continue + } + defer result.Body.Close() + + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(result.Body); err != nil { + zap.L().Error("Failed to read object data", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + o.reconnect() + time.Sleep(time.Second) + continue + } + + return buf.Bytes(), nil + } + + return nil, fmt.Errorf("failed to get object after retries") +} + +// Rm removes an object from OSS +func (o *OSSStorage) Rm(bucket, fnm string, tenantID ...string) error { + bucket, fnm = o.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + _, err := o.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + zap.L().Error("Failed to remove object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + return err + } + + return nil +} + +// ObjExist checks if an object exists in OSS +func (o *OSSStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { + bucket, fnm = o.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + _, err := o.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + if isOSSNotFound(err) { + return false + } + return false + } + + return true +} + +// GetPresignedURL generates a presigned URL for accessing an object +func (o *OSSStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { + bucket, fnm = o.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + presignClient := s3.NewPresignClient(o.client) + + for i := 0; i < 10; i++ { + req, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }, s3.WithPresignExpires(expires)) + if err != nil { + zap.L().Error("Failed to generate presigned URL", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + o.reconnect() + time.Sleep(time.Second) + continue + } + + return req.URL, nil + } + + return "", fmt.Errorf("failed to generate presigned URL after 10 retries") +} + +// BucketExists checks if a bucket exists +func (o *OSSStorage) BucketExists(bucket string) bool { + actualBucket := bucket + if o.bucket != "" { + actualBucket = o.bucket + } + + ctx := context.Background() + + _, err := o.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(actualBucket), + }) + if err != nil { + zap.L().Debug("Bucket does not exist or error", zap.String("bucket", actualBucket), zap.Error(err)) + return false + } + + return true +} + +// RemoveBucket removes a bucket and all its objects +func (o *OSSStorage) RemoveBucket(bucket string) error { + actualBucket := bucket + if o.bucket != "" { + actualBucket = o.bucket + } + + ctx := context.Background() + + // Check if bucket exists + if !o.BucketExists(actualBucket) { + return nil + } + + // List and delete all objects + listInput := &s3.ListObjectsV2Input{ + Bucket: aws.String(actualBucket), + } + + for { + result, err := o.client.ListObjectsV2(ctx, listInput) + if err != nil { + zap.L().Error("Failed to list objects", zap.String("bucket", actualBucket), zap.Error(err)) + return err + } + + for _, obj := range result.Contents { + _, err := o.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(actualBucket), + Key: obj.Key, + }) + if err != nil { + zap.L().Error("Failed to delete object", zap.String("bucket", actualBucket), zap.Error(err)) + } + } + + if result.IsTruncated == nil || !*result.IsTruncated { + break + } + listInput.ContinuationToken = result.NextContinuationToken + } + + // Delete bucket + _, err := o.client.DeleteBucket(ctx, &s3.DeleteBucketInput{ + Bucket: aws.String(actualBucket), + }) + if err != nil { + zap.L().Error("Failed to delete bucket", zap.String("bucket", actualBucket), zap.Error(err)) + return err + } + + return nil +} + +// Copy copies an object from source to destination +func (o *OSSStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { + srcBucket, srcPath = o.resolveBucketAndPath(srcBucket, srcPath) + destBucket, destPath = o.resolveBucketAndPath(destBucket, destPath) + + ctx := context.Background() + + copySource := fmt.Sprintf("%s/%s", srcBucket, srcPath) + + _, err := o.client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: aws.String(destBucket), + Key: aws.String(destPath), + CopySource: aws.String(copySource), + }) + if err != nil { + zap.L().Error("Failed to copy object", zap.String("src", copySource), zap.String("dest", fmt.Sprintf("%s/%s", destBucket, destPath)), zap.Error(err)) + return false + } + + return true +} + +// Move moves an object from source to destination +func (o *OSSStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { + if o.Copy(srcBucket, srcPath, destBucket, destPath) { + if err := o.Rm(srcBucket, srcPath); err != nil { + zap.L().Error("Failed to remove source object after copy", zap.String("bucket", srcBucket), zap.String("key", srcPath), zap.Error(err)) + return false + } + return true + } + return false +} + +// Helper functions +func isOSSNotFound(err error) bool { + if err == nil { + return false + } + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return apiErr.ErrorCode() == "NotFound" || apiErr.ErrorCode() == "404" || apiErr.ErrorCode() == "NoSuchKey" + } + return false +} diff --git a/internal/storage/s3.go b/internal/storage/s3.go new file mode 100644 index 00000000000..af49fbd4619 --- /dev/null +++ b/internal/storage/s3.go @@ -0,0 +1,425 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package storage + +import ( + "bytes" + "context" + "errors" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/smithy-go" + "go.uber.org/zap" +) + +// S3Config holds AWS S3 storage configuration +type S3Config struct { + AccessKeyID string `mapstructure:"access_key"` // AWS Access Key ID + SecretAccessKey string `mapstructure:"secret_key"` // AWS Secret Access Key + SessionToken string `mapstructure:"session_token"` // AWS Session Token (optional) + Region string `mapstructure:"region_name"` // AWS Region + EndpointURL string `mapstructure:"endpoint_url"` // Custom endpoint (optional) + SignatureVersion string `mapstructure:"signature_version"` // Signature version + AddressingStyle string `mapstructure:"addressing_style"` // Addressing style + Bucket string `mapstructure:"bucket"` // Default bucket (optional) + PrefixPath string `mapstructure:"prefix_path"` // Path prefix (optional) +} + +// S3Storage implements Storage interface for AWS S3 +type S3Storage struct { + client *s3.Client + bucket string + prefixPath string + config *S3Config +} + +// NewS3Storage creates a new S3 storage instance +func NewS3Storage(config *S3Config) (*S3Storage, error) { + storage := &S3Storage{ + bucket: config.Bucket, + prefixPath: config.PrefixPath, + config: config, + } + + if err := storage.connect(); err != nil { + return nil, err + } + + return storage, nil +} + +func (s *S3Storage) connect() error { + ctx := context.Background() + + var opts []func(*config.LoadOptions) error + + // Configure region + if s.config.Region != "" { + opts = append(opts, config.WithRegion(s.config.Region)) + } + + // Configure credentials if provided + if s.config.AccessKeyID != "" && s.config.SecretAccessKey != "" { + creds := credentials.NewStaticCredentialsProvider( + s.config.AccessKeyID, + s.config.SecretAccessKey, + s.config.SessionToken, + ) + opts = append(opts, config.WithCredentialsProvider(creds)) + } + + // Load configuration + cfg, err := config.LoadDefaultConfig(ctx, opts...) + if err != nil { + return fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create S3 client with custom endpoint if provided + clientOpts := []func(*s3.Options){} + if s.config.EndpointURL != "" { + clientOpts = append(clientOpts, func(o *s3.Options) { + o.BaseEndpoint = aws.String(s.config.EndpointURL) + }) + } + + s.client = s3.NewFromConfig(cfg, clientOpts...) + return nil +} + +func (s *S3Storage) reconnect() { + if err := s.connect(); err != nil { + zap.L().Error("Failed to reconnect to S3", zap.Error(err)) + } +} + +func (s *S3Storage) resolveBucketAndPath(bucket, fnm string) (string, string) { + actualBucket := bucket + if s.bucket != "" { + actualBucket = s.bucket + } + + actualPath := fnm + if s.prefixPath != "" { + actualPath = fmt.Sprintf("%s/%s/%s", s.prefixPath, bucket, fnm) + } + + return actualBucket, actualPath +} + +// Health checks S3 service availability +func (s *S3Storage) Health() bool { + bucket := s.bucket + if bucket == "" { + bucket = "health-check-bucket" + } + + fnm := "txtxtxtxt1" + if s.prefixPath != "" { + fnm = fmt.Sprintf("%s/%s", s.prefixPath, fnm) + } + binary := []byte("_t@@@1") + + ctx := context.Background() + + // Ensure bucket exists + if !s.BucketExists(bucket) { + _, err := s.client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + zap.L().Error("Failed to create bucket for health check", zap.String("bucket", bucket), zap.Error(err)) + return false + } + } + + // Try to upload a test object + reader := bytes.NewReader(binary) + _, err := s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + Body: reader, + }) + + if err != nil { + zap.L().Error("Health check failed", zap.Error(err)) + return false + } + + return true +} + +// Put uploads an object to S3 +func (s *S3Storage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { + bucket, fnm = s.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 2; i++ { + // Ensure bucket exists + if !s.BucketExists(bucket) { + _, err := s.client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + zap.L().Error("Failed to create bucket", zap.String("bucket", bucket), zap.Error(err)) + s.reconnect() + time.Sleep(time.Second) + continue + } + zap.L().Info("Created bucket", zap.String("bucket", bucket)) + } + + reader := bytes.NewReader(binary) + _, err := s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + Body: reader, + }) + if err != nil { + zap.L().Error("Failed to put object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + s.reconnect() + time.Sleep(time.Second) + continue + } + + return nil + } + + return fmt.Errorf("failed to put object after retries") +} + +// Get retrieves an object from S3 +func (s *S3Storage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { + bucket, fnm = s.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + for i := 0; i < 2; i++ { + result, err := s.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + zap.L().Error("Failed to get object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + s.reconnect() + time.Sleep(time.Second) + continue + } + defer result.Body.Close() + + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(result.Body); err != nil { + zap.L().Error("Failed to read object data", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + s.reconnect() + time.Sleep(time.Second) + continue + } + + return buf.Bytes(), nil + } + + return nil, fmt.Errorf("failed to get object after retries") +} + +// Rm removes an object from S3 +func (s *S3Storage) Rm(bucket, fnm string, tenantID ...string) error { + bucket, fnm = s.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + zap.L().Error("Failed to remove object", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + return err + } + + return nil +} + +// ObjExist checks if an object exists in S3 +func (s *S3Storage) ObjExist(bucket, fnm string, tenantID ...string) bool { + bucket, fnm = s.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }) + if err != nil { + if isS3NotFound(err) { + return false + } + return false + } + + return true +} + +// GetPresignedURL generates a presigned URL for accessing an object +func (s *S3Storage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { + bucket, fnm = s.resolveBucketAndPath(bucket, fnm) + + ctx := context.Background() + + presignClient := s3.NewPresignClient(s.client) + + for i := 0; i < 10; i++ { + req, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fnm), + }, s3.WithPresignExpires(expires)) + if err != nil { + zap.L().Error("Failed to generate presigned URL", zap.String("bucket", bucket), zap.String("key", fnm), zap.Error(err)) + s.reconnect() + time.Sleep(time.Second) + continue + } + + return req.URL, nil + } + + return "", fmt.Errorf("failed to generate presigned URL after 10 retries") +} + +// BucketExists checks if a bucket exists +func (s *S3Storage) BucketExists(bucket string) bool { + actualBucket := bucket + if s.bucket != "" { + actualBucket = s.bucket + } + + ctx := context.Background() + + _, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(actualBucket), + }) + if err != nil { + zap.L().Debug("Bucket does not exist or error", zap.String("bucket", actualBucket), zap.Error(err)) + return false + } + + return true +} + +// RemoveBucket removes a bucket and all its objects +func (s *S3Storage) RemoveBucket(bucket string) error { + actualBucket := bucket + if s.bucket != "" { + actualBucket = s.bucket + } + + ctx := context.Background() + + // Check if bucket exists + if !s.BucketExists(actualBucket) { + return nil + } + + // List and delete all objects + listInput := &s3.ListObjectsV2Input{ + Bucket: aws.String(actualBucket), + } + + for { + result, err := s.client.ListObjectsV2(ctx, listInput) + if err != nil { + zap.L().Error("Failed to list objects", zap.String("bucket", actualBucket), zap.Error(err)) + return err + } + + for _, obj := range result.Contents { + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(actualBucket), + Key: obj.Key, + }) + if err != nil { + zap.L().Error("Failed to delete object", zap.String("bucket", actualBucket), zap.Error(err)) + } + } + + if result.IsTruncated == nil || !*result.IsTruncated { + break + } + listInput.ContinuationToken = result.NextContinuationToken + } + + // Delete bucket + _, err := s.client.DeleteBucket(ctx, &s3.DeleteBucketInput{ + Bucket: aws.String(actualBucket), + }) + if err != nil { + zap.L().Error("Failed to delete bucket", zap.String("bucket", actualBucket), zap.Error(err)) + return err + } + + return nil +} + +// Copy copies an object from source to destination +func (s *S3Storage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { + srcBucket, srcPath = s.resolveBucketAndPath(srcBucket, srcPath) + destBucket, destPath = s.resolveBucketAndPath(destBucket, destPath) + + ctx := context.Background() + + copySource := fmt.Sprintf("%s/%s", srcBucket, srcPath) + + _, err := s.client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: aws.String(destBucket), + Key: aws.String(destPath), + CopySource: aws.String(copySource), + }) + if err != nil { + zap.L().Error("Failed to copy object", zap.String("src", copySource), zap.String("dest", fmt.Sprintf("%s/%s", destBucket, destPath)), zap.Error(err)) + return false + } + + return true +} + +// Move moves an object from source to destination +func (s *S3Storage) Move(srcBucket, srcPath, destBucket, destPath string) bool { + if s.Copy(srcBucket, srcPath, destBucket, destPath) { + if err := s.Rm(srcBucket, srcPath); err != nil { + zap.L().Error("Failed to remove source object after copy", zap.String("bucket", srcBucket), zap.String("key", srcPath), zap.Error(err)) + return false + } + return true + } + return false +} + +// isNotFound checks if the error is a not found error +func isS3NotFound(err error) bool { + if err == nil { + return false + } + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return apiErr.ErrorCode() == "NotFound" || apiErr.ErrorCode() == "404" || apiErr.ErrorCode() == "NoSuchKey" + } + return false +} diff --git a/internal/storage/storage_factory.go b/internal/storage/storage_factory.go new file mode 100644 index 00000000000..e7f17d0584f --- /dev/null +++ b/internal/storage/storage_factory.go @@ -0,0 +1,298 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package storage + +import ( + "fmt" + "os" + "sync" + + "github.com/spf13/viper" + "go.uber.org/zap" +) + +// StorageFactory creates storage instances based on configuration +type StorageFactory struct { + storageType StorageType + storage Storage + config *StorageConfig + mu sync.RWMutex +} + +// StorageConfig holds all storage-related configurations +type StorageConfig struct { + StorageType string `mapstructure:"storage_type"` + Minio *MinioConfig `mapstructure:"minio"` + S3 *S3Config `mapstructure:"s3"` + OSS *OSSConfig `mapstructure:"oss"` +} + +// AzureConfig holds Azure-specific configurations +type AzureConfig struct { + ContainerURL string `mapstructure:"container_url"` + SASToken string `mapstructure:"sas_token"` + AccountURL string `mapstructure:"account_url"` + ClientID string `mapstructure:"client_id"` + Secret string `mapstructure:"secret"` + TenantID string `mapstructure:"tenant_id"` + ContainerName string `mapstructure:"container_name"` + AuthorityHost string `mapstructure:"authority_host"` +} + +var ( + globalFactory *StorageFactory + once sync.Once +) + +// GetStorageFactory returns the singleton storage factory instance +func GetStorageFactory() *StorageFactory { + once.Do(func() { + globalFactory = &StorageFactory{} + }) + return globalFactory +} + +// InitStorageFactory initializes the storage factory with configuration +func InitStorageFactory(v *viper.Viper) error { + factory := GetStorageFactory() + + // Get storage type from environment or config + storageType := os.Getenv("STORAGE_IMPL") + if storageType == "" { + storageType = v.GetString("storage_type") + } + if storageType == "" { + storageType = "MINIO" // Default storage type + } + + storageConfig := &StorageConfig{} + if err := v.UnmarshalKey("storage", storageConfig); err != nil { + return fmt.Errorf("failed to unmarshal storage config: %w", err) + } + storageConfig.StorageType = storageType + + factory.config = storageConfig + + // Initialize storage based on type + if err := factory.initStorage(storageType, v); err != nil { + return err + } + + zap.L().Info("Storage factory initialized", + zap.String("storage_type", storageType), + ) + + return nil +} + +// initStorage initializes the specific storage implementation +func (f *StorageFactory) initStorage(storageType string, v *viper.Viper) error { + switch storageType { + case "MINIO": + return f.initMinio(v) + case "AWS_S3": + return f.initS3(v) + case "OSS": + return f.initOSS(v) + default: + return fmt.Errorf("unsupported storage type: %s", storageType) + } +} + +func (f *StorageFactory) initMinio(v *viper.Viper) error { + config := &MinioConfig{} + + // Try to load from minio section first + if v.IsSet("minio") { + minioConfig := v.Sub("minio") + if minioConfig != nil { + config.Host = minioConfig.GetString("host") + config.User = minioConfig.GetString("user") + config.Password = minioConfig.GetString("password") + config.Secure = minioConfig.GetBool("secure") + config.Verify = minioConfig.GetBool("verify") + config.Bucket = minioConfig.GetString("bucket") + config.PrefixPath = minioConfig.GetString("prefix_path") + } + } + + // Apply defaults + if config.Host == "" { + config.Host = "localhost:9000" + } + if config.User == "" { + config.User = "minioadmin" + } + if config.Password == "" { + config.Password = "minioadmin" + } + + storage, err := NewMinioStorage(config) + if err != nil { + return fmt.Errorf("failed to create MinIO storage: %w", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + f.storageType = StorageMinio + f.storage = storage + f.config.Minio = config + + return nil +} + +func (f *StorageFactory) initS3(v *viper.Viper) error { + config := &S3Config{} + + if v.IsSet("s3") { + s3Config := v.Sub("s3") + if s3Config != nil { + config.AccessKeyID = s3Config.GetString("access_key") + config.SecretAccessKey = s3Config.GetString("secret_key") + config.SessionToken = s3Config.GetString("session_token") + config.Region = s3Config.GetString("region_name") + config.EndpointURL = s3Config.GetString("endpoint_url") + config.SignatureVersion = s3Config.GetString("signature_version") + config.AddressingStyle = s3Config.GetString("addressing_style") + config.Bucket = s3Config.GetString("bucket") + config.PrefixPath = s3Config.GetString("prefix_path") + } + } + + storage, err := NewS3Storage(config) + if err != nil { + return fmt.Errorf("failed to create S3 storage: %w", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + f.storageType = StorageAWSS3 + f.storage = storage + f.config.S3 = config + + return nil +} + +func (f *StorageFactory) initOSS(v *viper.Viper) error { + config := &OSSConfig{} + + if v.IsSet("oss") { + ossConfig := v.Sub("oss") + if ossConfig != nil { + config.AccessKeyID = ossConfig.GetString("access_key") + config.SecretAccessKey = ossConfig.GetString("secret_key") + config.EndpointURL = ossConfig.GetString("endpoint_url") + config.Region = ossConfig.GetString("region") + config.Bucket = ossConfig.GetString("bucket") + config.PrefixPath = ossConfig.GetString("prefix_path") + config.SignatureVersion = ossConfig.GetString("signature_version") + config.AddressingStyle = ossConfig.GetString("addressing_style") + } + } + + storage, err := NewOSSStorage(config) + if err != nil { + return fmt.Errorf("failed to create OSS storage: %w", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + f.storageType = StorageOSS + f.storage = storage + f.config.OSS = config + + return nil +} + +// GetStorage returns the current storage instance +func (f *StorageFactory) GetStorage() Storage { + f.mu.RLock() + defer f.mu.RUnlock() + return f.storage +} + +// GetStorageType returns the current storage type +func (f *StorageFactory) GetStorageType() StorageType { + f.mu.RLock() + defer f.mu.RUnlock() + return f.storageType +} + +// Create creates a new storage instance based on the storage type +// This is the factory method equivalent to Python's StorageFactory.create() +func (f *StorageFactory) Create(storageType StorageType) (Storage, error) { + var storage Storage + var err error + + switch storageType { + case StorageMinio: + if f.config.Minio != nil { + storage, err = NewMinioStorage(f.config.Minio) + } else { + return nil, fmt.Errorf("MinIO config not available") + } + case StorageAWSS3: + if f.config.S3 != nil { + storage, err = NewS3Storage(f.config.S3) + } else { + return nil, fmt.Errorf("S3 config not available") + } + case StorageOSS: + if f.config.OSS != nil { + storage, err = NewOSSStorage(f.config.OSS) + } else { + return nil, fmt.Errorf("OSS config not available") + } + default: + return nil, fmt.Errorf("unsupported storage type: %v", storageType) + } + + if err != nil { + return nil, err + } + + return storage, nil +} + +// SetStorage sets the storage instance (useful for testing) +func (f *StorageFactory) SetStorage(storage Storage) { + f.mu.Lock() + defer f.mu.Unlock() + f.storage = storage +} + +// StorageTypeMapping returns the storage type mapping (equivalent to Python's storage_mapping) +var StorageTypeMapping = map[StorageType]func(*StorageConfig) (Storage, error){ + StorageMinio: func(config *StorageConfig) (Storage, error) { + if config.Minio == nil { + return nil, fmt.Errorf("MinIO config not available") + } + return NewMinioStorage(config.Minio) + }, + StorageAWSS3: func(config *StorageConfig) (Storage, error) { + if config.S3 == nil { + return nil, fmt.Errorf("S3 config not available") + } + return NewS3Storage(config.S3) + }, + StorageOSS: func(config *StorageConfig) (Storage, error) { + if config.OSS == nil { + return nil, fmt.Errorf("OSS config not available") + } + return NewOSSStorage(config.OSS) + }, +} diff --git a/internal/storage/types.go b/internal/storage/types.go new file mode 100644 index 00000000000..fc777373aff --- /dev/null +++ b/internal/storage/types.go @@ -0,0 +1,102 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package storage + +import ( + "errors" + "time" +) + +var ( + // ErrNotFound is returned when an object is not found + ErrNotFound = errors.New("object not found") + // ErrBucketNotFound is returned when a bucket is not found + ErrBucketNotFound = errors.New("bucket not found") +) + +// StorageType represents the type of storage backend +type StorageType int + +const ( + StorageMinio StorageType = 1 + StorageAzureSpn StorageType = 2 + StorageAzureSas StorageType = 3 + StorageAWSS3 StorageType = 4 + StorageOSS StorageType = 5 + StorageOpenDAL StorageType = 6 + StorageGCS StorageType = 7 +) + +func (s StorageType) String() string { + switch s { + case StorageMinio: + return "MINIO" + case StorageAzureSpn: + return "AZURE_SPN" + case StorageAzureSas: + return "AZURE_SAS" + case StorageAWSS3: + return "AWS_S3" + case StorageOSS: + return "OSS" + case StorageOpenDAL: + return "OPENDAL" + case StorageGCS: + return "GCS" + default: + return "UNKNOWN" + } +} + +// Storage defines the interface for storage operations +type Storage interface { + // Health checks the storage service availability + Health() bool + + // Put uploads an object to storage + // bucket: the bucket/container name + // fnm: the file/object name (key) + // binary: the data to upload + // tenantID: optional tenant identifier + Put(bucket, fnm string, binary []byte, tenantID ...string) error + + // Get retrieves an object from storage + // Returns the data or nil if not found + Get(bucket, fnm string, tenantID ...string) ([]byte, error) + + // Rm removes an object from storage + Rm(bucket, fnm string, tenantID ...string) error + + // ObjExist checks if an object exists + ObjExist(bucket, fnm string, tenantID ...string) bool + + // GetPresignedURL generates a presigned URL for accessing an object + // expires: duration until the URL expires + GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) + + // BucketExists checks if a bucket exists + BucketExists(bucket string) bool + + // RemoveBucket removes a bucket and all its objects + RemoveBucket(bucket string) error + + // Copy copies an object from source to destination + Copy(srcBucket, srcPath, destBucket, destPath string) bool + + // Move moves an object from source to destination + Move(srcBucket, srcPath, destBucket, destPath string) bool +} From 3d9e07e6e2e24fb7e2d56f385872412a88fa41b5 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:59:56 +0800 Subject: [PATCH 232/565] Fix "Result window is too large" during meta data search (#13521) ### What problem does this PR solve? Fix https://github.com/infiniflow/ragflow/issues/13210#issuecomment-3982878498 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/joint_services/user_account_service.py | 2 +- api/db/services/doc_metadata_service.py | 57 +++++++------------ rag/utils/es_conn.py | 3 +- 3 files changed, 22 insertions(+), 40 deletions(-) diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index f9a25b498c5..6f992576a7d 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -173,7 +173,7 @@ def delete_user_data(user_id: str) -> dict: if doc_ids: for doc in doc_ids: try: - DocMetadataService.delete_document_metadata(doc["id"], doc["kb_id"], tenant_id=None, skip_empty_check=True) + DocMetadataService.delete_document_metadata(doc["id"], doc["kb_id"], tenant_id=None) except Exception as e: logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}") diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index f2ee29e6da4..dbb16a5a941 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -160,7 +160,7 @@ def _iter_search_results(cls, results): def _search_metadata(cls, kb_id: str, condition: Dict = None): """ Common search logic for metadata queries. - Uses pagination internally to retrieve ALL data from the index. + Uses pagination internally to retrieve data from the index. Args: kb_id: Knowledge base ID @@ -188,7 +188,10 @@ def _search_metadata(cls, kb_id: str, condition: Dict = None): if condition is None: condition = {"kb_id": kb_id} + # Add sort by id for ES to enable search_after on large data order_by = OrderByExpr() + if not settings.DOC_ENGINE_INFINITY: + order_by.asc("id") page_size = 1000 all_results = [] @@ -474,7 +477,7 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: # For Infinity or as fallback: use delete+insert logging.debug(f"[update_document_metadata] Using delete+insert method for doc_id: {doc_id}") - cls.delete_document_metadata(doc_id, kb_id, tenant_id, skip_empty_check=True) + cls.delete_document_metadata(doc_id, kb_id, tenant_id) return cls.insert_document_metadata(doc_id, processed_meta) except Exception as e: @@ -483,7 +486,7 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: @classmethod @DB.connection_context() - def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None, skip_empty_check: bool = False) -> bool: + def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None) -> bool: """ Delete document metadata from ES/Infinity. Also drops the metadata table if it becomes empty (efficiently). @@ -493,7 +496,6 @@ def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None doc_id: Document ID kb_id: Knowledge base ID tenant_id: Tenant ID, if not provided, get it from kb_id - skip_empty_check: If True, skip checking/dropping empty table (for bulk deletions) Returns: True if successful (or no metadata to delete), False otherwise @@ -529,9 +531,6 @@ def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None logging.debug(f"[METADATA DELETE] Get result: {existing_metadata is not None}") if not existing_metadata: logging.debug(f"[METADATA DELETE] Document {doc_id} has no metadata in table, skipping deletion") - # Only check/drop table if not skipped (tenant deletion will handle it) - if not skip_empty_check: - cls._drop_empty_metadata_table(index_name, tenant_id) return True # No metadata to delete is success except Exception as e: # If get fails, document might not exist in metadata table, which is fine @@ -548,14 +547,6 @@ def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None kb_id # Pass actual kb_id (delete() will handle metadata tables correctly) ) logging.debug(f"[METADATA DELETE] Deleted count: {deleted_count}") - - # Only check if table should be dropped if not skipped (for bulk operations) - # Note: delete operation already uses refresh=True, so data is immediately available - if not skip_empty_check: - # Check by querying the actual metadata table (not MySQL) - cls._drop_empty_metadata_table(index_name, tenant_id) - - logging.debug(f"Successfully deleted metadata for document {doc_id}") return True except Exception as e: @@ -782,21 +773,18 @@ def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dictionary mapping doc_id to meta_fields dict """ try: - results = cls._search_metadata(kb_id, condition={"kb_id": kb_id}) + condition = {"kb_id": kb_id} + if doc_ids: + condition["id"] = doc_ids + results = cls._search_metadata(kb_id, condition=condition) if not results: return {} # Build mapping: doc_id -> meta_fields meta_mapping = {} - # If doc_ids is provided, create a set for efficient lookup - doc_ids_set = set(doc_ids) if doc_ids else None - - # Use helper to iterate over results in any format + # Use helper to iterate over results for doc_id, doc in cls._iter_search_results(results): - # Filter by doc_ids if provided - if doc_ids_set is not None and doc_id not in doc_ids_set: - continue # Extract metadata (handles both JSON strings and dicts) doc_meta = cls._extract_metadata(doc) @@ -850,13 +838,13 @@ def _meta_value_type(value): return "string" try: - results = cls._search_metadata(kb_id, condition={"kb_id": kb_id}) + condition = {"kb_id": kb_id} + if doc_ids: + condition["id"] = doc_ids + results = cls._search_metadata(kb_id, condition=condition) if not results: return {} - # If doc_ids are provided, we'll filter after the search - doc_ids_set = set(doc_ids) if doc_ids else None - # Aggregate metadata summary = {} type_counter = {} @@ -865,9 +853,6 @@ def _meta_value_type(value): # Use helper to iterate over results in any format for doc_id, doc in cls._iter_search_results(results): - # Check doc_ids filter - if doc_ids_set and doc_id not in doc_ids_set: - continue doc_meta = cls._extract_metadata(doc) @@ -1029,22 +1014,17 @@ def _apply_deletes(meta): return changed try: - results = cls._search_metadata(kb_id, condition=None) + results = cls._search_metadata(kb_id, condition={"kb_id": kb_id, "id": doc_ids}) if not results: results = [] # Treat as empty list if None updated_docs = 0 - doc_ids_set = set(doc_ids) found_doc_ids = set() logging.debug(f"[batch_update_metadata] Searching for doc_ids: {doc_ids}") - # Use helper to iterate over results in any format + # Use helper to iterate over results for doc_id, doc in cls._iter_search_results(results): - # Filter to only process requested doc_ids - if doc_id not in doc_ids_set: - continue - found_doc_ids.add(doc_id) # Get current metadata @@ -1066,13 +1046,14 @@ def _apply_deletes(meta): logging.debug(f"[batch_update_metadata] Updating doc_id: {doc_id}, meta: {meta}") # If metadata is empty, delete the row entirely instead of keeping empty metadata if not meta: - cls.delete_document_metadata(doc_id, kb_id, tenant_id=None, skip_empty_check=True) + cls.delete_document_metadata(doc_id, kb_id, tenant_id=None) else: cls.update_document_metadata(doc_id, meta) updated_docs += 1 # Handle documents that don't have metadata rows yet # These documents weren't in the search results, so we need to insert new metadata for them + doc_ids_set = set(doc_ids) missing_doc_ids = doc_ids_set - found_doc_ids if missing_doc_ids and updates: logging.debug(f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 6f88c9a44e1..6a3d35eec68 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -264,7 +264,8 @@ def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = assert "id" in d d_copy = copy.deepcopy(d) d_copy["kb_id"] = knowledgebase_id - meta_id = d_copy.pop("id", "") + # Use id as _id for uniqueness, also keep "id" as a regular field for sorting + meta_id = d_copy.get("id", "") operations.append( {"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) From 335fb1dc89535c1f0c9f7eb575ebc0a6718c3af0 Mon Sep 17 00:00:00 2001 From: Ray Zhang Date: Thu, 12 Mar 2026 19:01:25 +0800 Subject: [PATCH 233/565] docs(migration): add project name (-p) usage to backup & migration guide (#13565) ## Summary - Add documentation for the `-p project_name` flag in the migration script, covering all steps (stop, backup, restore, start) - Add a note explaining how Docker volume name prefixes relate to the Compose project name - Update `docker-compose` to `docker compose` (Compose V2 syntax) for consistency - Fix `sh` to `bash` to match the script's shebang line This is the documentation follow-up to #12187 which added `-p` project name support to `docker/migration.sh`. ## Test plan - [ ] Verify the documentation renders correctly on the docs site - [ ] Confirm all example commands are accurate against the current `migration.sh` --- docs/administrator/backup_and_migration.md | 44 ++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/docs/administrator/backup_and_migration.md b/docs/administrator/backup_and_migration.md index b8e48a9cab2..8a55691b68e 100644 --- a/docs/administrator/backup_and_migration.md +++ b/docs/administrator/backup_and_migration.md @@ -39,15 +39,25 @@ local docker_redis_data These volumes contain all the data you need to migrate. +:::note +The volume name prefix (e.g., `docker_`) comes from the Docker Compose project name. By default it is `docker` (derived from the directory name). If you started RAGFlow with `docker compose -p `, your volumes will be prefixed with `_` instead, for example `ragflow_mysql_data`. +::: + ### Step 1: Stop RAGFlow services Before starting the migration, you must stop all running RAGFlow services on the **source machine**. Navigate to the project's root directory and run: ```bash -docker-compose -f docker/docker-compose.yml down +docker compose -f docker/docker-compose.yml down ``` -**Important:** Do **not** use the `-v` flag (e.g., `docker-compose down -v`), as this will delete all your data volumes. The migration script includes a check and will prevent you from running it if services are active. +If you started RAGFlow with a custom project name (e.g., `docker compose -p ragflow`), include it in the command: + +```bash +docker compose -p ragflow -f docker/docker-compose.yml down +``` + +**Important:** Do **not** use the `-v` flag (e.g., `docker compose down -v`), as this will delete all your data volumes. The migration script includes a check and will prevent you from running it if services are active. ### Step 2: Back up your data @@ -74,6 +84,13 @@ bash docker/migration.sh backup my_ragflow_backup This will create a folder named `my_ragflow_backup/` instead. +If you started RAGFlow with a custom project name (e.g., `docker compose -p ragflow`), use the `-p` flag so the script can find the correct volumes: + +```bash +bash docker/migration.sh -p ragflow backup +bash docker/migration.sh -p ragflow backup my_ragflow_backup +``` + ### Step 3: Transfer the backup folder Copy the entire backup folder (e.g., `backup/` or `my_ragflow_backup/`) from your source machine to the RAGFlow project directory on your **target machine**. You can use tools like `scp`, `rsync`, or a physical drive for the transfer. @@ -94,6 +111,13 @@ If you used a custom name, specify it in the command: bash docker/migration.sh restore my_ragflow_backup ``` +If the target machine uses a custom project name, use the `-p` flag to ensure the volumes are created with the correct prefix: + +```bash +bash docker/migration.sh -p ragflow restore +bash docker/migration.sh -p ragflow restore my_ragflow_backup +``` + The script will automatically create the necessary Docker volumes and unpack the data. **Note:** If the script detects that Docker volumes with the same names already exist on the target machine, it will warn you that restoring will overwrite the existing data and ask for confirmation before proceeding. @@ -103,16 +127,22 @@ The script will automatically create the necessary Docker volumes and unpack the Once the restore process is complete, you can start the RAGFlow services on your new machine: ```bash -docker-compose -f docker/docker-compose.yml up -d +docker compose -f docker/docker-compose.yml up -d +``` + +If you use a custom project name: + +```bash +docker compose -p ragflow -f docker/docker-compose.yml up -d ``` -**Note:** If you already have built a service by docker-compose before, you may need to backup your data for target machine like this guide above and run like: +**Note:** If you already have built a service by docker compose before, you may need to backup your data for target machine like this guide above and run like: ```bash -# Please backup by `sh docker/migration.sh backup backup_dir_name` before you do the following line. +# Please backup by `bash docker/migration.sh backup backup_dir_name` before you do the following line. # !!! this line -v flag will delete the original docker volume -docker-compose -f docker/docker-compose.yml down -v -docker-compose -f docker/docker-compose.yml up -d +docker compose -f docker/docker-compose.yml down -v +docker compose -f docker/docker-compose.yml up -d ``` Your RAGFlow instance is now running with all the data from your original machine. From 2cd4ffb7e9e1678b5cf6e6ac8c463b13d1902964 Mon Sep 17 00:00:00 2001 From: guptas6est Date: Thu, 12 Mar 2026 11:04:26 +0000 Subject: [PATCH 234/565] fix(web): upgrade lodash to 4.17.23 and dompurify to 3.3.2 to fix CVE-2026-0540 and CVE-2025-13465 (#13488) ### What problem does this PR solve? This PR fixes two security vulnerabilities in web dependencies identified by Trivy: 1. CVE-2025-13465 (lodash): Prototype pollution vulnerability in _.unset and _.omit functions 2. CVE-2026-0540 (dompurify): Cross-site scripting (XSS) vulnerability **Changes:** - Upgraded lodash from 4.17.21 to 4.17.23 - Upgraded dompurify from 3.3.1 to 3.3.2 - Added npm override to force monaco-editor's transitive dependency on dompurify to use 3.3.2 (monaco-editor still depends on vulnerable 3.2.7) Both upgrades are backward-compatible patch versions. Build verified successfully with no breaking changes. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/package-lock.json | 35 ++++++++++++++--------------------- web/package.json | 9 ++++++--- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/web/package-lock.json b/web/package-lock.json index 4e6406fd011..811fd580c59 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -60,7 +60,7 @@ "clsx": "^2.1.1", "cmdk": "^1.0.4", "dayjs": "^1.11.10", - "dompurify": "^3.1.6", + "dompurify": "^3.3.2", "embla-carousel-react": "^8.6.0", "eventsource-parser": "^1.1.2", "human-id": "^4.1.1", @@ -72,7 +72,7 @@ "jsencrypt": "^3.3.2", "jsoneditor": "^10.4.2", "lexical": "^0.23.1", - "lodash": "^4.17.21", + "lodash": "^4.17.23", "lucide-react": "^0.546.0", "mammoth": "^1.7.2", "next-themes": "^0.4.6", @@ -12829,10 +12829,13 @@ "license": "MIT" }, "node_modules/dompurify": { - "version": "3.3.1", - "resolved": "https://registry.npmmirror.com/dompurify/-/dompurify-3.3.1.tgz", - "integrity": "sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q==", + "version": "3.3.2", + "resolved": "https://registry.npmmirror.com/dompurify/-/dompurify-3.3.2.tgz", + "integrity": "sha512-6obghkliLdmKa56xdbLOpUZ43pAR6xFy1uOrxBaIDjT+yaRuuybLjGS9eVBoSR/UPU5fq3OXClEHLJNGvbxKpQ==", "license": "(MPL-2.0 OR Apache-2.0)", + "engines": { + "node": ">=20" + }, "optionalDependencies": { "@types/trusted-types": "^2.0.7" } @@ -18509,15 +18512,15 @@ } }, "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmmirror.com/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmmirror.com/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "license": "MIT" }, "node_modules/lodash-es": { - "version": "4.17.22", - "resolved": "https://registry.npmmirror.com/lodash-es/-/lodash-es-4.17.22.tgz", - "integrity": "sha512-XEawp1t0gxSi9x01glktRZ5HDy0HXqrM0x5pXQM98EaI0NxO6jVM7omDOxsuEo5UIASAnm2bRp1Jt/e0a2XU8Q==", + "version": "4.17.23", + "resolved": "https://registry.npmmirror.com/lodash-es/-/lodash-es-4.17.23.tgz", + "integrity": "sha512-kVI48u3PZr38HdYz98UmfPnXl2DXrpdctLrFLCd3kOx1xUkOmpFPx7gCWWM5MPkL/fD8zb+Ph0QzjGFs4+hHWg==", "license": "MIT" }, "node_modules/lodash.debounce": { @@ -20122,16 +20125,6 @@ "marked": "14.0.0" } }, - "node_modules/monaco-editor/node_modules/dompurify": { - "version": "3.2.7", - "resolved": "https://registry.npmmirror.com/dompurify/-/dompurify-3.2.7.tgz", - "integrity": "sha512-WhL/YuveyGXJaerVlMYGWhvQswa7myDG17P7Vu65EWC05o8vfeNbvNf4d/BOvH99+ZW+LlQsc1GDKMa1vNK6dw==", - "license": "(MPL-2.0 OR Apache-2.0)", - "peer": true, - "optionalDependencies": { - "@types/trusted-types": "^2.0.7" - } - }, "node_modules/mri": { "version": "1.2.0", "resolved": "https://registry.npmmirror.com/mri/-/mri-1.2.0.tgz", diff --git a/web/package.json b/web/package.json index 40c30f134cb..439097982bd 100644 --- a/web/package.json +++ b/web/package.json @@ -20,7 +20,10 @@ ] }, "overrides": { - "@radix-ui/react-dismissable-layer": "1.1.4" + "@radix-ui/react-dismissable-layer": "1.1.4", + "monaco-editor": { + "dompurify": "3.3.2" + } }, "dependencies": { "@ant-design/icons": "^5.2.6", @@ -76,7 +79,7 @@ "clsx": "^2.1.1", "cmdk": "^1.0.4", "dayjs": "^1.11.10", - "dompurify": "^3.1.6", + "dompurify": "^3.3.2", "embla-carousel-react": "^8.6.0", "eventsource-parser": "^1.1.2", "human-id": "^4.1.1", @@ -88,7 +91,7 @@ "jsencrypt": "^3.3.2", "jsoneditor": "^10.4.2", "lexical": "^0.23.1", - "lodash": "^4.17.21", + "lodash": "^4.17.23", "lucide-react": "^0.546.0", "mammoth": "^1.7.2", "next-themes": "^0.4.6", From d952e924431dfd2a11388a4bd51a5fe00e957ff1 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Thu, 12 Mar 2026 19:06:20 +0800 Subject: [PATCH 235/565] Feature (System Settings): Implemented system settings management functionality (#13556) ### What problem does this PR solve? Feature (System Settings): Implemented system settings management functionality - Added a new SystemSettings model, including creation and update time fields. - Implemented SystemSettingsDAO, providing CRUD operations and transaction support. - Implemented management interfaces for variables, configurations, and environment variables in the admin service. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Yingfeng --- internal/admin/handler.go | 74 ++++++++++--- internal/admin/service.go | 173 +++++++++++++++++++++++++---- internal/dao/system_settings.go | 188 ++++++++++++++++++++++++++++++++ internal/model/system.go | 14 ++- 4 files changed, 410 insertions(+), 39 deletions(-) create mode 100644 internal/dao/system_settings.go diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 8acbb8cb19a..2dc0eb8803c 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -767,28 +767,46 @@ func (h *Handler) RestartService(c *gin.Context) { } // GetVariables handle get variables +// Python logic: if request body is empty, list all variables; otherwise get single variable by var_name from body func (h *Handler) GetVariables(c *gin.Context) { - varName := c.Query("var_name") - - if varName != "" { - // Get single variable - variable, err := h.service.GetVariable(varName) + // Check if request has body content + if c.Request.ContentLength == 0 || c.Request.ContentLength == -1 { + // List all variables + variables, err := h.service.GetAllVariables() if err != nil { - errorResponse(c, err.Error(), 400) + errorResponse(c, err.Error(), 500) return } - success(c, variable, "") + success(c, variables, "") + return + } + + // Get single variable by var_name from request body + var req struct { + VarName string `json:"var_name"` + } + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Invalid request body", 400) return } - // List all variables - variables, err := h.service.GetAllVariables() + if req.VarName == "" { + errorResponse(c, "Var name is required", 400) + return + } + + variable, err := h.service.GetVariable(req.VarName) if err != nil { + // Check if it's an AdminException + if adminErr, ok := err.(*AdminException); ok { + errorResponse(c, adminErr.Message, 400) + return + } errorResponse(c, err.Error(), 500) return } - success(c, variables, "") + success(c, variable, "") } // SetVariableHTTPRequest set variable request @@ -798,15 +816,31 @@ type SetVariableHTTPRequest struct { } // SetVariable handle set variable +// Python logic: update or create a system setting with the given name and value func (h *Handler) SetVariable(c *gin.Context) { var req SetVariableHTTPRequest if err := c.ShouldBindJSON(&req); err != nil { - errorResponse(c, "Var name and value are required", 400) + errorResponse(c, "Var name is required", 400) + return + } + + if req.VarName == "" { + errorResponse(c, "Var name is required", 400) + return + } + + if req.VarValue == "" { + errorResponse(c, "Var value is required", 400) return } if err := h.service.SetVariable(req.VarName, req.VarValue); err != nil { - errorResponse(c, err.Error(), 400) + // Check if it's an AdminException + if adminErr, ok := err.(*AdminException); ok { + errorResponse(c, adminErr.Message, 400) + return + } + errorResponse(c, err.Error(), 500) return } @@ -814,10 +848,16 @@ func (h *Handler) SetVariable(c *gin.Context) { } // GetConfigs handle get configs +// Python logic: return all service configurations func (h *Handler) GetConfigs(c *gin.Context) { configs, err := h.service.GetAllConfigs() if err != nil { - errorResponse(c, err.Error(), 400) + // Check if it's an AdminException + if adminErr, ok := err.(*AdminException); ok { + errorResponse(c, adminErr.Message, 400) + return + } + errorResponse(c, err.Error(), 500) return } @@ -825,10 +865,16 @@ func (h *Handler) GetConfigs(c *gin.Context) { } // GetEnvironments handle get environments +// Python logic: return important environment variables func (h *Handler) GetEnvironments(c *gin.Context) { environments, err := h.service.GetAllEnvironments() if err != nil { - errorResponse(c, err.Error(), 400) + // Check if it's an AdminException + if adminErr, ok := err.(*AdminException); ok { + errorResponse(c, adminErr.Message, 400) + return + } + errorResponse(c, err.Error(), 500) return } diff --git a/internal/admin/service.go b/internal/admin/service.go index dc1fb74061d..8d7cc2f59d4 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -46,15 +46,17 @@ var ( // Service admin service layer type Service struct { - userDAO *dao.UserDAO - licenseDAO *dao.LicenseDAO + userDAO *dao.UserDAO + licenseDAO *dao.LicenseDAO + systemSettingsDAO *dao.SystemSettingsDAO } // NewService create admin service func NewService() *Service { return &Service{ - userDAO: dao.NewUserDAO(), - licenseDAO: dao.NewLicenseDAO(), + userDAO: dao.NewUserDAO(), + licenseDAO: dao.NewLicenseDAO(), + systemSettingsDAO: dao.NewSystemSettingsDAO(), } } @@ -888,43 +890,172 @@ func (s *Service) RestartService(serviceID string) (map[string]interface{}, erro // Variable/Settings methods -// GetVariable get variable -func (s *Service) GetVariable(varName string) (map[string]interface{}, error) { - // TODO: Implement with settings manager - return map[string]interface{}{ - "var_name": varName, - "var_value": "", - }, nil +// AdminException admin exception error +type AdminException struct { + Message string + Code int +} + +// Error implement error interface +func (e *AdminException) Error() string { + return e.Message +} + +// NewAdminException create admin exception +func NewAdminException(message string) *AdminException { + return &AdminException{ + Message: message, + Code: 400, + } +} + +// GetVariable get variable by name +// Returns the system setting with the given name +// Returns AdminException if the setting is not found +func (s *Service) GetVariable(varName string) ([]map[string]interface{}, error) { + settings, err := s.systemSettingsDAO.GetByName(varName) + if err != nil { + return nil, err + } + + if len(settings) == 0 { + return nil, NewAdminException("Can't get setting: " + varName) + } + + result := make([]map[string]interface{}, 0, len(settings)) + for _, setting := range settings { + result = append(result, map[string]interface{}{ + "name": setting.Name, + "source": setting.Source, + "data_type": setting.DataType, + "value": setting.Value, + }) + } + return result, nil } // GetAllVariables get all variables +// Returns all system settings from database func (s *Service) GetAllVariables() ([]map[string]interface{}, error) { - // TODO: Implement with settings manager - return []map[string]interface{}{}, nil + settings, err := s.systemSettingsDAO.GetAll() + if err != nil { + return nil, err + } + + result := make([]map[string]interface{}, 0, len(settings)) + for _, setting := range settings { + result = append(result, map[string]interface{}{ + "name": setting.Name, + "source": setting.Source, + "data_type": setting.DataType, + "value": setting.Value, + }) + } + return result, nil } // SetVariable set variable +// Creates or updates a system setting +// If the setting exists, updates it; otherwise creates a new one func (s *Service) SetVariable(varName, varValue string) error { - // TODO: Implement with settings manager - _ = varName - _ = varValue - return nil + settings, err := s.systemSettingsDAO.GetByName(varName) + if err != nil { + return err + } + + if len(settings) == 1 { + setting := &settings[0] + setting.Value = varValue + return s.systemSettingsDAO.UpdateByName(varName, setting) + } else if len(settings) > 1 { + return NewAdminException("Can't update more than 1 setting: " + varName) + } + + // Create new setting if it doesn't exist + // Determine data_type based on name and value + dataType := "string" + if len(varName) >= 7 && varName[:7] == "sandbox" { + dataType = "json" + } else if len(varName) >= 9 && varName[len(varName)-9:] == ".enabled" { + dataType = "boolean" + } + + newSetting := &model.SystemSettings{ + Name: varName, + Value: varValue, + Source: "admin", + DataType: dataType, + } + return s.systemSettingsDAO.Create(newSetting) } // Config methods // GetAllConfigs get all configs +// Returns all service configurations from the config file func (s *Service) GetAllConfigs() ([]map[string]interface{}, error) { - // TODO: Implement with config manager - return []map[string]interface{}{}, nil + result := server.GetAllConfigs() + return result, nil } // Environment methods // GetAllEnvironments get all environments +// Returns important environment variables func (s *Service) GetAllEnvironments() ([]map[string]interface{}, error) { - // TODO: Implement with environment manager - return []map[string]interface{}{}, nil + result := make([]map[string]interface{}, 0) + + // DOC_ENGINE + docEngine := os.Getenv("DOC_ENGINE") + if docEngine == "" { + docEngine = "elasticsearch" + } + result = append(result, map[string]interface{}{ + "env": "DOC_ENGINE", + "value": docEngine, + }) + + // DEFAULT_SUPERUSER_EMAIL + defaultSuperuserEmail := os.Getenv("DEFAULT_SUPERUSER_EMAIL") + if defaultSuperuserEmail == "" { + defaultSuperuserEmail = "admin@ragflow.io" + } + result = append(result, map[string]interface{}{ + "env": "DEFAULT_SUPERUSER_EMAIL", + "value": defaultSuperuserEmail, + }) + + // DB_TYPE + dbType := os.Getenv("DB_TYPE") + if dbType == "" { + dbType = "mysql" + } + result = append(result, map[string]interface{}{ + "env": "DB_TYPE", + "value": dbType, + }) + + // DEVICE + device := os.Getenv("DEVICE") + if device == "" { + device = "cpu" + } + result = append(result, map[string]interface{}{ + "env": "DEVICE", + "value": device, + }) + + // STORAGE_IMPL + storageImpl := os.Getenv("STORAGE_IMPL") + if storageImpl == "" { + storageImpl = "MINIO" + } + result = append(result, map[string]interface{}{ + "env": "STORAGE_IMPL", + "value": storageImpl, + }) + + return result, nil } // Version methods diff --git a/internal/dao/system_settings.go b/internal/dao/system_settings.go new file mode 100644 index 00000000000..858d63ba505 --- /dev/null +++ b/internal/dao/system_settings.go @@ -0,0 +1,188 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "errors" + "time" + + "ragflow/internal/model" + + "gorm.io/gorm" +) + +// SystemSettingsDAO system settings data access object +type SystemSettingsDAO struct{} + +// NewSystemSettingsDAO create system settings DAO instance +func NewSystemSettingsDAO() *SystemSettingsDAO { + return &SystemSettingsDAO{} +} + +// GetAll get all system settings +// Returns all system settings records from database +func (d *SystemSettingsDAO) GetAll() ([]model.SystemSettings, error) { + var settings []model.SystemSettings + err := DB.Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// GetByName get system settings by name +// Returns settings records that match the given name +func (d *SystemSettingsDAO) GetByName(name string) ([]model.SystemSettings, error) { + var settings []model.SystemSettings + err := DB.Where("name = ?", name).Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// UpdateByName update system settings by name +// Updates the setting with the given name using the provided data +func (d *SystemSettingsDAO) UpdateByName(name string, setting *model.SystemSettings) error { + now := time.Now().Unix() + nowDate := time.Now() + + return DB.Model(&model.SystemSettings{}). + Where("name = ?", name). + Updates(map[string]interface{}{ + "value": setting.Value, + "source": setting.Source, + "data_type": setting.DataType, + "update_time": now, + "update_date": nowDate, + }).Error +} + +// Create create a new system setting +// Inserts a new system setting record into database +func (d *SystemSettingsDAO) Create(setting *model.SystemSettings) error { + now := time.Now().Unix() + nowDate := time.Now() + + setting.CreateTime = &now + setting.CreateDate = &nowDate + setting.UpdateTime = &now + setting.UpdateDate = &nowDate + + return DB.Create(setting).Error +} + +// SaveOrCreate update existing setting or create new one +// If setting exists, updates it; otherwise creates a new record +func (d *SystemSettingsDAO) SaveOrCreate(name string, value string, source string, dataType string) error { + settings, err := d.GetByName(name) + if err != nil { + return err + } + + if len(settings) == 1 { + setting := &settings[0] + setting.Value = value + return d.UpdateByName(name, setting) + } else if len(settings) > 1 { + return errors.New("can't update more than 1 setting: " + name) + } + + newSetting := &model.SystemSettings{ + Name: name, + Value: value, + Source: source, + DataType: dataType, + } + return d.Create(newSetting) +} + +// Count get total count of system settings +func (d *SystemSettingsDAO) Count() (int64, error) { + var count int64 + err := DB.Model(&model.SystemSettings{}).Count(&count).Error + return count, err +} + +// DeleteByName delete system setting by name +func (d *SystemSettingsDAO) DeleteByName(name string) error { + return DB.Where("name = ?", name).Delete(&model.SystemSettings{}).Error +} + +// Exists check if setting exists by name +func (d *SystemSettingsDAO) Exists(name string) (bool, error) { + var count int64 + err := DB.Model(&model.SystemSettings{}).Where("name = ?", name).Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +// GetBySource get system settings by source +func (d *SystemSettingsDAO) GetBySource(source string) ([]model.SystemSettings, error) { + var settings []model.SystemSettings + err := DB.Where("source = ?", source).Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// GetByDataType get system settings by data type +func (d *SystemSettingsDAO) GetByDataType(dataType string) ([]model.SystemSettings, error) { + var settings []model.SystemSettings + err := DB.Where("data_type = ?", dataType).Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// Transaction execute operations in a transaction +func (d *SystemSettingsDAO) Transaction(fn func(tx *gorm.DB) error) error { + return DB.Transaction(fn) +} + +// CreateWithTx create setting within transaction +func (d *SystemSettingsDAO) CreateWithTx(tx *gorm.DB, setting *model.SystemSettings) error { + now := time.Now().Unix() + nowDate := time.Now() + + setting.CreateTime = &now + setting.CreateDate = &nowDate + setting.UpdateTime = &now + setting.UpdateDate = &nowDate + + return tx.Create(setting).Error +} + +// UpdateByNameWithTx update setting within transaction +func (d *SystemSettingsDAO) UpdateByNameWithTx(tx *gorm.DB, name string, setting *model.SystemSettings) error { + now := time.Now().Unix() + nowDate := time.Now() + + return tx.Model(&model.SystemSettings{}). + Where("name = ?", name). + Updates(map[string]interface{}{ + "value": setting.Value, + "source": setting.Source, + "data_type": setting.DataType, + "update_time": now, + "update_date": nowDate, + }).Error +} diff --git a/internal/model/system.go b/internal/model/system.go index be94f1653a6..a67645a6b82 100644 --- a/internal/model/system.go +++ b/internal/model/system.go @@ -16,12 +16,18 @@ package model +import "time" + // SystemSettings system settings model type SystemSettings struct { - Name string `gorm:"column:name;primaryKey;size:128" json:"name"` - Source string `gorm:"column:source;size:32;not null" json:"source"` - DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` - Value string `gorm:"column:value;type:longtext;not null" json:"value"` + Name string `gorm:"column:name;primaryKey;size:128" json:"name"` + Source string `gorm:"column:source;size:32;not null" json:"source"` + DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` + Value string `gorm:"column:value;type:longtext;not null" json:"value"` + CreateTime *int64 `gorm:"column:create_time" json:"create_time"` + CreateDate *time.Time `gorm:"column:create_date" json:"create_date"` + UpdateTime *int64 `gorm:"column:update_time" json:"update_time"` + UpdateDate *time.Time `gorm:"column:update_date" json:"update_date"` } // TableName specify table name From c1d51c87c6cb3829ceda4fdbbf6e2cc3bc563988 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 12 Mar 2026 20:02:50 +0800 Subject: [PATCH 236/565] Go: Add admin server status checking (#13571) ### What problem does this PR solve? RAGFlow server isn't available when admin server isn't connected. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- admin/client/parser.py | 6 ++ admin/client/ragflow_client.py | 13 +++++ build.sh | 3 +- cmd/server_main.go | 10 +++- docker/entrypoint.sh | 1 + internal/handler/auth.go | 19 +++++-- internal/handler/user.go | 43 +++++++++++++-- internal/router/router.go | 19 ++++--- internal/server/local/admin_status.go | 79 +++++++++++++++++++++++++++ internal/service/heartbeat_sender.go | 14 ++--- 10 files changed, 181 insertions(+), 26 deletions(-) create mode 100644 internal/server/local/admin_status.go diff --git a/admin/client/parser.py b/admin/client/parser.py index e82517f16b2..788e3459926 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -96,6 +96,7 @@ | show_fingerprint | set_license | show_license + | check_license | benchmark // meta command definition @@ -183,6 +184,7 @@ SERVER: "SERVER"i FINGERPRINT: "FINGERPRINT"i LICENSE: "LICENSE"i +CHECK: "CHECK"i login_user: LOGIN USER quoted_string ";" list_services: LIST SERVICES ";" @@ -231,6 +233,7 @@ show_fingerprint: SHOW FINGERPRINT ";" set_license: SET LICENSE quoted_string ";" show_license: SHOW LICENSE ";" +check_license: CHECK LICENSE ";" list_server_configs: LIST SERVER CONFIGS ";" @@ -496,6 +499,9 @@ def set_license(self, items): def show_license(self, items): return {"type": "show_license"} + def check_license(self, items): + return {"type": "check_license"} + def list_server_configs(self, items): return {"type": "list_server_configs"} diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index b59b931bfe2..d878e9e4aa6 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -614,6 +614,17 @@ def show_license(self, command): else: print(f"Fail to show license, code: {res_json['code']}, message: {res_json['message']}") + def check_license(self, command): + if self.server_type != "admin": + print("This command is only allowed in ADMIN mode") + response = self.http_client.request("GET", "/admin/license?check=true", use_api_base=True, auth_kind="admin") + res_json = response.json() + if response.status_code == 200: + print(res_json["data"]) + else: + print(f"Fail to show license, code: {res_json['code']}, message: {res_json['message']}") + + def list_server_configs(self, command): """List server configs by calling /system/configs API and flattening the JSON response.""" response = self.http_client.request("GET", "/system/configs", use_api_base=False, auth_kind="web") @@ -1551,6 +1562,8 @@ def run_command(client: RAGFlowClient, command_dict: dict): client.set_license(command_dict) case "show_license": client.show_license(command_dict) + case "check_license": + client.check_license(command_dict) case "list_server_configs": client.list_server_configs(command_dict) case "create_model_provider": diff --git a/build.sh b/build.sh index 70fe162437b..5c075120d15 100755 --- a/build.sh +++ b/build.sh @@ -92,7 +92,8 @@ build_go() { echo "Building Go binary: $OUTPUT_BINARY" GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$OUTPUT_BINARY" ./cmd/server_main.go - + GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$OUTPUT_BINARY" ./cmd/admin_server.go + if [ ! -f "$OUTPUT_BINARY" ]; then echo -e "${RED}Error: Failed to build Go binary${NC}" exit 1 diff --git a/cmd/server_main.go b/cmd/server_main.go index 81e66080e2c..a6eb3408bbe 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -9,6 +9,7 @@ import ( "os/signal" "ragflow/internal/common" "ragflow/internal/server" + "ragflow/internal/server/local" "ragflow/internal/utility" "strings" "syscall" @@ -123,6 +124,9 @@ func main() { logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) } + // Initialize admin status (default: unavailable=1) + local.InitAdminStatus(1, "admin server not connected") + // Initialize tokenizer (rag_analyzer) tokenizerCfg := &tokenizer.PoolConfig{ DictPath: "/usr/share/infinity/resource", @@ -238,7 +242,11 @@ func startServer(config *server.Config) { } else { // Start heartbeat reporter with 30 seconds interval heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { - if err := heartbeatService.SendHeartbeat(); err != nil { + var message string + if err, message = heartbeatService.SendHeartbeat(); err == nil { + local.SetAdminStatus(0, "") + } else { + local.SetAdminStatus(1, message) logger.Warn("Failed to send heartbeat", zap.Error(err)) } }) diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index f6557b3e4e7..9497d8c6a14 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -249,6 +249,7 @@ if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then echo "Starting admin_server..." while true; do "$PY" admin/server/admin_server.py & + bin/admin_server & wait; sleep 1; done & diff --git a/internal/handler/auth.go b/internal/handler/auth.go index ca232645a0c..57e7a29ccd2 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -17,8 +17,11 @@ package handler import ( + "fmt" "net/http" "ragflow/internal/common" + "ragflow/internal/logger" + "ragflow/internal/server/local" "ragflow/internal/service" "github.com/gin-gonic/gin" @@ -69,13 +72,21 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { return } + if !local.IsAdminAvailable() { + license := local.GetAdminStatus() + errMsg := fmt.Sprintf("server license %s, check admin server status", license.Reason) + logger.Warn(errMsg) + c.JSON(http.StatusServiceUnavailable, gin.H{ + "code": common.CodeUnauthorized, + "message": errMsg, + "data": "No", + }) + return + } + c.Set("user", user) c.Set("user_id", user.ID) c.Set("email", user.Email) c.Next() } } - -func (h *AuthHandler) LoginByEmail1(c *gin.Context) { - println("hello") -} diff --git a/internal/handler/user.go b/internal/handler/user.go index 2678ecf1bf1..96f34498044 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -21,6 +21,7 @@ import ( "net/http" "ragflow/internal/common" "ragflow/internal/server" + "ragflow/internal/server/local" "ragflow/internal/utility" "strconv" @@ -164,6 +165,16 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { return } + if !local.IsAdminAvailable() { + license := local.GetAdminStatus() + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeAuthenticationError, + "message": license.Reason, + "data": "No", + }) + return + } + user, code, err := h.userService.LoginByEmail(&req, false) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -291,14 +302,38 @@ func (h *UserHandler) ListUsers(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/logout [post] func (h *UserHandler) Logout(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) + // Same as AuthMiddleware@auth.go + token := c.GetHeader("Authorization") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Missing Authorization header", + }) + c.Abort() + return + } + + // Get user by access token + user, code, err := h.userService.GetUserByToken(token) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": code, + "message": "Invalid access token", + }) + c.Abort() + return + } + + if *user.IsSuperuser { + c.JSON(http.StatusForbidden, gin.H{ + "code": common.CodeForbidden, + "message": "Super user should access the URL", + }) return } // Logout user - code, err := h.userService.Logout(user) + code, err = h.userService.Logout(user) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": code, diff --git a/internal/router/router.go b/internal/router/router.go index b7f8b0a6714..1085e0c4c97 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -93,12 +93,13 @@ func (r *Router) Setup(engine *gin.Engine) { // User login by email endpoint engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + // User logout endpoint + engine.GET("/v1/user/logout", r.userHandler.Logout) + // Protected routes authorized := engine.Group("") authorized.Use(r.authHandler.AuthMiddleware()) { - // User logout endpoint - authorized.GET("/v1/user/logout", r.userHandler.Logout) // User info endpoint authorized.GET("/v1/user/info", r.userHandler.Info) // User tenant info endpoint @@ -116,13 +117,13 @@ func (r *Router) Setup(engine *gin.Engine) { v1 := authorized.Group("/api/v1") { // User routes - users := v1.Group("/users") - { - users.POST("/register", r.userHandler.Register) - users.POST("/login", r.userHandler.Login) - users.GET("", r.userHandler.ListUsers) - users.GET("/:id", r.userHandler.GetUserByID) - } + //users := v1.Group("/users") + //{ + // users.POST("/register", r.userHandler.Register) + // users.POST("/login", r.userHandler.Login) + // users.GET("", r.userHandler.ListUsers) + // users.GET("/:id", r.userHandler.GetUserByID) + //} // Document routes documents := v1.Group("/documents") diff --git a/internal/server/local/admin_status.go b/internal/server/local/admin_status.go new file mode 100644 index 00000000000..5c2e8ab2984 --- /dev/null +++ b/internal/server/local/admin_status.go @@ -0,0 +1,79 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package local + +import ( + "sync" +) + +// AdminStatus represents the admin status +// 0 = valid, 1 = invalid +type AdminStatus struct { + Status int `json:"status"` // 0 = available, 1 = not available + Reason string `json:"reason"` // reason for invalid status +} + +var ( + adminStatus *AdminStatus + adminStatusMu sync.RWMutex + adminStatusOnce sync.Once +) + +// InitAdminStatus initializes the global admin status +// status: 0 = valid, 1 = invalid (default) +func InitAdminStatus(status int, reason string) { + adminStatusOnce.Do(func() { + adminStatus = &AdminStatus{ + Status: status, + Reason: reason, + } + }) +} + +// GetAdminStatus returns the current admin status +func GetAdminStatus() AdminStatus { + adminStatusMu.RLock() + defer adminStatusMu.RUnlock() + if adminStatus == nil { + return AdminStatus{Status: 1, Reason: "not initialized"} + } + return AdminStatus{ + Status: adminStatus.Status, + Reason: adminStatus.Reason, + } +} + +// SetAdminStatus updates the admin status +func SetAdminStatus(status int, reason string) { + adminStatusMu.Lock() + defer adminStatusMu.Unlock() + if adminStatus == nil { + adminStatus = &AdminStatus{} + } + adminStatus.Status = status + adminStatus.Reason = reason +} + +// IsAdminAvailable returns true if admin is valid (Status == 0) +func IsAdminAvailable() bool { + adminStatusMu.RLock() + defer adminStatusMu.RUnlock() + if adminStatus == nil { + return false + } + return adminStatus.Status == 0 +} diff --git a/internal/service/heartbeat_sender.go b/internal/service/heartbeat_sender.go index 3d7539848ba..47c4d67550e 100644 --- a/internal/service/heartbeat_sender.go +++ b/internal/service/heartbeat_sender.go @@ -76,12 +76,12 @@ func (h *HeartbeatSender) InitHTTPClient() error { } // SendHeartbeat sends a heartbeat message to the admin server -func (h *HeartbeatSender) SendHeartbeat() error { +func (h *HeartbeatSender) SendHeartbeat() (error, string) { if h.attemptCount < 10 { if h.lastSuccess { h.attemptCount++ - return nil + return nil, "" } } h.attemptCount = 0 @@ -90,7 +90,7 @@ func (h *HeartbeatSender) SendHeartbeat() error { if h.client == nil { if err := h.InitHTTPClient(); err != nil { h.logger.Error("Failed to initialize HTTP client", zap.Error(err)) - return err + return err, "internal error, fail to initialize HTTP client" } } @@ -109,19 +109,19 @@ func (h *HeartbeatSender) SendHeartbeat() error { jsonData, err := json.Marshal(message) if err != nil { h.logger.Error("Failed to marshal heartbeat message", zap.Error(err)) - return err + return err, "fail to parse the message" } resp, err := h.client.PostJSON("/api/v1/admin/reports", jsonData) if err != nil { - return err + return err, "can't connect with admin server" } defer resp.Body.Close() if resp.StatusCode != 200 { errMsg := fmt.Errorf("Heartbeat request failed with status code: %d", resp.StatusCode) h.logger.Warn(errMsg.Error()) - return errMsg + return errMsg, errMsg.Error() } h.logger.Debug("Heartbeat sent successfully", @@ -131,5 +131,5 @@ func (h *HeartbeatSender) SendHeartbeat() error { h.lastSuccess = true - return nil + return nil, "" } From f9a925e0a36fe70d1df743e6eed4ba4e5e081f85 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 12 Mar 2026 20:39:57 +0800 Subject: [PATCH 237/565] Expose go version server and admin server port out of docker in CI (#13572) ### What problem does this PR solve? - Print Go version log when start server - Expose the server port in CI docker container ### Type of change - [x] Other (please describe): For CI Signed-off-by: Jin Hai --- .github/workflows/tests.yml | 4 ++++ cmd/admin_server.go | 2 +- cmd/server_main.go | 2 +- docker/.env | 2 ++ internal/server/config.go | 6 +++--- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 47d94c212f8..eda2ea3f794 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -185,6 +185,8 @@ jobs: SVR_HTTP_PORT=$((9380 + RUNNER_NUM * 10)) ADMIN_SVR_HTTP_PORT=$((9381 + RUNNER_NUM * 10)) SVR_MCP_PORT=$((9382 + RUNNER_NUM * 10)) + GO_HTTP_PORT=$((9384 + RUNNER_NUM * 10)) + GO_ADMIN_PORT=$((9385 + RUNNER_NUM * 10)) SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + RUNNER_NUM * 10)) SVR_WEB_HTTP_PORT=$((80 + RUNNER_NUM * 10)) SVR_WEB_HTTPS_PORT=$((443 + RUNNER_NUM * 10)) @@ -205,6 +207,8 @@ jobs: echo -e "SVR_HTTP_PORT=${SVR_HTTP_PORT}" >> docker/.env echo -e "ADMIN_SVR_HTTP_PORT=${ADMIN_SVR_HTTP_PORT}" >> docker/.env echo -e "SVR_MCP_PORT=${SVR_MCP_PORT}" >> docker/.env + echo -e "GO_HTTP_PORT=${GO_HTTP_PORT}" >> docker/.env + echo -e "GO_ADMIN_PORT=${GO_ADMIN_PORT}" >> docker/.env echo -e "SANDBOX_EXECUTOR_MANAGER_PORT=${SANDBOX_EXECUTOR_MANAGER_PORT}" >> docker/.env echo -e "SVR_WEB_HTTP_PORT=${SVR_WEB_HTTP_PORT}" >> docker/.env echo -e "SVR_WEB_HTTPS_PORT=${SVR_WEB_HTTPS_PORT}" >> docker/.env diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 291f7868eef..fae63a0e888 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -159,7 +159,7 @@ func main() { // Start server in a goroutine go func() { - logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) + logger.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion())) logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) diff --git a/cmd/server_main.go b/cmd/server_main.go index a6eb3408bbe..cc2510ec857 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -216,7 +216,7 @@ func startServer(config *server.Config) { " / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /\n" + " /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n", ) - logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) + logger.Info(fmt.Sprintf("RAGFlow Go Version: %s", utility.GetRAGFlowVersion())) logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) diff --git a/docker/.env b/docker/.env index 79f4f91b34c..c7e68a756a1 100644 --- a/docker/.env +++ b/docker/.env @@ -152,6 +152,8 @@ SVR_WEB_HTTPS_PORT=443 SVR_HTTP_PORT=9380 ADMIN_SVR_HTTP_PORT=9381 SVR_MCP_PORT=9382 +GO_HTTP_PORT=9384 +GO_ADMIN_PORT=9385 # The RAGFlow Docker image to download. v0.22+ doesn't include embedding models. RAGFLOW_IMAGE=infiniflow/ragflow:v0.24.0 diff --git a/internal/server/config.go b/internal/server/config.go index acdcf4c0179..111dec213cf 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -337,9 +337,9 @@ func Init(configPath string) error { globalConfig.Admin.Host = "127.0.0.1" } if globalConfig.Admin.Port == 0 { - globalConfig.Admin.Port = 9383 + globalConfig.Admin.Port = 9385 } else { - globalConfig.Admin.Port += 2 + globalConfig.Admin.Port += 4 } // Load REGISTER_ENABLED from environment variable (default: 1) @@ -374,7 +374,7 @@ func Init(configPath string) error { if v.IsSet("ragflow") { ragflowConfig := v.Sub("ragflow") if ragflowConfig != nil { - globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default + globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 4 // 9384, by default //globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct // If mode is not set, default to debug if globalConfig.Server.Mode == "" { From 99726cf2450b5d8c3744b4fb5622255fe4b55081 Mon Sep 17 00:00:00 2001 From: Ethan Clarke Date: Thu, 12 Mar 2026 20:41:46 +0800 Subject: [PATCH 238/565] feat: add MiniMax-M2.5 and M2.5-highspeed models (#13557) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add MiniMax's latest M2.5 model family to the model registry and update the default API base URL to the international endpoint for broader accessibility. ## Changes - **Add MiniMax-M2.5 models** to `conf/llm_factories.json`: - `MiniMax-M2.5` — Peak Performance. Ultimate Value. Master the Complex. - `MiniMax-M2.5-highspeed` — Same performance, faster and more agile. - Both support 204,800 token context window and tool calling (`is_tools: true`). - **Update default MiniMax API base URL** in `rag/llm/__init__.py`: - From `https://api.minimaxi.com/v1` (domestic) to `https://api.minimax.io/v1` (international). - Chinese users can still override via the Base URL field in the UI settings (as documented in existing i18n strings). ## Supported Models | Model | Context Window | Tool Calling | Description | |-------|---------------|-------------|-------------| | `MiniMax-M2.5` | 204,800 tokens | Yes | Peak Performance. Ultimate Value. | | `MiniMax-M2.5-highspeed` | 204,800 tokens | Yes | Same performance, faster and more agile. | ## API Documentation - OpenAI Compatible API: https://platform.minimax.io/docs/api-reference/text-openai-api ## Testing - [x] JSON validation passes - [x] Python syntax validation passes - [x] Ruff lint passes - [x] MiniMax-M2.5 API call verified (returns valid response) - [x] MiniMax-M2.5-highspeed API call verified (returns valid response) Co-authored-by: PR Bot Co-authored-by: Jin Hai Co-authored-by: Yingfeng --- conf/llm_factories.json | 14 ++++++++++++++ rag/llm/__init__.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 20ef720f1a3..170e22340c4 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -1300,6 +1300,20 @@ "rank": "810", "url": "https://api.minimaxi.com/v1", "llm": [ + { + "llm_name": "MiniMax-M2.5", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMax-M2.5-highspeed", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "MiniMax-M2.1", "tags": "LLM,CHAT,200k", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 77b1ff2b0e2..9cbce5acd9c 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -81,7 +81,7 @@ class SupportedLiteLLMProvider(StrEnum): SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", - SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1", + SupportedLiteLLMProvider.MiniMax: "https://api.minimax.io/v1", SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1", SupportedLiteLLMProvider.OpenAI: "https://api.openai.com/v1", SupportedLiteLLMProvider.n1n: "https://api.n1n.ai/v1", From 8ca1ef5b4003767ab9f4a4fdf71f43688d14c37e Mon Sep 17 00:00:00 2001 From: Jimmy Ben Klieve Date: Thu, 12 Mar 2026 21:01:09 +0800 Subject: [PATCH 239/565] refactor(ui): unify top level pages structure, use standard language codes and time zones (#13573) ### What problem does this PR solve? - Unify top level pages structure - Standardize locale language codes (BCP 47) and time zones (IANA tz) > **Note:** > Newly created user info brings non-standard default values `timezone: "UTC+8\tAsia/Shanghai"` and `language: "English"`. ### Type of change - [x] Refactoring --- .../chat-overview-modal/api-content.tsx | 5 +- web/src/components/avatar-upload.tsx | 79 +-- web/src/components/card-container.tsx | 6 +- .../card-singleline-container/index.tsx | 22 +- web/src/components/empty/empty.tsx | 8 +- web/src/components/list-filter-bar/index.tsx | 6 +- .../originui/select-with-search.tsx | 10 +- web/src/components/ui/button.tsx | 3 +- web/src/components/ui/ragflow-pagination.tsx | 9 +- web/src/components/ui/segmented.tsx | 2 - web/src/constants/common.ts | 10 +- web/src/constants/setting.ts | 453 ++---------------- web/src/hooks/logic-hooks.ts | 7 +- web/src/hooks/use-user-setting-request.tsx | 23 +- web/src/layouts/components/header.tsx | 31 +- web/src/lib/utils.ts | 17 +- web/src/locales/ar.ts | 2 +- web/src/locales/bg.ts | 2 +- web/src/locales/config.ts | 55 ++- web/src/locales/de.ts | 2 +- web/src/locales/en.ts | 2 +- web/src/locales/it.ts | 2 +- web/src/locales/until.ts | 60 --- web/src/pages/admin/service-status.tsx | 15 +- web/src/pages/admin/users.tsx | 8 +- web/src/pages/agents/index.tsx | 282 +++++------ .../dataset-overview/dataset-filter.tsx | 43 +- .../pages/dataset/dataset-overview/index.tsx | 2 - .../dataset-overview/overview-table.tsx | 48 +- web/src/pages/datasets/index.tsx | 135 +++--- web/src/pages/files/files-table.tsx | 28 +- web/src/pages/files/index.tsx | 32 +- web/src/pages/memories/index.tsx | 106 ++-- web/src/pages/next-chats/index.tsx | 94 ++-- web/src/pages/next-searches/index.tsx | 119 +++-- web/src/pages/user-setting/index.tsx | 4 +- .../user-setting/profile/hooks/use-profile.ts | 6 +- web/src/pages/user-setting/profile/index.tsx | 21 +- .../user-setting/setting-locale/index.tsx | 23 - .../setting-locale/translation-table.tsx | 238 --------- web/src/pages/user-setting/sidebar/index.tsx | 3 +- web/src/routes.tsx | 4 +- 42 files changed, 668 insertions(+), 1359 deletions(-) delete mode 100644 web/src/locales/until.ts delete mode 100644 web/src/pages/user-setting/setting-locale/index.tsx delete mode 100644 web/src/pages/user-setting/setting-locale/translation-table.tsx diff --git a/web/src/components/api-service/chat-overview-modal/api-content.tsx b/web/src/components/api-service/chat-overview-modal/api-content.tsx index be19e016b68..b7127f53e9d 100644 --- a/web/src/components/api-service/chat-overview-modal/api-content.tsx +++ b/web/src/components/api-service/chat-overview-modal/api-content.tsx @@ -26,8 +26,9 @@ const ApiContent = ({ id, idKey }: { id?: string; idKey: string }) => { const isDarkTheme = useIsDarkTheme(); return ( -
- +
+ +
) : (
- - - - - -
{ + innerInputRef.current?.click(); + }} > - -
+ + + + + +
+ +
+
)} - -
+
{tips ?? t('knowledgeConfiguration.photoTip')}
@@ -357,7 +372,7 @@ export const AvatarUpload = forwardRef( height: '300px', touchAction: 'none', }} - // onWheel={handleWheel} + onWheel={handleWheel} > {children} - +
); } diff --git a/web/src/components/card-singleline-container/index.tsx b/web/src/components/card-singleline-container/index.tsx index a35a37603be..f5e8e232d1e 100644 --- a/web/src/components/card-singleline-container/index.tsx +++ b/web/src/components/card-singleline-container/index.tsx @@ -1,5 +1,5 @@ import { cn } from '@/lib/utils'; -import { isValidElement, PropsWithChildren, ReactNode } from 'react'; +import { PropsWithChildren } from 'react'; import './index.less'; type CardContainerProps = { className?: string } & PropsWithChildren; @@ -8,26 +8,6 @@ export function CardSineLineContainer({ children, className, }: CardContainerProps) { - const flattenChildren = (children: ReactNode): ReactNode[] => { - const result: ReactNode[] = []; - - const traverse = (child: ReactNode) => { - if (Array.isArray(child)) { - child.forEach(traverse); - } else if (isValidElement(child) && child.props.children) { - result.push(child); - } else { - result.push(child); - } - }; - - traverse(children); - return result; - }; - const childArray = flattenChildren(children); - const childCount = childArray.length; - console.log(childArray, childCount); - return (
{ - const { type, showIcon, className, isSearch, children, testId } = props; + const { type, showIcon, className, isSearch, children, testId, tabIndex } = + props; const { t } = useTranslation(); let defaultClass = ''; let style = {}; @@ -110,10 +112,10 @@ export const EmptyAppCard = (props: { diff --git a/web/src/components/list-filter-bar/index.tsx b/web/src/components/list-filter-bar/index.tsx index b84271328c3..9cd12d4c428 100644 --- a/web/src/components/list-filter-bar/index.tsx +++ b/web/src/components/list-filter-bar/index.tsx @@ -94,7 +94,7 @@ export default function ListFilterBar({ return (
-
+

{typeof icon === 'string' ? ( // +

{preChildren} - {showFilter && ( + {filters?.length && showFilter && ( {selectLabel || value ? ( - - - {selectLabel || value} - + + {selectLabel || value} ) : ( {placeholder} @@ -209,10 +207,10 @@ export const SelectWithSearch = forwardRef<
- {options.map((group, idx) => { + {options.map((group) => { if (group.options) { return ( - + {group.options.map((option) => ( {t('pagination.page', { page: x })}
, value: x.toString(), })); - }, []); + }, [t]); const pages = useMemo(() => { const num = Math.ceil(total / pageSize); @@ -134,7 +135,7 @@ export function RAGFlowPagination({ }, [pages, currentPage]); return ( -
+
{t('pagination.total', { total: total })} @@ -181,6 +182,6 @@ export function RAGFlowPagination({ triggerClassName="bg-bg-card border-transparent" /> )} -
+
); } diff --git a/web/src/components/ui/segmented.tsx b/web/src/components/ui/segmented.tsx index 5b3eba54d99..7318c017dd1 100644 --- a/web/src/components/ui/segmented.tsx +++ b/web/src/components/ui/segmented.tsx @@ -105,8 +105,6 @@ const Segmented = React.forwardRef( const isObject = typeof option === 'object'; const actualValue = isObject ? option.value : option; - console.log(actualValue); - return ( - {langItems.map((x) => ( + {supportedLanguages.map((x) => ( changeLanguage(x.key)} + key={x.code} + onClick={() => changeLanguage(x.code)} > - {x.label} + {x.displayName} ))} diff --git a/web/src/lib/utils.ts b/web/src/lib/utils.ts index 3b3f77e8a6d..646e0534f65 100644 --- a/web/src/lib/utils.ts +++ b/web/src/lib/utils.ts @@ -1,4 +1,5 @@ import { clsx, type ClassValue } from 'clsx'; +import React from 'react'; import { twMerge } from 'tailwind-merge'; export function cn(...inputs: ClassValue[]) { @@ -19,6 +20,20 @@ export function formatBytes( if (bytes === 0) return '0 Byte'; const i = Math.floor(Math.log(bytes) / Math.log(1024)); return `${(bytes / Math.pow(1024, i)).toFixed(decimals)} ${ - sizeType === 'accurate' ? accurateSizes[i] ?? 'Bytes' : sizes[i] ?? 'Bytes' + sizeType === 'accurate' + ? (accurateSizes[i] ?? 'Bytes') + : (sizes[i] ?? 'Bytes') }`; } + +export function combineRefs(...refs: React.ForwardedRef[]) { + return (node: T) => { + refs.forEach((ref) => { + if (typeof ref === 'function') { + ref(node); + } else if (ref) { + ref.current = node; + } + }); + }; +} diff --git a/web/src/locales/ar.ts b/web/src/locales/ar.ts index 9e1334f7799..f6336c47341 100644 --- a/web/src/locales/ar.ts +++ b/web/src/locales/ar.ts @@ -2300,7 +2300,7 @@ export default { }, pagination: { total: 'الإجمالي {{total}}', - page: '{{page}} /الصفحة', + page: '{{page}} / الصفحة', }, dataflowParser: { result: 'نتيجة', diff --git a/web/src/locales/bg.ts b/web/src/locales/bg.ts index 7a0a91baedc..078ece92d70 100644 --- a/web/src/locales/bg.ts +++ b/web/src/locales/bg.ts @@ -2385,7 +2385,7 @@ Important structured information may include: names, dates, locations, events, k }, pagination: { total: 'Общо {{total}}', - page: '{{page}} /Страница', + page: '{{page}} / Страница', }, dataflowParser: { result: 'Резултат', diff --git a/web/src/locales/config.ts b/web/src/locales/config.ts index 73469dcaabc..2a625b622c4 100644 --- a/web/src/locales/config.ts +++ b/web/src/locales/config.ts @@ -1,13 +1,14 @@ import i18n from 'i18next'; import LanguageDetector from 'i18next-browser-languagedetector'; +import { upperFirst } from 'lodash'; import { initReactI18next } from 'react-i18next'; import { LanguageAbbreviation } from '@/constants/common'; -import { createTranslationTable, flattenObject } from './until'; import translation_en from './en'; const languageImports: Record Promise<{ default: any }>> = { + [LanguageAbbreviation.En]: () => import('./en'), [LanguageAbbreviation.Zh]: () => import('./zh'), [LanguageAbbreviation.ZhTraditional]: () => import('./zh-traditional'), [LanguageAbbreviation.Id]: () => import('./id'), @@ -23,20 +24,22 @@ const languageImports: Record Promise<{ default: any }>> = { [LanguageAbbreviation.Ar]: () => import('./ar'), }; -const languageAliases: Record = { - 'pt-br': LanguageAbbreviation.PtBr, -}; +const supportedLanguageCodes: Intl.UnicodeBCP47LocaleIdentifier[] = + Object.keys(languageImports); -const normalizeLanguageCode = (lng: string): string => { - return languageAliases[lng] ?? lng; -}; +export const supportedLanguages = supportedLanguageCodes.map((code) => { + const locale = new Intl.Locale(code); -const enFlattened = flattenObject(translation_en); + return { + code, + locale, + displayName: upperFirst( + new Intl.DisplayNames(locale, { type: 'language' }).of(code)!, + ), + }; +}); -export const translationTable = createTranslationTable( - [enFlattened], - ['English'], -); +export const DEFAULT_LANGUAGE_CODE = LanguageAbbreviation.En; const resources = { [LanguageAbbreviation.En]: translation_en, @@ -49,16 +52,17 @@ i18n detection: { lookupLocalStorage: 'lng', }, - supportedLngs: Object.values(LanguageAbbreviation), + supportedLngs: supportedLanguageCodes, resources, - fallbackLng: 'en', + fallbackLng: DEFAULT_LANGUAGE_CODE, interpolation: { escapeValue: false, }, }); export const loadLanguageAsync = async (lng: string): Promise => { - const normalizedLng = normalizeLanguageCode(lng); + // const normalizedLng = normalizeLanguageCode(lng); + const normalizedLng = lng; if (i18n.hasResourceBundle(normalizedLng, 'translation')) { return; @@ -74,16 +78,15 @@ export const loadLanguageAsync = async (lng: string): Promise => { const module = await importFn(); const translationData = module.default?.translation || module.default; i18n.addResourceBundle(normalizedLng, 'translation', translationData); - - const flattened = flattenObject({ translation: translationData }); - translationTable.push(flattened); } catch (error) { console.error(`Failed to load language ${lng}:`, error); } }; export const changeLanguageAsync = async (lng: string): Promise => { - const normalizedLng = normalizeLanguageCode(lng); + // const normalizedLng = normalizeLanguageCode(lng); + const normalizedLng = lng; + if ( normalizedLng !== LanguageAbbreviation.En && !i18n.hasResourceBundle(normalizedLng, 'translation') @@ -94,14 +97,14 @@ export const changeLanguageAsync = async (lng: string): Promise => { }; export const initLanguage = async (): Promise => { - const currentLng = normalizeLanguageCode( - i18n.language || localStorage.getItem('lng') || LanguageAbbreviation.En, - ); + // const currentLng = normalizeLanguageCode( + // i18n.language || localStorage.getItem('lng') || LanguageAbbreviation.En, + // ); - if (currentLng !== LanguageAbbreviation.En && languageImports[currentLng]) { - await loadLanguageAsync(currentLng); - await i18n.changeLanguage(currentLng); - } + const currentLng = + i18n.language || localStorage.getItem('lng') || DEFAULT_LANGUAGE_CODE; + + await changeLanguageAsync(currentLng); }; export default i18n; diff --git a/web/src/locales/de.ts b/web/src/locales/de.ts index 5e17b24e5f3..5cef98ba33e 100644 --- a/web/src/locales/de.ts +++ b/web/src/locales/de.ts @@ -2444,7 +2444,7 @@ Wichtige strukturierte Informationen können sein: Namen, Daten, Orte, Ereigniss }, pagination: { total: 'Gesamt {{total}}', - page: '{{page}} /Seite', + page: '{{page}} / Seite', }, dataflowParser: { result: 'Ergebnis', diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index e58b958923d..0671ac33a07 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -2482,7 +2482,7 @@ Important structured information may include: names, dates, locations, events, k }, pagination: { total: 'Total {{total}}', - page: '{{page}} /Page', + page: '{{page}} / Page', }, dataflowParser: { result: 'Result', diff --git a/web/src/locales/it.ts b/web/src/locales/it.ts index 2a16b2ea012..febe4c9f6e1 100644 --- a/web/src/locales/it.ts +++ b/web/src/locales/it.ts @@ -1203,7 +1203,7 @@ Quanto sopra è il contenuto che devi riassumere.`, }, pagination: { total: 'Totale {{total}}', - page: '{{page}} /Pagina', + page: '{{page}} / Pagina', }, deleteModal: { delAgent: 'Elimina agente', diff --git a/web/src/locales/until.ts b/web/src/locales/until.ts deleted file mode 100644 index 9934a97570b..00000000000 --- a/web/src/locales/until.ts +++ /dev/null @@ -1,60 +0,0 @@ -type NestedObject = { - [key: string]: string | NestedObject; -}; - -type FlattenedObject = { - [key: string]: string; -}; - -export function flattenObject( - obj: NestedObject, - parentKey: string = '', -): FlattenedObject { - const result: FlattenedObject = {}; - - for (const [key, value] of Object.entries(obj)) { - const newKey = parentKey ? `${parentKey}.${key}` : key; - - if (typeof value === 'object' && value !== null) { - Object.assign(result, flattenObject(value as NestedObject, newKey)); - } else { - result[newKey] = value as string; - } - } - - return result; -} -type TranslationTableRow = { - key: string; - [language: string]: string; -}; - -/** - * Creates a translation table from multiple flattened language objects. - * @param langs - A list of flattened language objects. - * @param langKeys - A list of language identifiers (e.g., 'English', 'Vietnamese'). - * @returns An array representing the translation table. - */ -export function createTranslationTable( - langs: FlattenedObject[], - langKeys: string[], -): TranslationTableRow[] { - const keys = new Set(); - - // Collect all unique keys from the language objects - langs.forEach((lang) => { - Object.keys(lang).forEach((key) => keys.add(key)); - }); - - // Build the table - return Array.from(keys).map((key) => { - const row: TranslationTableRow = { key }; - - langs.forEach((lang, index) => { - const langKey = langKeys[index]; - row[langKey] = lang[key] || ''; // Use empty string if key is missing - }); - - return row; - }); -} diff --git a/web/src/pages/admin/service-status.tsx b/web/src/pages/admin/service-status.tsx index 582b38cac64..46da1b50b72 100644 --- a/web/src/pages/admin/service-status.tsx +++ b/web/src/pages/admin/service-status.tsx @@ -229,19 +229,8 @@ function AdminServiceStatus() {
- diff --git a/web/src/pages/admin/users.tsx b/web/src/pages/admin/users.tsx index 4eafa54b0a1..9730eec5764 100644 --- a/web/src/pages/admin/users.tsx +++ b/web/src/pages/admin/users.tsx @@ -481,12 +481,8 @@ function AdminUserManagement() {
- diff --git a/web/src/pages/agents/index.tsx b/web/src/pages/agents/index.tsx index 7e954aff96f..bea639f2f26 100644 --- a/web/src/pages/agents/index.tsx +++ b/web/src/pages/agents/index.tsx @@ -13,6 +13,7 @@ import { import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; import { useFetchAgentListByPage } from '@/hooks/use-agent-request'; +import { Routes } from '@/routes'; import { t } from 'i18next'; import { pick } from 'lodash'; import { Clipboard, ClipboardPlus, FileInput, Plus } from 'lucide-react'; @@ -36,6 +37,7 @@ export default function Agents() { filterValue, handleFilterSubmit, } = useFetchAgentListByPage(); + const { navigateToAgentTemplates } = useNavigatePage(); const { @@ -72,6 +74,7 @@ export default function Agents() { ); const [searchUrl, setSearchUrl] = useSearchParams(); const isCreate = searchUrl.get('isCreate') === 'true'; + useEffect(() => { if (isCreate) { showCreatingModal(); @@ -79,153 +82,162 @@ export default function Agents() { setSearchUrl(searchUrl); } }, [isCreate, showCreatingModal, searchUrl, setSearchUrl]); + return ( <> - {(!data?.length || data?.length <= 0) && !searchString && ( -
- showCreatingModal()} - > -
-
- - {t('flow.createFromBlank')} -
-
- - {t('flow.createFromTemplate')} -
-
- - {t('flow.importJsonFile')} -
-
-
-
- )} -
- {(!!data?.length || searchString) && ( - <> -
- - - - - - - - - {t('flow.createFromBlank')} - - - - {t('flow.createFromTemplate')} - - - - {t('flow.importJsonFile')} - - - - -
- {(!data?.length || data?.length <= 0) && searchString && ( -
- showCreatingModal()} - /> -
- )} -
- + {data?.length || searchString ? ( +
+
+ + + + + + + + + {t('flow.createFromBlank')} + + navigateToAgentTemplates()} + > + + {t('flow.createFromTemplate')} + + + + {t('flow.importJsonFile')} + + + + +
+ + {data.length ? ( + <> + {data.map((x) => { return ( + /> ); })} + +
+ +
+ + ) : ( +
+ showCreatingModal()} + />
-
- -
- - )} - {agentRenameVisible && ( - - )} - {creatingVisible && ( - - )} - {fileUploadVisible && ( - - )} -
+ )} + + ) : ( +
+ showCreatingModal()} + > +
    +
  • + +
  • + +
  • + +
  • + +
  • + +
  • +
+
+
+ )} + + {agentRenameVisible && ( + + )} + {creatingVisible && ( + + )} + {fileUploadVisible && ( + + )} ); } diff --git a/web/src/pages/dataset/dataset-overview/dataset-filter.tsx b/web/src/pages/dataset/dataset-overview/dataset-filter.tsx index 4767ce5d715..ab7bb4f25f9 100644 --- a/web/src/pages/dataset/dataset-overview/dataset-filter.tsx +++ b/web/src/pages/dataset/dataset-overview/dataset-filter.tsx @@ -3,9 +3,8 @@ import { CheckboxFormMultipleProps, FilterPopover, } from '@/components/list-filter-bar/filter-popover'; -import { Button } from '@/components/ui/button'; import { SearchInput } from '@/components/ui/input'; -import { cn } from '@/lib/utils'; +import { Segmented } from '@/components/ui/segmented'; import { ChangeEventHandler, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { LogTabs } from './dataset-common'; @@ -40,35 +39,23 @@ const DatasetFilter = ( }, [value]); return (
-
- - + ]} + onChange={(value) => + setActive?.(value as (typeof LogTabs)[keyof typeof LogTabs]) + } + />
-
+ +
{ label: t('knowledgeDetails.status'), list: Object.values(RunningStatus).map((value) => { // const value = key as RunningStatus; - console.log(value); return { id: value, // label: RunningStatusMap[value].label, @@ -245,7 +244,6 @@ const FileLogsPage: FC = () => { page: number; pageSize: number; }) => { - console.log('Pagination changed:', { page, pageSize }); setPagination({ ...pagination, page, diff --git a/web/src/pages/dataset/dataset-overview/overview-table.tsx b/web/src/pages/dataset/dataset-overview/overview-table.tsx index 6c5c93527de..de92a53ef50 100644 --- a/web/src/pages/dataset/dataset-overview/overview-table.tsx +++ b/web/src/pages/dataset/dataset-overview/overview-table.tsx @@ -150,14 +150,19 @@ export const getFileLogsTableColumns = ( accessorKey: 'process_begin_at', header: ({ column }) => { return ( - + + +
); }, cell: ({ row }) => ( @@ -192,8 +197,7 @@ export const getFileLogsTableColumns = (
+ +
); }, cell: ({ row }) => ( @@ -319,11 +326,10 @@ export const getDatasetLogsTableColumns = ( id: 'operations', header: t('operations'), cell: ({ row }) => ( -
+
- {(!kbs?.length || kbs?.length <= 0) && searchString && ( -
- showModal()} - /> -
- )} -
- - {kbs.map((dataset) => { - return ( - - ); - })} + + + {kbs?.length ? ( + <> + + {kbs.map((dataset) => ( + + ))} + +
+ +
+ + ) : ( +
+ showModal()} + />
-
- -
- - )} - {visible && ( - - )} - {datasetRenameVisible && ( - - )} -
+ )} + + ) : ( +
+ showModal()} + /> +
+ )} + {visible && ( + + )} + {datasetRenameVisible && ( + + )} ); } diff --git a/web/src/pages/files/files-table.tsx b/web/src/pages/files/files-table.tsx index cb4a92681d5..abc072185e1 100644 --- a/web/src/pages/files/files-table.tsx +++ b/web/src/pages/files/files-table.tsx @@ -225,7 +225,7 @@ export function FilesTable({ id: 'actions', header: t('action'), meta: { - headerCellClassName: 'w-0', + headerCellClassName: 'w-0 whitespace-nowrap', }, enableHiding: false, enablePinning: true, @@ -278,8 +278,8 @@ export function FilesTable({ return ( <> -
-
+
+
{table.getHeaderGroups().map((headerGroup) => ( @@ -332,17 +332,17 @@ export function FilesTable({
-
-
- { - setPagination({ page, pageSize }); - }} - > -
-
+ +
+ { + setPagination({ page, pageSize }); + }} + /> +
+ {connectToKnowledgeVisible && ( -
+
+
+ {!rowSelectionIsEmpty && ( - + )}
- + +
+ +
+ {fileUploadVisible && ( - {(!list?.data?.memory_list?.length || - list?.data?.memory_list?.length <= 0) && - !searchString && ( -
- openCreateModalFun()} - /> -
- )} - {(!!list?.data?.memory_list?.length || searchString) && ( - <> -
+ <> + {list?.data?.memory_list?.length || searchString ? ( +
+
- -
- {(!list?.data?.memory_list?.length || - list?.data?.memory_list?.length <= 0) && - searchString && ( -
- openCreateModalFun()} - /> -
- )} -
- - {list?.data.memory_list.map((x) => { - return ( +
+ + {list?.data?.memory_list?.length ? ( + <> + + {list?.data.memory_list.map((x) => ( - ); - })} - -
- {list?.data.total_count && list?.data.total_count > 0 && ( -
- + ))} + + +
+ +
+ + ) : ( +
+ openCreateModalFun()} />
)} - + + ) : ( +
+ openCreateModalFun()} + /> +
)} {/* {openCreateModal && ( )} -
+ ); } diff --git a/web/src/pages/next-chats/index.tsx b/web/src/pages/next-chats/index.tsx index 195fd991212..a6d764c5307 100644 --- a/web/src/pages/next-chats/index.tsx +++ b/web/src/pages/next-chats/index.tsx @@ -49,70 +49,72 @@ export default function ChatList() { }, [isCreate, handleShowCreateModal, searchParams, setSearchParams]); return ( -
- {data.dialogs?.length <= 0 && !searchString && ( -
- handleShowCreateModal()} - testId="chats-empty-create" - /> -
- )} - {(data.dialogs?.length > 0 || searchString) && ( - <> -
+ <> + {data.dialogs?.length || searchString ? ( +
+
- -
- {data.dialogs?.length <= 0 && searchString && ( -
+ + + {data.dialogs?.length ? ( + <> + + {data.dialogs.map((x) => ( + + ))} + + +
+ +
+ + ) : ( +
handleShowCreateModal()} testId="chats-empty-create" />
)} -
- - {data.dialogs.map((x) => { - return ( - - ); - })} - -
-
- -
- + + ) : ( +
+ handleShowCreateModal()} + testId="chats-empty-create" + /> +
)} + {chatRenameVisible && ( )} -
+ ); } diff --git a/web/src/pages/next-searches/index.tsx b/web/src/pages/next-searches/index.tsx index f49563a2d5e..70e9a01f438 100644 --- a/web/src/pages/next-searches/index.tsx +++ b/web/src/pages/next-searches/index.tsx @@ -65,25 +65,10 @@ export default function SearchList() { }, [isCreate, openCreateModalFun, searchUrl, setSearchUrl]); return ( -
- {(!list?.data?.search_apps?.length || - list?.data?.search_apps?.length <= 0) && - !searchString && ( -
- openCreateModalFun()} - testId="search-empty-create" - /> -
- )} - {(!!list?.data?.search_apps?.length || searchString) && ( - <> -
+ <> + {list?.data?.search_apps?.length || searchString ? ( +
+
-
- {(!list?.data?.search_apps?.length || - list?.data?.search_apps?.length <= 0) && - searchString && ( -
- openCreateModalFun()} - testId="search-empty-create" + + + {list?.data?.search_apps?.length ? ( + <> + + {list?.data.search_apps.map((x) => { + return ( + { + showSearchRenameModal(x); + }} + /> + ); + })} + + +
+ -
- )} -
- - {list?.data.search_apps.map((x) => { - return ( - { - showSearchRenameModal(x); - }} - > - ); - })} - -
- {list?.data.total && list?.data.total > 0 && ( -
- + + ) : ( +
+
)} - + + ) : ( +
+ openCreateModalFun()} + testId="search-empty-create" + /> +
)} {openCreateModal && ( )} -
+ ); } diff --git a/web/src/pages/user-setting/index.tsx b/web/src/pages/user-setting/index.tsx index f3b85685792..080ed3b8d16 100644 --- a/web/src/pages/user-setting/index.tsx +++ b/web/src/pages/user-setting/index.tsx @@ -6,10 +6,10 @@ import { cn } from '@/lib/utils'; const UserSetting = () => { return (
- +
- +
); diff --git a/web/src/pages/user-setting/profile/hooks/use-profile.ts b/web/src/pages/user-setting/profile/hooks/use-profile.ts index 5b8bdf7b506..ffcd6ae134a 100644 --- a/web/src/pages/user-setting/profile/hooks/use-profile.ts +++ b/web/src/pages/user-setting/profile/hooks/use-profile.ts @@ -1,4 +1,5 @@ // src/hooks/useProfile.ts +import { DEFAULT_TIMEZONE } from '@/constants/setting'; import { useFetchUserInfo, useSaveSetting, @@ -53,7 +54,10 @@ export const useProfile = () => { // form.setValue('currPasswd', ''); // current password const profile = { userName: userInfo.nickname, - timeZone: userInfo.timezone, + timeZone: + userInfo.timezone === ' UTC+8\tAsia/Shanghai' + ? DEFAULT_TIMEZONE.name + : userInfo.timezone, avatar: userInfo.avatar || '', email: userInfo.email, currPasswd: userInfo.password, diff --git a/web/src/pages/user-setting/profile/index.tsx b/web/src/pages/user-setting/profile/index.tsx index 4b0d897ffb8..9f330fab33c 100644 --- a/web/src/pages/user-setting/profile/index.tsx +++ b/web/src/pages/user-setting/profile/index.tsx @@ -19,12 +19,17 @@ import { TimezoneList } from '@/pages/user-setting/constants'; import { zodResolver } from '@hookform/resolvers/zod'; import { t } from 'i18next'; import { Loader2Icon, PenLine } from 'lucide-react'; -import { FC, useEffect } from 'react'; +import { FC, useEffect, useMemo } from 'react'; import { useForm } from 'react-hook-form'; import { z } from 'zod'; import { ProfileSettingWrapperCard } from '../components/user-setting-header'; import { EditType, modalTitle, useProfile } from './hooks/use-profile'; +const timezoneOptions = TimezoneList.map(({ name }) => ({ + value: name, + label: name, +})); + const baseSchema = z.object({ userName: z .string() @@ -75,6 +80,7 @@ const passwordSchema = baseSchema }); } }); + const ProfilePage: FC = () => { const { t } = useTranslate('setting'); @@ -116,6 +122,11 @@ const ProfilePage: FC = () => { // ); // }; + const timezone = useMemo(() => { + const tz = TimezoneList.find((tz) => tz.name === profile.timeZone); + return tz?.name ?? ''; + }, [profile.timeZone]); + return ( //
{ {t('timezone')}
-
- {profile.timeZone} +
+ {timezone}
- {data.release_time ? ( -
-
+ {showReleaseTime ? ( +
+
{`${t('flow.lastSavedAt')}:`}
-
- {`${t('flow.publishedAt')}:`} - -
+ {data.release_time && ( +
+ {`${t('flow.publishedAt')}:`} + +
+ )}
) : ( diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index c9a08b7f204..8ca9d81942c 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -33,6 +33,7 @@ export interface ISwitchForm { import { AgentCategory } from '@/constants/agent'; import { Edge, Node } from '@xyflow/react'; import { IReference, Message } from './chat'; +import { IKnowledge } from './knowledge'; export type DSLComponents = Record; @@ -80,6 +81,7 @@ export declare interface IFlow { release?: boolean; release_time?: number; last_publish_time?: number; + datasets?: Pick[]; } export interface IFlowTemplate { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 0671ac33a07..fc813671df4 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -2196,9 +2196,11 @@ This process aggregates variables from multiple branches into a single variable production: 'Production', productionTooltip: 'This version is published to production. Access it via the API or the embedded page.', - confirmPublish: 'Confirm Publish', - publishDescription: 'You are about to publish this data pipeline.', - linkedDataset: 'Linked dataset', + confirmPublish: 'Confirm publish', + publishIngestionPipeline: + 'You are about to publish this Ingestion pipeline.', + publishAgent: 'You are about to publish this agent', + linkedDataset: 'Linked dataset:', lastPublished: 'Last published', createFromBlank: 'Create from blank', createFromTemplate: 'Create from template', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 3d49e890034..9a129df7c13 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -202,7 +202,7 @@ export default { // editMetadataForDataset: '查看和编辑元数据 ', restrictDefinedValues: '限制为已定义的值', metadataGenerationSettings: '元数据生成设置', - // manageMetadataForDataset: '管理此数据集的元数据', + // manageMetadataForDataset: '管理此知识库的元数据', manageMetadata: '管理元数据', metadata: '元数据', values: '值', @@ -247,7 +247,7 @@ export default { startDate: '开始时间', source: '来源', fileName: '文件名', - datasetLogs: '数据集', + datasetLogs: '知识库', fileLogs: '文件', overview: '日志', success: '成功', @@ -390,7 +390,7 @@ export default { knowledgeConfiguration: { randomSeedTip: '种子是伪随机算法的起点,它确保在不同运行中产生相同的输出,从而保证可重复性。', - datasetDescription: '你的数据集描述。', + datasetDescription: '你的知识库描述。', overlappedPercentTip: '相邻两个块之间的重叠百分比', settings: '设置', autoMetadataTip: @@ -430,13 +430,13 @@ export default { baseInfo: '基础信息', globalIndex: '全局索引', dataSource: '数据源', - linkSourceSetTip: '管理与此数据集的数据源链接', + linkSourceSetTip: '管理与此知识库的数据源链接', linkDataSource: '链接数据源', tocExtractionTip: '对于已有的chunk生成层级结构的目录信息(每个文件一个目录)。在查询时,激活`Page Index`后,系统会用大模型去判断用户问题和哪些目录项相关,从而找到相关的chunk。', deleteGenerateModalContent: `

删除生成的 {{type}} 结果 - 将从此数据集中移除所有派生实体和关系。 + 将从此知识库中移除所有派生实体和关系。 您的原始文件将保持不变。


是否要继续? @@ -449,7 +449,7 @@ export default { setDefaultTip: '', setDefault: '设置默认', editLinkDataPipeline: '编辑pipeline', - linkPipelineSetTip: '管理与此数据集的数据管道链接', + linkPipelineSetTip: '管理与此知识库的数据管道链接', default: '默认', dataPipeline: '切换或配置 ingestion pipeline。', linkDataPipeline: '关联pipeline', @@ -630,7 +630,7 @@ export default { tagSetTip: `

请选择一个或多个标签集或标签知识库,用于对知识库中的每个文本块进行标记。

对这些文本块的查询也将自动关联相应标签。

-

此功能基于文本相似度,能够为数据集的文本块批量添加更多领域知识,从而显著提高检索准确性。该功能还能提升大量文本块的操作效率。

+

此功能基于文本相似度,能够为知识库的文本块批量添加更多领域知识,从而显著提高检索准确性。该功能还能提升大量文本块的操作效率。

为了更好地理解标签集的作用,以下是标签集和关键词之间的主要区别: