diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
index 7e5d0062bb4..5161ded2b6c 100644
--- a/.github/copilot-instructions.md
+++ b/.github/copilot-instructions.md
@@ -1 +1,22 @@
-Refer to [AGENTS.MD](../AGENTS.md) for all repo instructions.
+# Project instructions for Copilot
+
+## How to run (minimum)
+- Install:
+ - python -m venv .venv && source .venv/bin/activate
+ - pip install -r requirements.txt
+- Run:
+ - (fill) e.g. uvicorn app.main:app --reload
+- Verify:
+ - (fill) curl http://127.0.0.1:8000/health
+
+## Project layout (what matters)
+- app/: API entrypoints + routers
+- services/: business logic
+- configs/: config loading (.env)
+- docs/: documents
+- tests/: pytest
+
+## Conventions
+- Prefer small, incremental changes.
+- Add logging for new flows.
+- Add/adjust tests for behavior changes.
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 37c666173a4..934005edec3 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -86,6 +86,9 @@ jobs:
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
fi
+ ARTIFACTS_DIR=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/${GITHUB_RUN_ID}
+ echo "ARTIFACTS_DIR=${ARTIFACTS_DIR}" >> ${GITHUB_ENV}
+ rm -rf ${ARTIFACTS_DIR} && mkdir -p ${ARTIFACTS_DIR}
# https://github.com/astral-sh/ruff-action
- name: Static check with Ruff
@@ -161,7 +164,7 @@ jobs:
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
- MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
+ EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
@@ -181,7 +184,7 @@ jobs:
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
- echo -e "MYSQL_PORT=${MYSQL_PORT}" >> docker/.env
+ echo -e "EXPOSE_MYSQL_PORT=${EXPOSE_MYSQL_PORT}" >> docker/.env
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
@@ -199,8 +202,11 @@ jobs:
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
+ # Patch entrypoint.sh for coverage
+ sed -i '/"\$PY" api\/ragflow_server.py \${INIT_SUPERUSER_ARGS} &/c\ echo "Ensuring coverage is installed..."\n "$PY" -m pip install coverage\n export COVERAGE_FILE=/ragflow/logs/.coverage\n echo "Starting ragflow_server with coverage..."\n "$PY" -m coverage run --source=./api/apps --omit="*/tests/*,*/migrations/*" -a api/ragflow_server.py ${INIT_SUPERUSER_ARGS} &' docker/entrypoint.sh
+
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
+ uv sync --python 3.12 --group test --frozen && uv pip install -e sdk/python
- name: Run sdk tests against Elasticsearch
run: |
@@ -209,16 +215,16 @@ jobs:
echo "Waiting for service to be available..."
sleep 5
done
- source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
+ source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-es-sdk.xml test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- - name: Run frontend api tests against Elasticsearch
+ - name: Run web api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
- source .venv/bin/activate && set -o pipefail; pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
+ source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api 2>&1 | tee es_web_api_test.log
- name: Run http api tests against Elasticsearch
run: |
@@ -229,6 +235,154 @@ jobs:
done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
+ - name: RAGFlow CLI retrieval test Elasticsearch
+ env:
+ PYTHONPATH: ${{ github.workspace }}
+ run: |
+ set -euo pipefail
+ source .venv/bin/activate
+
+ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
+
+ EMAIL="ci-${GITHUB_RUN_ID}@example.com"
+ PASS="ci-pass-${GITHUB_RUN_ID}"
+ DATASET="ci_dataset_${GITHUB_RUN_ID}"
+
+ CLI="python admin/client/ragflow_cli.py"
+
+ LOG_FILE="es_cli_test.log"
+ : > "${LOG_FILE}"
+
+ ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]'
+ run_cli() {
+ local logfile="$1"
+ shift
+ local allow_re=""
+ if [[ "${1:-}" == "--allow" ]]; then
+ allow_re="$2"
+ shift 2
+ fi
+ local cmd_display="$*"
+ echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}"
+ local tmp_log
+ tmp_log="$(mktemp)"
+ set +e
+ timeout 180s "$@" 2>&1 | tee "${tmp_log}"
+ local status=${PIPESTATUS[0]}
+ set -e
+ cat "${tmp_log}" >> "${logfile}"
+ if grep -qiE "${ERROR_RE}" "${tmp_log}"; then
+ if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then
+ echo "Allowed CLI error markers in ${logfile}"
+ rm -f "${tmp_log}"
+ return 0
+ fi
+ echo "Detected CLI error markers in ${logfile}"
+ rm -f "${tmp_log}"
+ exit 1
+ fi
+ rm -f "${tmp_log}"
+ return ${status}
+ }
+
+ set -a
+ source docker/.env
+ set +a
+
+ HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}"
+ USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')"
+ USER_PORT="${SVR_HTTP_PORT}"
+ ADMIN_HOST="${USER_HOST}"
+ ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}"
+
+ until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
+ echo "Waiting for service to be available..."
+ sleep 5
+ done
+
+ admin_ready=0
+ for i in $(seq 1 30); do
+ if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then
+ admin_ready=1
+ break
+ fi
+ sleep 1
+ done
+ if [[ "${admin_ready}" -ne 1 ]]; then
+ echo "Admin service did not become ready"
+ exit 1
+ fi
+
+ run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version"
+ ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?'
+ run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'"
+
+ user_ready=0
+ for i in $(seq 1 30); do
+ if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then
+ user_ready=1
+ break
+ fi
+ sleep 1
+ done
+ if [[ "${user_ready}" -ne 1 ]]; then
+ echo "User service did not become ready"
+ exit 1
+ fi
+
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'"
+
+ - name: Stop ragflow to save coverage Elasticsearch
+ if: ${{ !cancelled() }}
+ run: |
+ # Send SIGINT to ragflow_server.py to trigger coverage save
+ PID=$(sudo docker exec ${RAGFLOW_CONTAINER} ps aux | grep "ragflow_server.py" | grep -v grep | awk '{print $2}' | head -n 1)
+ if [ -n "$PID" ]; then
+ echo "Sending SIGINT to ragflow_server.py (PID: $PID)..."
+ sudo docker exec ${RAGFLOW_CONTAINER} kill -INT $PID
+ # Wait for process to exit and coverage file to be written
+ sleep 10
+ else
+ echo "ragflow_server.py not found!"
+ fi
+ sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} stop
+
+ - name: Generate server coverage report Elasticsearch
+ if: ${{ !cancelled() }}
+ run: |
+ # .coverage file should be in docker/ragflow-logs/.coverage
+ if [ -f docker/ragflow-logs/.coverage ]; then
+ echo "Found .coverage file"
+ cp docker/ragflow-logs/.coverage .coverage
+ source .venv/bin/activate
+ # Create .coveragerc to map container paths to host paths
+ echo "[paths]" > .coveragerc
+ echo "source =" >> .coveragerc
+ echo " ." >> .coveragerc
+ echo " /ragflow" >> .coveragerc
+ coverage xml -o coverage-es-server.xml
+ rm .coveragerc
+ # Clean up for next run
+ sudo rm docker/ragflow-logs/.coverage
+ else
+ echo ".coverage file not found!"
+ fi
+
+ - name: Collect ragflow log Elasticsearch
+ if: ${{ !cancelled() }}
+ run: |
+ if [ -d docker/ragflow-logs ]; then
+ cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-es
+ echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log || true
+ else
+ echo "No docker/ragflow-logs directory found; skipping log collection"
+ fi
+ sudo rm -rf docker/ragflow-logs || true
+
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
run: |
@@ -247,16 +401,16 @@ jobs:
echo "Waiting for service to be available..."
sleep 5
done
- source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
+ source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-infinity-sdk.xml test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- - name: Run frontend api tests against Infinity
+ - name: Run web api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
- source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
+ source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_api_app 2>&1 | tee infinity_web_api_test.log
- name: Run http api tests against Infinity
run: |
@@ -267,6 +421,159 @@ jobs:
done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
+ - name: RAGFlow CLI retrieval test Infinity
+ env:
+ PYTHONPATH: ${{ github.workspace }}
+ run: |
+ set -euo pipefail
+ source .venv/bin/activate
+
+ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
+
+ EMAIL="ci-${GITHUB_RUN_ID}@example.com"
+ PASS="ci-pass-${GITHUB_RUN_ID}"
+ DATASET="ci_dataset_${GITHUB_RUN_ID}"
+
+ CLI="python admin/client/ragflow_cli.py"
+
+ LOG_FILE="infinity_cli_test.log"
+ : > "${LOG_FILE}"
+
+ ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]'
+ run_cli() {
+ local logfile="$1"
+ shift
+ local allow_re=""
+ if [[ "${1:-}" == "--allow" ]]; then
+ allow_re="$2"
+ shift 2
+ fi
+ local cmd_display="$*"
+ echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}"
+ local tmp_log
+ tmp_log="$(mktemp)"
+ set +e
+ timeout 180s "$@" 2>&1 | tee "${tmp_log}"
+ local status=${PIPESTATUS[0]}
+ set -e
+ cat "${tmp_log}" >> "${logfile}"
+ if grep -qiE "${ERROR_RE}" "${tmp_log}"; then
+ if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then
+ echo "Allowed CLI error markers in ${logfile}"
+ rm -f "${tmp_log}"
+ return 0
+ fi
+ echo "Detected CLI error markers in ${logfile}"
+ rm -f "${tmp_log}"
+ exit 1
+ fi
+ rm -f "${tmp_log}"
+ return ${status}
+ }
+
+ set -a
+ source docker/.env
+ set +a
+
+ HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}"
+ USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')"
+ USER_PORT="${SVR_HTTP_PORT}"
+ ADMIN_HOST="${USER_HOST}"
+ ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}"
+
+ until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
+ echo "Waiting for service to be available..."
+ sleep 5
+ done
+
+ admin_ready=0
+ for i in $(seq 1 30); do
+ if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then
+ admin_ready=1
+ break
+ fi
+ sleep 1
+ done
+ if [[ "${admin_ready}" -ne 1 ]]; then
+ echo "Admin service did not become ready"
+ exit 1
+ fi
+
+ run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version"
+ ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?'
+ run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'"
+
+ user_ready=0
+ for i in $(seq 1 30); do
+ if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then
+ user_ready=1
+ break
+ fi
+ sleep 1
+ done
+ if [[ "${user_ready}" -ne 1 ]]; then
+ echo "User service did not become ready"
+ exit 1
+ fi
+
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync"
+ run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'"
+
+ - name: Stop ragflow to save coverage Infinity
+ if: ${{ !cancelled() }}
+ run: |
+ # Send SIGINT to ragflow_server.py to trigger coverage save
+ PID=$(sudo docker exec ${RAGFLOW_CONTAINER} ps aux | grep "ragflow_server.py" | grep -v grep | awk '{print $2}' | head -n 1)
+ if [ -n "$PID" ]; then
+ echo "Sending SIGINT to ragflow_server.py (PID: $PID)..."
+ sudo docker exec ${RAGFLOW_CONTAINER} kill -INT $PID
+ # Wait for process to exit and coverage file to be written
+ sleep 10
+ else
+ echo "ragflow_server.py not found!"
+ fi
+ sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} stop
+
+ - name: Generate server coverage report Infinity
+ if: ${{ !cancelled() }}
+ run: |
+ # .coverage file should be in docker/ragflow-logs/.coverage
+ if [ -f docker/ragflow-logs/.coverage ]; then
+ echo "Found .coverage file"
+ cp docker/ragflow-logs/.coverage .coverage
+ source .venv/bin/activate
+ # Create .coveragerc to map container paths to host paths
+ echo "[paths]" > .coveragerc
+ echo "source =" >> .coveragerc
+ echo " ." >> .coveragerc
+ echo " /ragflow" >> .coveragerc
+ coverage xml -o coverage-infinity-server.xml
+ rm .coveragerc
+ else
+ echo ".coverage file not found!"
+ fi
+
+ - name: Upload coverage reports to Codecov
+ uses: codecov/codecov-action@v5
+ if: ${{ !cancelled() }}
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ fail_ci_if_error: false
+
+ - name: Collect ragflow log
+ if: ${{ !cancelled() }}
+ run: |
+ if [ -d docker/ragflow-logs ]; then
+ cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-infinity
+ echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log || true
+ else
+ echo "No docker/ragflow-logs directory found; skipping log collection"
+ fi
+ sudo rm -rf docker/ragflow-logs || true
+
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
run: |
diff --git a/.gitignore b/.gitignore
index 11aa5449312..bc2bb8abe3a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -44,6 +44,7 @@ cl100k_base.tiktoken
chrome*
huggingface.co/
nltk_data/
+uv-x86_64*.tar.gz
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
@@ -51,6 +52,13 @@ nltk_data/
.venv
docker/data
+# OceanBase data and conf
+docker/oceanbase/conf
+docker/oceanbase/data
+
+# SeekDB data and conf
+docker/seekdb
+
#--------------------------------------------------#
# The following was generated with gitignore.nvim: #
@@ -197,4 +205,9 @@ ragflow_cli.egg-info
backup
-.hypothesis
\ No newline at end of file
+.hypothesis
+
+
+# Added by cargo
+
+/target
diff --git a/CLAUDE.md b/CLAUDE.md
index d774fc376c6..58d1217afea 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -27,7 +27,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d
- **Document Processing**: `deepdoc/` - PDF parsing, OCR, layout analysis
- **LLM Integration**: `rag/llm/` - Model abstractions for chat, embedding, reranking
- **RAG Pipeline**: `rag/flow/` - Chunking, parsing, tokenization
-- **Graph RAG**: `graphrag/` - Knowledge graph construction and querying
+- **Graph RAG**: `rag/graphrag/` - Knowledge graph construction and querying
### Agent System (`/agent/`)
- **Components**: Modular workflow components (LLM, retrieval, categorize, etc.)
@@ -113,4 +113,4 @@ RAGFlow supports switching between Elasticsearch (default) and Infinity:
- Node.js >=18.20.4
- Docker & Docker Compose
- uv package manager
-- 16GB+ RAM, 50GB+ disk space
\ No newline at end of file
+- 16GB+ RAM, 50GB+ disk space
diff --git a/Dockerfile b/Dockerfile
index 5f2c5f6cf8a..d3af16ff05e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -19,17 +19,16 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
cp -r /deps/nltk_data /root/ && \
- cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \
+ cp /deps/tika-server-standard-3.2.3.jar /deps/tika-server-standard-3.2.3.jar.md5 /ragflow/ && \
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
-ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar"
+ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.2.3.jar"
ENV DEBIAN_FRONTEND=noninteractive
# Setup apt
# Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
-# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
-# python-pptx: default-jdk tika-server-standard-3.0.0.jar
+# python-pptx: default-jdk tika-server-standard-3.2.3.jar
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
@@ -49,11 +48,21 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y libatk-bridge2.0-0 && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
- apt install -y nginx unzip curl wget git vim less && \
+ apt install -y gnupg unzip curl wget git vim less && \
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive && \
- apt install -y fonts-freefont-ttf fonts-noto-cjk
+ apt install -y fonts-freefont-ttf fonts-noto-cjk && \
+ apt install -y postgresql-client
+
+ARG NGINX_VERSION=1.29.5-1~noble
+RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
+ mkdir -p /etc/apt/keyrings && \
+ curl -fsSL https://nginx.org/keys/nginx_signing.key | gpg --dearmor -o /etc/apt/keyrings/nginx-archive-keyring.gpg && \
+ echo "deb [signed-by=/etc/apt/keyrings/nginx-archive-keyring.gpg] https://nginx.org/packages/mainline/ubuntu/ noble nginx" > /etc/apt/sources.list.d/nginx.list && \
+ apt update && \
+ apt install -y nginx=${NGINX_VERSION} && \
+ apt-mark hold nginx
# Install uv
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
@@ -64,10 +73,12 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
echo 'default = true' >> /etc/uv/uv.toml; \
fi; \
- tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
- && cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
- && rm -rf uv-x86_64-unknown-linux-gnu \
- && uv python install 3.11
+ arch="$(uname -m)"; \
+ if [ "$arch" = "x86_64" ]; then uv_arch="x86_64"; else uv_arch="aarch64"; fi; \
+ tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \
+ && cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \
+ && rm -rf "uv-${uv_arch}-unknown-linux-gnu" \
+ && uv python install 3.12
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH
@@ -125,8 +136,6 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chromedriver-l
mv chromedriver /usr/local/bin/ && \
rm -f /usr/bin/google-chrome
-# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
-# aspose-slides on linux/arm64 is unavailable
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
if [ "$(uname -m)" = "x86_64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
@@ -152,11 +161,14 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
else \
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
fi; \
- uv sync --python 3.12 --frozen
+ uv sync --python 3.12 --frozen && \
+ # Ensure pip is available in the venv for runtime package installation (fixes #12651)
+ .venv/bin/python3 -m ensurepip --upgrade
COPY web web
COPY docs docs
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
+ export NODE_OPTIONS="--max-old-space-size=4096" && \
cd web && npm install && npm run build
COPY .git /ragflow/.git
@@ -186,11 +198,8 @@ COPY conf conf
COPY deepdoc deepdoc
COPY rag rag
COPY agent agent
-COPY graphrag graphrag
-COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./
COPY mcp mcp
-COPY plugin plugin
COPY common common
COPY memory memory
diff --git a/Dockerfile.deps b/Dockerfile.deps
index c683ebf7cb7..591b99eb83e 100644
--- a/Dockerfile.deps
+++ b/Dockerfile.deps
@@ -3,7 +3,7 @@
FROM scratch
# Copy resources downloaded via download_deps.py
-COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
+COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz /
COPY nltk_data /nltk_data
diff --git a/README.md b/README.md
index 4aa670b2e09..b95fcddc772 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@
-
+
@@ -72,7 +72,7 @@
## 💡 What is RAGFlow?
-[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation (RAG) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged context engine and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
+[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation ([RAG](https://ragflow.io/basics/what-is-rag)) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged [context engine](https://ragflow.io/basics/what-is-agent-context-engine) and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
## 🎮 Demo
@@ -188,15 +188,15 @@ releases! 🌟
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
-> The command below downloads the `v0.23.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
+> The command below downloads the `v0.24.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.24.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
-
+
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d
diff --git a/README_id.md b/README_id.md
index 51fe841175a..c3cfdfcc5d1 100644
--- a/README_id.md
+++ b/README_id.md
@@ -22,7 +22,7 @@
-
+
@@ -72,7 +72,7 @@
## 💡 Apa Itu RAGFlow?
-[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
+[RAGFlow](https://ragflow.io/) adalah mesin [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
## 🎮 Demo
@@ -188,12 +188,12 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
-> Perintah di bawah ini mengunduh edisi v0.23.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
+> Perintah di bawah ini mengunduh edisi v0.24.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.24.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
diff --git a/README_ja.md b/README_ja.md
index cd65acffddb..afff19bc8fd 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -22,7 +22,7 @@
-
+
@@ -53,7 +53,7 @@
## 💡 RAGFlow とは?
-[RAGFlow](https://ragflow.io/) は、先進的なRAG(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型コンテキストエンジンと事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
+[RAGFlow](https://ragflow.io/) は、先進的な[RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型[コンテキストエンジン](https://ragflow.io/basics/what-is-agent-context-engine)と事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
## 🎮 Demo
@@ -168,12 +168,12 @@
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
-> 以下のコマンドは、RAGFlow Docker イメージの v0.23.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
+> 以下のコマンドは、RAGFlow Docker イメージの v0.24.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.24.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
```bash
$ cd ragflow/docker
- # git checkout v0.23.1
+ # git checkout v0.24.0
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
@@ -194,8 +194,8 @@
> `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
- 1. サーバーを立ち上げた後、サーバーの状態を確認する:
-
+ 1. サーバーを立ち上げた後、サーバーの状態を確認する:
+
```bash
$ docker logs -f docker-ragflow-cpu-1
```
diff --git a/README_ko.md b/README_ko.md
index b6551fa264b..91978a72a5d 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -22,7 +22,7 @@
-
+
@@ -54,7 +54,7 @@
## 💡 RAGFlow란?
-[RAGFlow](https://ragflow.io/) 는 최첨단 RAG(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 컨텍스트 엔진과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
+[RAGFlow](https://ragflow.io/) 는 최첨단 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 [컨텍스트 엔진](https://ragflow.io/basics/what-is-agent-context-engine)과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
## 🎮 데모
@@ -170,12 +170,12 @@
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
- > 아래 명령어는 RAGFlow Docker 이미지의 v0.23.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
+ > 아래 명령어는 RAGFlow Docker 이미지의 v0.24.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.24.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
diff --git a/README_pt_br.md b/README_pt_br.md
index bd196bf6dae..8fa5b6692e1 100644
--- a/README_pt_br.md
+++ b/README_pt_br.md
@@ -22,7 +22,7 @@
-
+
@@ -73,7 +73,7 @@
## 💡 O que é o RAGFlow?
-[RAGFlow](https://ragflow.io/) é um mecanismo de RAG (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por um motor de contexto convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
+[RAGFlow](https://ragflow.io/) é um mecanismo de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por [um motor de contexto](https://ragflow.io/basics/what-is-agent-context-engine) convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
## 🎮 Demo
@@ -188,12 +188,12 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
- > O comando abaixo baixa a edição`v0.23.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
+ > O comando abaixo baixa a edição`v0.24.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.24.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
diff --git a/README_tzh.md b/README_tzh.md
index a33f6f8f80c..d46d06077ce 100644
--- a/README_tzh.md
+++ b/README_tzh.md
@@ -22,7 +22,7 @@
-
+
@@ -72,7 +72,7 @@
## 💡 RAGFlow 是什麼?
-[RAGFlow](https://ragflow.io/) 是一款領先的開源 RAG(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式上下文引擎與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
+[RAGFlow](https://ragflow.io/) 是一款領先的開源 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
## 🎮 Demo 試用
@@ -187,12 +187,12 @@
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
-> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
+> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.24.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.24.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
diff --git a/README_zh.md b/README_zh.md
index 2aa34a788eb..5b194daa0ff 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -22,7 +22,7 @@
-
+
@@ -72,7 +72,7 @@
## 💡 RAGFlow 是什么?
-[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成(RAG)引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式上下文引擎与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
+[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
## 🎮 Demo 试用
@@ -188,12 +188,12 @@
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
- > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
+ > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.24.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.24.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
```bash
$ cd ragflow/docker
-
- # git checkout v0.23.1
+
+ # git checkout v0.24.0
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
@@ -204,7 +204,7 @@
# sed -i '1i DEVICE=gpu' .env
# docker compose -f docker-compose.yml up -d
```
-
+
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
diff --git a/admin/build_cli_release.sh b/admin/build_cli_release.sh
index c9fd6d9d909..d9025ff181d 100755
--- a/admin/build_cli_release.sh
+++ b/admin/build_cli_release.sh
@@ -21,7 +21,7 @@ cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
cp README.md release/$PROJECT_NAME/README.md
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
-cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py
+cp ragflow_cli.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/ragflow_cli.py
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"
diff --git a/admin/client/README.md b/admin/client/README.md
index 1f77a45d696..2090a214402 100644
--- a/admin/client/README.md
+++ b/admin/client/README.md
@@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
1. Ensure the Admin Service is running.
2. Install ragflow-cli.
```bash
- pip install ragflow-cli==0.23.1
+ pip install ragflow-cli==0.24.0
```
3. Launch the CLI client:
```bash
diff --git a/admin/client/admin_client.py b/admin/client/admin_client.py
deleted file mode 100644
index f70e1624e1b..00000000000
--- a/admin/client/admin_client.py
+++ /dev/null
@@ -1,938 +0,0 @@
-#
-# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import argparse
-import base64
-import getpass
-from cmd import Cmd
-from typing import Any, Dict, List
-
-import requests
-from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
-from Cryptodome.PublicKey import RSA
-from lark import Lark, Transformer, Tree
-
-GRAMMAR = r"""
-start: command
-
-command: sql_command | meta_command
-
-sql_command: list_services
- | show_service
- | startup_service
- | shutdown_service
- | restart_service
- | list_users
- | show_user
- | drop_user
- | alter_user
- | create_user
- | activate_user
- | list_datasets
- | list_agents
- | create_role
- | drop_role
- | alter_role
- | list_roles
- | show_role
- | grant_permission
- | revoke_permission
- | alter_user_role
- | show_user_permission
- | show_version
-
-// meta command definition
-meta_command: "\\" meta_command_name [meta_args]
-
-meta_command_name: /[a-zA-Z?]+/
-meta_args: (meta_arg)+
-
-meta_arg: /[^\\s"']+/ | quoted_string
-
-// command definition
-
-LIST: "LIST"i
-SERVICES: "SERVICES"i
-SHOW: "SHOW"i
-CREATE: "CREATE"i
-SERVICE: "SERVICE"i
-SHUTDOWN: "SHUTDOWN"i
-STARTUP: "STARTUP"i
-RESTART: "RESTART"i
-USERS: "USERS"i
-DROP: "DROP"i
-USER: "USER"i
-ALTER: "ALTER"i
-ACTIVE: "ACTIVE"i
-PASSWORD: "PASSWORD"i
-DATASETS: "DATASETS"i
-OF: "OF"i
-AGENTS: "AGENTS"i
-ROLE: "ROLE"i
-ROLES: "ROLES"i
-DESCRIPTION: "DESCRIPTION"i
-GRANT: "GRANT"i
-REVOKE: "REVOKE"i
-ALL: "ALL"i
-PERMISSION: "PERMISSION"i
-TO: "TO"i
-FROM: "FROM"i
-FOR: "FOR"i
-RESOURCES: "RESOURCES"i
-ON: "ON"i
-SET: "SET"i
-VERSION: "VERSION"i
-
-list_services: LIST SERVICES ";"
-show_service: SHOW SERVICE NUMBER ";"
-startup_service: STARTUP SERVICE NUMBER ";"
-shutdown_service: SHUTDOWN SERVICE NUMBER ";"
-restart_service: RESTART SERVICE NUMBER ";"
-
-list_users: LIST USERS ";"
-drop_user: DROP USER quoted_string ";"
-alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
-show_user: SHOW USER quoted_string ";"
-create_user: CREATE USER quoted_string quoted_string ";"
-activate_user: ALTER USER ACTIVE quoted_string status ";"
-
-list_datasets: LIST DATASETS OF quoted_string ";"
-list_agents: LIST AGENTS OF quoted_string ";"
-
-create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
-drop_role: DROP ROLE identifier ";"
-alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
-list_roles: LIST ROLES ";"
-show_role: SHOW ROLE identifier ";"
-
-grant_permission: GRANT action_list ON identifier TO ROLE identifier ";"
-revoke_permission: REVOKE action_list ON identifier FROM ROLE identifier ";"
-alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
-show_user_permission: SHOW USER PERMISSION quoted_string ";"
-
-show_version: SHOW VERSION ";"
-
-action_list: identifier ("," identifier)*
-
-identifier: WORD
-quoted_string: QUOTED_STRING
-status: WORD
-
-QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
-WORD: /[a-zA-Z0-9_\-\.]+/
-NUMBER: /[0-9]+/
-
-%import common.WS
-%ignore WS
-"""
-
-
-class AdminTransformer(Transformer):
- def start(self, items):
- return items[0]
-
- def command(self, items):
- return items[0]
-
- def list_services(self, items):
- result = {"type": "list_services"}
- return result
-
- def show_service(self, items):
- service_id = int(items[2])
- return {"type": "show_service", "number": service_id}
-
- def startup_service(self, items):
- service_id = int(items[2])
- return {"type": "startup_service", "number": service_id}
-
- def shutdown_service(self, items):
- service_id = int(items[2])
- return {"type": "shutdown_service", "number": service_id}
-
- def restart_service(self, items):
- service_id = int(items[2])
- return {"type": "restart_service", "number": service_id}
-
- def list_users(self, items):
- return {"type": "list_users"}
-
- def show_user(self, items):
- user_name = items[2]
- return {"type": "show_user", "user_name": user_name}
-
- def drop_user(self, items):
- user_name = items[2]
- return {"type": "drop_user", "user_name": user_name}
-
- def alter_user(self, items):
- user_name = items[3]
- new_password = items[4]
- return {"type": "alter_user", "user_name": user_name, "password": new_password}
-
- def create_user(self, items):
- user_name = items[2]
- password = items[3]
- return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
-
- def activate_user(self, items):
- user_name = items[3]
- activate_status = items[4]
- return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
-
- def list_datasets(self, items):
- user_name = items[3]
- return {"type": "list_datasets", "user_name": user_name}
-
- def list_agents(self, items):
- user_name = items[3]
- return {"type": "list_agents", "user_name": user_name}
-
- def create_role(self, items):
- role_name = items[2]
- if len(items) > 4:
- description = items[4]
- return {"type": "create_role", "role_name": role_name, "description": description}
- else:
- return {"type": "create_role", "role_name": role_name}
-
- def drop_role(self, items):
- role_name = items[2]
- return {"type": "drop_role", "role_name": role_name}
-
- def alter_role(self, items):
- role_name = items[2]
- description = items[5]
- return {"type": "alter_role", "role_name": role_name, "description": description}
-
- def list_roles(self, items):
- return {"type": "list_roles"}
-
- def show_role(self, items):
- role_name = items[2]
- return {"type": "show_role", "role_name": role_name}
-
- def grant_permission(self, items):
- action_list = items[1]
- resource = items[3]
- role_name = items[6]
- return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
-
- def revoke_permission(self, items):
- action_list = items[1]
- resource = items[3]
- role_name = items[6]
- return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
-
- def alter_user_role(self, items):
- user_name = items[2]
- role_name = items[5]
- return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
-
- def show_user_permission(self, items):
- user_name = items[3]
- return {"type": "show_user_permission", "user_name": user_name}
-
- def show_version(self, items):
- return {"type": "show_version"}
-
- def action_list(self, items):
- return items
-
- def meta_command(self, items):
- command_name = str(items[0]).lower()
- args = items[1:] if len(items) > 1 else []
-
- # handle quoted parameter
- parsed_args = []
- for arg in args:
- if hasattr(arg, "value"):
- parsed_args.append(arg.value)
- else:
- parsed_args.append(str(arg))
-
- return {"type": "meta", "command": command_name, "args": parsed_args}
-
- def meta_command_name(self, items):
- return items[0]
-
- def meta_args(self, items):
- return items
-
-
-def encrypt(input_string):
- pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
- pub_key = RSA.importKey(pub)
- cipher = Cipher_pkcs1_v1_5.new(pub_key)
- cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
- return base64.b64encode(cipher_text).decode("utf-8")
-
-
-def encode_to_base64(input_string):
- base64_encoded = base64.b64encode(input_string.encode("utf-8"))
- return base64_encoded.decode("utf-8")
-
-
-class AdminCLI(Cmd):
- def __init__(self):
- super().__init__()
- self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=AdminTransformer())
- self.command_history = []
- self.is_interactive = False
- self.admin_account = "admin@ragflow.io"
- self.admin_password: str = "admin"
- self.session = requests.Session()
- self.access_token: str = ""
- self.host: str = ""
- self.port: int = 0
-
- intro = r"""Type "\h" for help."""
- prompt = "admin> "
-
- def onecmd(self, command: str) -> bool:
- try:
- result = self.parse_command(command)
-
- if isinstance(result, dict):
- if "type" in result and result.get("type") == "empty":
- return False
-
- self.execute_command(result)
-
- if isinstance(result, Tree):
- return False
-
- if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
- return True
-
- except KeyboardInterrupt:
- print("\nUse '\\q' to quit")
- except EOFError:
- print("\nGoodbye!")
- return True
- return False
-
- def emptyline(self) -> bool:
- return False
-
- def default(self, line: str) -> bool:
- return self.onecmd(line)
-
- def parse_command(self, command_str: str) -> dict[str, str]:
- if not command_str.strip():
- return {"type": "empty"}
-
- self.command_history.append(command_str)
-
- try:
- result = self.parser.parse(command_str)
- return result
- except Exception as e:
- return {"type": "error", "message": f"Parse error: {str(e)}"}
-
- def verify_admin(self, arguments: dict, single_command: bool):
- self.host = arguments["host"]
- self.port = arguments["port"]
- print("Attempt to access server for admin login")
- url = f"http://{self.host}:{self.port}/api/v1/admin/login"
-
- attempt_count = 3
- if single_command:
- attempt_count = 1
-
- try_count = 0
- while True:
- try_count += 1
- if try_count > attempt_count:
- return False
-
- if single_command:
- admin_passwd = arguments["password"]
- else:
- admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip()
- try:
- self.admin_password = encrypt(admin_passwd)
- response = self.session.post(url, json={"email": self.admin_account, "password": self.admin_password})
- if response.status_code == 200:
- res_json = response.json()
- error_code = res_json.get("code", -1)
- if error_code == 0:
- self.session.headers.update({"Content-Type": "application/json", "Authorization": response.headers["Authorization"], "User-Agent": "RAGFlow-CLI/0.23.1"})
- print("Authentication successful.")
- return True
- else:
- error_message = res_json.get("message", "Unknown error")
- print(f"Authentication failed: {error_message}, try again")
- continue
- else:
- print(f"Bad response,status: {response.status_code}, password is wrong")
- except Exception as e:
- print(str(e))
- print("Can't access server for admin login (connection failed)")
-
- def _format_service_detail_table(self, data):
- if isinstance(data, list):
- return data
- if not all([isinstance(v, list) for v in data.values()]):
- # normal table
- return data
- # handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
- task_executor_list = []
- for k, v in data.items():
- # display latest status
- heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
- task_executor_list.append(
- {
- "task_executor_name": k,
- **heartbeats[0],
- }
- if heartbeats
- else {"task_executor_name": k}
- )
- return task_executor_list
-
- def _print_table_simple(self, data):
- if not data:
- print("No data to print")
- return
- if isinstance(data, dict):
- # handle single row data
- data = [data]
-
- columns = list(set().union(*(d.keys() for d in data)))
- columns.sort()
- col_widths = {}
-
- def get_string_width(text):
- half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
- width = 0
- for char in text:
- if char in half_width_chars:
- width += 1
- else:
- width += 2
- return width
-
- for col in columns:
- max_width = get_string_width(str(col))
- for item in data:
- value_len = get_string_width(str(item.get(col, "")))
- if value_len > max_width:
- max_width = value_len
- col_widths[col] = max(2, max_width)
-
- # Generate delimiter
- separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
-
- # Print header
- print(separator)
- header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
- print(header)
- print(separator)
-
- # Print data
- for item in data:
- row = "|"
- for col in columns:
- value = str(item.get(col, ""))
- if get_string_width(value) > col_widths[col]:
- value = value[: col_widths[col] - 3] + "..."
- row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
- print(row)
-
- print(separator)
-
- def run_interactive(self):
- self.is_interactive = True
- print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
-
- while True:
- try:
- command = input("admin> ").strip()
- if not command:
- continue
-
- print(f"command: {command}")
- result = self.parse_command(command)
- self.execute_command(result)
-
- if isinstance(result, Tree):
- continue
-
- if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
- break
-
- except KeyboardInterrupt:
- print("\nUse '\\q' to quit")
- except EOFError:
- print("\nGoodbye!")
- break
-
- def run_single_command(self, command: str):
- result = self.parse_command(command)
- self.execute_command(result)
-
- def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
- parser = argparse.ArgumentParser(description="Admin CLI Client", add_help=False)
- parser.add_argument("-h", "--host", default="localhost", help="Admin service host")
- parser.add_argument("-p", "--port", type=int, default=9381, help="Admin service port")
- parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
- parser.add_argument("command", nargs="?", help="Single command")
- try:
- parsed_args, remaining_args = parser.parse_known_args(args)
- if remaining_args:
- command = remaining_args[0]
- return {"host": parsed_args.host, "port": parsed_args.port, "password": parsed_args.password, "command": command}
- else:
- return {
- "host": parsed_args.host,
- "port": parsed_args.port,
- }
- except SystemExit:
- return {"error": "Invalid connection arguments"}
-
- def execute_command(self, parsed_command: Dict[str, Any]):
- command_dict: dict
- if isinstance(parsed_command, Tree):
- command_dict = parsed_command.children[0]
- else:
- if parsed_command["type"] == "error":
- print(f"Error: {parsed_command['message']}")
- return
- else:
- command_dict = parsed_command
-
- # print(f"Parsed command: {command_dict}")
-
- command_type = command_dict["type"]
-
- match command_type:
- case "list_services":
- self._handle_list_services(command_dict)
- case "show_service":
- self._handle_show_service(command_dict)
- case "restart_service":
- self._handle_restart_service(command_dict)
- case "shutdown_service":
- self._handle_shutdown_service(command_dict)
- case "startup_service":
- self._handle_startup_service(command_dict)
- case "list_users":
- self._handle_list_users(command_dict)
- case "show_user":
- self._handle_show_user(command_dict)
- case "drop_user":
- self._handle_drop_user(command_dict)
- case "alter_user":
- self._handle_alter_user(command_dict)
- case "create_user":
- self._handle_create_user(command_dict)
- case "activate_user":
- self._handle_activate_user(command_dict)
- case "list_datasets":
- self._handle_list_datasets(command_dict)
- case "list_agents":
- self._handle_list_agents(command_dict)
- case "create_role":
- self._create_role(command_dict)
- case "drop_role":
- self._drop_role(command_dict)
- case "alter_role":
- self._alter_role(command_dict)
- case "list_roles":
- self._list_roles(command_dict)
- case "show_role":
- self._show_role(command_dict)
- case "grant_permission":
- self._grant_permission(command_dict)
- case "revoke_permission":
- self._revoke_permission(command_dict)
- case "alter_user_role":
- self._alter_user_role(command_dict)
- case "show_user_permission":
- self._show_user_permission(command_dict)
- case "show_version":
- self._show_version(command_dict)
- case "meta":
- self._handle_meta_command(command_dict)
- case _:
- print(f"Command '{command_type}' would be executed with API")
-
- def _handle_list_services(self, command):
- print("Listing all services")
-
- url = f"http://{self.host}:{self.port}/api/v1/admin/services"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_show_service(self, command):
- service_id: int = command["number"]
- print(f"Showing service: {service_id}")
-
- url = f"http://{self.host}:{self.port}/api/v1/admin/services/{service_id}"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- res_data = res_json["data"]
- if "status" in res_data and res_data["status"] == "alive":
- print(f"Service {res_data['service_name']} is alive, ")
- if isinstance(res_data["message"], str):
- print(res_data["message"])
- else:
- data = self._format_service_detail_table(res_data["message"])
- self._print_table_simple(data)
- else:
- print(f"Service {res_data['service_name']} is down, {res_data['message']}")
- else:
- print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_restart_service(self, command):
- service_id: int = command["number"]
- print(f"Restart service {service_id}")
-
- def _handle_shutdown_service(self, command):
- service_id: int = command["number"]
- print(f"Shutdown service {service_id}")
-
- def _handle_startup_service(self, command):
- service_id: int = command["number"]
- print(f"Startup service {service_id}")
-
- def _handle_list_users(self, command):
- print("Listing all users")
-
- url = f"http://{self.host}:{self.port}/api/v1/admin/users"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_show_user(self, command):
- username_tree: Tree = command["user_name"]
- user_name: str = username_tree.children[0].strip("'\"")
- print(f"Showing user: {user_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- table_data = res_json["data"]
- table_data.pop("avatar")
- self._print_table_simple(table_data)
- else:
- print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_drop_user(self, command):
- username_tree: Tree = command["user_name"]
- user_name: str = username_tree.children[0].strip("'\"")
- print(f"Drop user: {user_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
- response = self.session.delete(url)
- res_json = response.json()
- if response.status_code == 200:
- print(res_json["message"])
- else:
- print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_alter_user(self, command):
- user_name_tree: Tree = command["user_name"]
- user_name: str = user_name_tree.children[0].strip("'\"")
- password_tree: Tree = command["password"]
- password: str = password_tree.children[0].strip("'\"")
- print(f"Alter user: {user_name}, password: ******")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password"
- response = self.session.put(url, json={"new_password": encrypt(password)})
- res_json = response.json()
- if response.status_code == 200:
- print(res_json["message"])
- else:
- print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_create_user(self, command):
- user_name_tree: Tree = command["user_name"]
- user_name: str = user_name_tree.children[0].strip("'\"")
- password_tree: Tree = command["password"]
- password: str = password_tree.children[0].strip("'\"")
- role: str = command["role"]
- print(f"Create user: {user_name}, password: ******, role: {role}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users"
- response = self.session.post(url, json={"user_name": user_name, "password": encrypt(password), "role": role})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_activate_user(self, command):
- user_name_tree: Tree = command["user_name"]
- user_name: str = user_name_tree.children[0].strip("'\"")
- activate_tree: Tree = command["activate_status"]
- activate_status: str = activate_tree.children[0].strip("'\"")
- if activate_status.lower() in ["on", "off"]:
- print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate"
- response = self.session.put(url, json={"activate_status": activate_status})
- res_json = response.json()
- if response.status_code == 200:
- print(res_json["message"])
- else:
- print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
- else:
- print(f"Unknown activate status: {activate_status}.")
-
- def _handle_list_datasets(self, command):
- username_tree: Tree = command["user_name"]
- user_name: str = username_tree.children[0].strip("'\"")
- print(f"Listing all datasets of user: {user_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- table_data = res_json["data"]
- for t in table_data:
- t.pop("avatar")
- self._print_table_simple(table_data)
- else:
- print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_list_agents(self, command):
- username_tree: Tree = command["user_name"]
- user_name: str = username_tree.children[0].strip("'\"")
- print(f"Listing all agents of user: {user_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- table_data = res_json["data"]
- for t in table_data:
- t.pop("avatar")
- self._print_table_simple(table_data)
- else:
- print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _create_role(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name: str = role_name_tree.children[0].strip("'\"")
- desc_str: str = ""
- if "description" in command:
- desc_tree: Tree = command["description"]
- desc_str = desc_tree.children[0].strip("'\"")
-
- print(f"create role name: {role_name}, description: {desc_str}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
- response = self.session.post(url, json={"role_name": role_name, "description": desc_str})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _drop_role(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name: str = role_name_tree.children[0].strip("'\"")
- print(f"drop role name: {role_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
- response = self.session.delete(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _alter_role(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name: str = role_name_tree.children[0].strip("'\"")
- desc_tree: Tree = command["description"]
- desc_str: str = desc_tree.children[0].strip("'\"")
-
- print(f"alter role name: {role_name}, description: {desc_str}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
- response = self.session.put(url, json={"description": desc_str})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _list_roles(self, command):
- print("Listing all roles")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
-
- def _show_role(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name: str = role_name_tree.children[0].strip("'\"")
- print(f"show role: {role_name}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
-
- def _grant_permission(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name_str: str = role_name_tree.children[0].strip("'\"")
- resource_tree: Tree = command["resource"]
- resource_str: str = resource_tree.children[0].strip("'\"")
- action_tree_list: list = command["actions"]
- actions: list = []
- for action_tree in action_tree_list:
- action_str: str = action_tree.children[0].strip("'\"")
- actions.append(action_str)
- print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
- response = self.session.post(url, json={"actions": actions, "resource": resource_str})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _revoke_permission(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name_str: str = role_name_tree.children[0].strip("'\"")
- resource_tree: Tree = command["resource"]
- resource_str: str = resource_tree.children[0].strip("'\"")
- action_tree_list: list = command["actions"]
- actions: list = []
- for action_tree in action_tree_list:
- action_str: str = action_tree.children[0].strip("'\"")
- actions.append(action_str)
- print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
- response = self.session.delete(url, json={"actions": actions, "resource": resource_str})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _alter_user_role(self, command):
- role_name_tree: Tree = command["role_name"]
- role_name_str: str = role_name_tree.children[0].strip("'\"")
- user_name_tree: Tree = command["user_name"]
- user_name_str: str = user_name_tree.children[0].strip("'\"")
- print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role"
- response = self.session.put(url, json={"role_name": role_name_str})
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
-
- def _show_user_permission(self, command):
- user_name_tree: Tree = command["user_name"]
- user_name_str: str = user_name_tree.children[0].strip("'\"")
- print(f"show_user_permission user_name: {user_name_str}")
- url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
-
- def _show_version(self, command):
- print("show_version")
- url = f"http://{self.host}:{self.port}/api/v1/admin/version"
- response = self.session.get(url)
- res_json = response.json()
- if response.status_code == 200:
- self._print_table_simple(res_json["data"])
- else:
- print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
-
- def _handle_meta_command(self, command):
- meta_command = command["command"]
- args = command.get("args", [])
-
- if meta_command in ["?", "h", "help"]:
- self.show_help()
- elif meta_command in ["q", "quit", "exit"]:
- print("Goodbye!")
- else:
- print(f"Meta command '{meta_command}' with args {args}")
-
- def show_help(self):
- """Help info"""
- help_text = """
-Commands:
- LIST SERVICES
- SHOW SERVICE
- STARTUP SERVICE
- SHUTDOWN SERVICE
- RESTART SERVICE
- LIST USERS
- SHOW USER
- DROP USER
- CREATE USER
- ALTER USER PASSWORD
- ALTER USER ACTIVE
- LIST DATASETS OF
- LIST AGENTS OF
-
-Meta Commands:
- \\?, \\h, \\help Show this help
- \\q, \\quit, \\exit Quit the CLI
- """
- print(help_text)
-
-
-def main():
- import sys
-
- cli = AdminCLI()
-
- args = cli.parse_connection_args(sys.argv)
- if "error" in args:
- print("Error: Invalid connection arguments")
- return
-
- if "command" in args:
- if "password" not in args:
- print("Error: password is missing")
- return
- if cli.verify_admin(args, single_command=True):
- command: str = args["command"]
- # print(f"Run single command: {command}")
- cli.run_single_command(command)
- else:
- if cli.verify_admin(args, single_command=False):
- print(r"""
- ____ ___ ______________ ___ __ _
- / __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
- / /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
- / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
- /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
- """)
- cli.cmdloop()
-
-
-if __name__ == "__main__":
- main()
diff --git a/admin/client/http_client.py b/admin/client/http_client.py
new file mode 100644
index 00000000000..bf18466ebcc
--- /dev/null
+++ b/admin/client/http_client.py
@@ -0,0 +1,182 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import json
+import typing
+from typing import Any, Dict, Optional
+
+import requests
+# from requests.sessions import HTTPAdapter
+
+
+class HttpClient:
+ def __init__(
+ self,
+ host: str = "127.0.0.1",
+ port: int = 9381,
+ api_version: str = "v1",
+ api_key: Optional[str] = None,
+ connect_timeout: float = 5.0,
+ read_timeout: float = 60.0,
+ verify_ssl: bool = False,
+ ) -> None:
+ self.host = host
+ self.port = port
+ self.api_version = api_version
+ self.api_key = api_key
+ self.login_token: str | None = None
+ self.connect_timeout = connect_timeout
+ self.read_timeout = read_timeout
+ self.verify_ssl = verify_ssl
+
+ def api_base(self) -> str:
+ return f"{self.host}:{self.port}/api/{self.api_version}"
+
+ def non_api_base(self) -> str:
+ return f"{self.host}:{self.port}/{self.api_version}"
+
+ def build_url(self, path: str, use_api_base: bool = True) -> str:
+ base = self.api_base() if use_api_base else self.non_api_base()
+ if self.verify_ssl:
+ return f"https://{base}/{path.lstrip('/')}"
+ else:
+ return f"http://{base}/{path.lstrip('/')}"
+
+ def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
+ headers = {}
+ if auth_kind == "api" and self.api_key:
+ headers["Authorization"] = f"Bearer {self.api_key}"
+ elif auth_kind == "web" and self.login_token:
+ headers["Authorization"] = self.login_token
+ elif auth_kind == "admin" and self.login_token:
+ headers["Authorization"] = self.login_token
+ else:
+ pass
+ if extra:
+ headers.update(extra)
+ return headers
+
+ def request(
+ self,
+ method: str,
+ path: str,
+ *,
+ use_api_base: bool = True,
+ auth_kind: Optional[str] = "api",
+ headers: Optional[Dict[str, str]] = None,
+ json_body: Optional[Dict[str, Any]] = None,
+ data: Any = None,
+ files: Any = None,
+ params: Optional[Dict[str, Any]] = None,
+ stream: bool = False,
+ iterations: int = 1,
+ ) -> requests.Response | dict:
+ url = self.build_url(path, use_api_base=use_api_base)
+ merged_headers = self._headers(auth_kind, headers)
+ # timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
+ session = requests.Session()
+ # adapter = HTTPAdapter(pool_connections=100, pool_maxsize=100)
+ # session.mount("http://", adapter)
+ http_function = typing.Any
+ match method:
+ case "GET":
+ http_function = session.get
+ case "POST":
+ http_function = session.post
+ case "PUT":
+ http_function = session.put
+ case "DELETE":
+ http_function = session.delete
+ case "PATCH":
+ http_function = session.patch
+ case _:
+ raise ValueError(f"Invalid HTTP method: {method}")
+
+ if iterations > 1:
+ response_list = []
+ total_duration = 0.0
+ for _ in range(iterations):
+ start_time = time.perf_counter()
+ response = http_function(url, headers=merged_headers, json=json_body, data=data, stream=stream)
+ # response = session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
+ # response = requests.request(
+ # method=method,
+ # url=url,
+ # headers=merged_headers,
+ # json=json_body,
+ # data=data,
+ # files=files,
+ # params=params,
+ # stream=stream,
+ # verify=self.verify_ssl,
+ # )
+ end_time = time.perf_counter()
+ total_duration += end_time - start_time
+ response_list.append(response)
+ return {"duration": total_duration, "response_list": response_list}
+ else:
+ return http_function(url, headers=merged_headers, json=json_body, data=data, stream=stream)
+ # return session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
+ # return requests.request(
+ # method=method,
+ # url=url,
+ # headers=merged_headers,
+ # json=json_body,
+ # data=data,
+ # files=files,
+ # params=params,
+ # stream=stream,
+ # verify=self.verify_ssl,
+ # )
+
+ def request_json(
+ self,
+ method: str,
+ path: str,
+ *,
+ use_api_base: bool = True,
+ auth_kind: Optional[str] = "api",
+ headers: Optional[Dict[str, str]] = None,
+ json_body: Optional[Dict[str, Any]] = None,
+ data: Any = None,
+ files: Any = None,
+ params: Optional[Dict[str, Any]] = None,
+ stream: bool = False,
+ ) -> Dict[str, Any]:
+ response = self.request(
+ method,
+ path,
+ use_api_base=use_api_base,
+ auth_kind=auth_kind,
+ headers=headers,
+ json_body=json_body,
+ data=data,
+ files=files,
+ params=params,
+ stream=stream,
+ )
+ try:
+ return response.json()
+ except Exception as exc:
+ raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
+
+ @staticmethod
+ def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
+ try:
+ return json.loads(raw.decode("utf-8"))
+ except Exception as exc:
+ raise ValueError(f"Invalid JSON payload: {exc}") from exc
diff --git a/admin/client/parser.py b/admin/client/parser.py
new file mode 100644
index 00000000000..d1d5c626231
--- /dev/null
+++ b/admin/client/parser.py
@@ -0,0 +1,623 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from lark import Transformer
+
+GRAMMAR = r"""
+start: command
+
+command: sql_command | meta_command
+
+sql_command: login_user
+ | ping_server
+ | list_services
+ | show_service
+ | startup_service
+ | shutdown_service
+ | restart_service
+ | register_user
+ | list_users
+ | show_user
+ | drop_user
+ | alter_user
+ | create_user
+ | activate_user
+ | list_datasets
+ | list_agents
+ | create_role
+ | drop_role
+ | alter_role
+ | list_roles
+ | show_role
+ | grant_permission
+ | revoke_permission
+ | alter_user_role
+ | show_user_permission
+ | show_version
+ | grant_admin
+ | revoke_admin
+ | set_variable
+ | show_variable
+ | list_variables
+ | list_configs
+ | list_environments
+ | generate_key
+ | list_keys
+ | drop_key
+ | show_current_user
+ | set_default_llm
+ | set_default_vlm
+ | set_default_embedding
+ | set_default_reranker
+ | set_default_asr
+ | set_default_tts
+ | reset_default_llm
+ | reset_default_vlm
+ | reset_default_embedding
+ | reset_default_reranker
+ | reset_default_asr
+ | reset_default_tts
+ | create_model_provider
+ | drop_model_provider
+ | create_user_dataset_with_parser
+ | create_user_dataset_with_pipeline
+ | drop_user_dataset
+ | list_user_datasets
+ | list_user_dataset_files
+ | list_user_agents
+ | list_user_chats
+ | create_user_chat
+ | drop_user_chat
+ | list_user_model_providers
+ | list_user_default_models
+ | parse_dataset_docs
+ | parse_dataset_sync
+ | parse_dataset_async
+ | import_docs_into_dataset
+ | search_on_datasets
+ | benchmark
+
+// meta command definition
+meta_command: "\\" meta_command_name [meta_args]
+
+meta_command_name: /[a-zA-Z?]+/
+meta_args: (meta_arg)+
+
+meta_arg: /[^\\s"']+/ | quoted_string
+
+// command definition
+
+LOGIN: "LOGIN"i
+REGISTER: "REGISTER"i
+LIST: "LIST"i
+SERVICES: "SERVICES"i
+SHOW: "SHOW"i
+CREATE: "CREATE"i
+SERVICE: "SERVICE"i
+SHUTDOWN: "SHUTDOWN"i
+STARTUP: "STARTUP"i
+RESTART: "RESTART"i
+USERS: "USERS"i
+DROP: "DROP"i
+USER: "USER"i
+ALTER: "ALTER"i
+ACTIVE: "ACTIVE"i
+ADMIN: "ADMIN"i
+PASSWORD: "PASSWORD"i
+DATASET: "DATASET"i
+DATASETS: "DATASETS"i
+OF: "OF"i
+AGENTS: "AGENTS"i
+ROLE: "ROLE"i
+ROLES: "ROLES"i
+DESCRIPTION: "DESCRIPTION"i
+GRANT: "GRANT"i
+REVOKE: "REVOKE"i
+ALL: "ALL"i
+PERMISSION: "PERMISSION"i
+TO: "TO"i
+FROM: "FROM"i
+FOR: "FOR"i
+RESOURCES: "RESOURCES"i
+ON: "ON"i
+SET: "SET"i
+RESET: "RESET"i
+VERSION: "VERSION"i
+VAR: "VAR"i
+VARS: "VARS"i
+CONFIGS: "CONFIGS"i
+ENVS: "ENVS"i
+KEY: "KEY"i
+KEYS: "KEYS"i
+GENERATE: "GENERATE"i
+MODEL: "MODEL"i
+MODELS: "MODELS"i
+PROVIDER: "PROVIDER"i
+PROVIDERS: "PROVIDERS"i
+DEFAULT: "DEFAULT"i
+CHATS: "CHATS"i
+CHAT: "CHAT"i
+FILES: "FILES"i
+AS: "AS"i
+PARSE: "PARSE"i
+IMPORT: "IMPORT"i
+INTO: "INTO"i
+WITH: "WITH"i
+PARSER: "PARSER"i
+PIPELINE: "PIPELINE"i
+SEARCH: "SEARCH"i
+CURRENT: "CURRENT"i
+LLM: "LLM"i
+VLM: "VLM"i
+EMBEDDING: "EMBEDDING"i
+RERANKER: "RERANKER"i
+ASR: "ASR"i
+TTS: "TTS"i
+ASYNC: "ASYNC"i
+SYNC: "SYNC"i
+BENCHMARK: "BENCHMARK"i
+PING: "PING"i
+
+login_user: LOGIN USER quoted_string ";"
+list_services: LIST SERVICES ";"
+show_service: SHOW SERVICE NUMBER ";"
+startup_service: STARTUP SERVICE NUMBER ";"
+shutdown_service: SHUTDOWN SERVICE NUMBER ";"
+restart_service: RESTART SERVICE NUMBER ";"
+
+register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
+list_users: LIST USERS ";"
+drop_user: DROP USER quoted_string ";"
+alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
+show_user: SHOW USER quoted_string ";"
+create_user: CREATE USER quoted_string quoted_string ";"
+activate_user: ALTER USER ACTIVE quoted_string status ";"
+
+list_datasets: LIST DATASETS OF quoted_string ";"
+list_agents: LIST AGENTS OF quoted_string ";"
+
+create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
+drop_role: DROP ROLE identifier ";"
+alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
+list_roles: LIST ROLES ";"
+show_role: SHOW ROLE identifier ";"
+
+grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
+revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
+alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
+show_user_permission: SHOW USER PERMISSION quoted_string ";"
+
+show_version: SHOW VERSION ";"
+
+grant_admin: GRANT ADMIN quoted_string ";"
+revoke_admin: REVOKE ADMIN quoted_string ";"
+
+generate_key: GENERATE KEY FOR USER quoted_string ";"
+list_keys: LIST KEYS OF quoted_string ";"
+drop_key: DROP KEY quoted_string OF quoted_string ";"
+
+set_variable: SET VAR identifier identifier ";"
+show_variable: SHOW VAR identifier ";"
+list_variables: LIST VARS ";"
+list_configs: LIST CONFIGS ";"
+list_environments: LIST ENVS ";"
+
+benchmark: BENCHMARK NUMBER NUMBER user_statement
+
+user_statement: ping_server
+ | show_current_user
+ | create_model_provider
+ | drop_model_provider
+ | set_default_llm
+ | set_default_vlm
+ | set_default_embedding
+ | set_default_reranker
+ | set_default_asr
+ | set_default_tts
+ | reset_default_llm
+ | reset_default_vlm
+ | reset_default_embedding
+ | reset_default_reranker
+ | reset_default_asr
+ | reset_default_tts
+ | create_user_dataset_with_parser
+ | create_user_dataset_with_pipeline
+ | drop_user_dataset
+ | list_user_datasets
+ | list_user_dataset_files
+ | list_user_agents
+ | list_user_chats
+ | create_user_chat
+ | drop_user_chat
+ | list_user_model_providers
+ | list_user_default_models
+ | import_docs_into_dataset
+ | search_on_datasets
+
+ping_server: PING ";"
+show_current_user: SHOW CURRENT USER ";"
+create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
+drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
+set_default_llm: SET DEFAULT LLM quoted_string ";"
+set_default_vlm: SET DEFAULT VLM quoted_string ";"
+set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
+set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
+set_default_asr: SET DEFAULT ASR quoted_string ";"
+set_default_tts: SET DEFAULT TTS quoted_string ";"
+
+reset_default_llm: RESET DEFAULT LLM ";"
+reset_default_vlm: RESET DEFAULT VLM ";"
+reset_default_embedding: RESET DEFAULT EMBEDDING ";"
+reset_default_reranker: RESET DEFAULT RERANKER ";"
+reset_default_asr: RESET DEFAULT ASR ";"
+reset_default_tts: RESET DEFAULT TTS ";"
+
+list_user_datasets: LIST DATASETS ";"
+create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
+create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
+drop_user_dataset: DROP DATASET quoted_string ";"
+list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
+list_user_agents: LIST AGENTS ";"
+list_user_chats: LIST CHATS ";"
+create_user_chat: CREATE CHAT quoted_string ";"
+drop_user_chat: DROP CHAT quoted_string ";"
+list_user_model_providers: LIST MODEL PROVIDERS ";"
+list_user_default_models: LIST DEFAULT MODELS ";"
+import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
+search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
+
+parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
+parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
+parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
+
+identifier_list: identifier ("," identifier)*
+
+identifier: WORD
+quoted_string: QUOTED_STRING
+status: WORD
+
+QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
+WORD: /[a-zA-Z0-9_\-\.]+/
+NUMBER: /[0-9]+/
+
+%import common.WS
+%ignore WS
+"""
+
+
+class RAGFlowCLITransformer(Transformer):
+ def start(self, items):
+ return items[0]
+
+ def command(self, items):
+ return items[0]
+
+ def login_user(self, items):
+ email = items[2].children[0].strip("'\"")
+ return {"type": "login_user", "email": email}
+
+ def ping_server(self, items):
+ return {"type": "ping_server"}
+
+ def list_services(self, items):
+ result = {"type": "list_services"}
+ return result
+
+ def show_service(self, items):
+ service_id = int(items[2])
+ return {"type": "show_service", "number": service_id}
+
+ def startup_service(self, items):
+ service_id = int(items[2])
+ return {"type": "startup_service", "number": service_id}
+
+ def shutdown_service(self, items):
+ service_id = int(items[2])
+ return {"type": "shutdown_service", "number": service_id}
+
+ def restart_service(self, items):
+ service_id = int(items[2])
+ return {"type": "restart_service", "number": service_id}
+
+ def register_user(self, items):
+ user_name: str = items[2].children[0].strip("'\"")
+ nickname: str = items[4].children[0].strip("'\"")
+ password: str = items[6].children[0].strip("'\"")
+ return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
+
+ def list_users(self, items):
+ return {"type": "list_users"}
+
+ def show_user(self, items):
+ user_name = items[2]
+ return {"type": "show_user", "user_name": user_name}
+
+ def drop_user(self, items):
+ user_name = items[2]
+ return {"type": "drop_user", "user_name": user_name}
+
+ def alter_user(self, items):
+ user_name = items[3]
+ new_password = items[4]
+ return {"type": "alter_user", "user_name": user_name, "password": new_password}
+
+ def create_user(self, items):
+ user_name = items[2]
+ password = items[3]
+ return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
+
+ def activate_user(self, items):
+ user_name = items[3]
+ activate_status = items[4]
+ return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
+
+ def list_datasets(self, items):
+ user_name = items[3]
+ return {"type": "list_datasets", "user_name": user_name}
+
+ def list_agents(self, items):
+ user_name = items[3]
+ return {"type": "list_agents", "user_name": user_name}
+
+ def create_role(self, items):
+ role_name = items[2]
+ if len(items) > 4:
+ description = items[4]
+ return {"type": "create_role", "role_name": role_name, "description": description}
+ else:
+ return {"type": "create_role", "role_name": role_name}
+
+ def drop_role(self, items):
+ role_name = items[2]
+ return {"type": "drop_role", "role_name": role_name}
+
+ def alter_role(self, items):
+ role_name = items[2]
+ description = items[5]
+ return {"type": "alter_role", "role_name": role_name, "description": description}
+
+ def list_roles(self, items):
+ return {"type": "list_roles"}
+
+ def show_role(self, items):
+ role_name = items[2]
+ return {"type": "show_role", "role_name": role_name}
+
+ def grant_permission(self, items):
+ action_list = items[1]
+ resource = items[3]
+ role_name = items[6]
+ return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
+
+ def revoke_permission(self, items):
+ action_list = items[1]
+ resource = items[3]
+ role_name = items[6]
+ return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
+
+ def alter_user_role(self, items):
+ user_name = items[2]
+ role_name = items[5]
+ return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
+
+ def show_user_permission(self, items):
+ user_name = items[3]
+ return {"type": "show_user_permission", "user_name": user_name}
+
+ def show_version(self, items):
+ return {"type": "show_version"}
+
+ def grant_admin(self, items):
+ user_name = items[2]
+ return {"type": "grant_admin", "user_name": user_name}
+
+ def revoke_admin(self, items):
+ user_name = items[2]
+ return {"type": "revoke_admin", "user_name": user_name}
+
+ def generate_key(self, items):
+ user_name = items[4]
+ return {"type": "generate_key", "user_name": user_name}
+
+ def list_keys(self, items):
+ user_name = items[3]
+ return {"type": "list_keys", "user_name": user_name}
+
+ def drop_key(self, items):
+ key = items[2]
+ user_name = items[4]
+ return {"type": "drop_key", "key": key, "user_name": user_name}
+
+ def set_variable(self, items):
+ var_name = items[2]
+ var_value = items[3]
+ return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
+
+ def show_variable(self, items):
+ var_name = items[2]
+ return {"type": "show_variable", "var_name": var_name}
+
+ def list_variables(self, items):
+ return {"type": "list_variables"}
+
+ def list_configs(self, items):
+ return {"type": "list_configs"}
+
+ def list_environments(self, items):
+ return {"type": "list_environments"}
+
+ def create_model_provider(self, items):
+ provider_name = items[3].children[0].strip("'\"")
+ provider_key = items[4].children[0].strip("'\"")
+ return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
+
+ def drop_model_provider(self, items):
+ provider_name = items[3].children[0].strip("'\"")
+ return {"type": "drop_model_provider", "provider_name": provider_name}
+
+ def show_current_user(self, items):
+ return {"type": "show_current_user"}
+
+ def set_default_llm(self, items):
+ llm_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
+
+ def set_default_vlm(self, items):
+ vlm_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
+
+ def set_default_embedding(self, items):
+ embedding_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
+
+ def set_default_reranker(self, items):
+ reranker_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
+
+ def set_default_asr(self, items):
+ asr_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
+
+ def set_default_tts(self, items):
+ tts_id = items[3].children[0].strip("'\"")
+ return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
+
+ def reset_default_llm(self, items):
+ return {"type": "reset_default_model", "model_type": "llm_id"}
+
+ def reset_default_vlm(self, items):
+ return {"type": "reset_default_model", "model_type": "img2txt_id"}
+
+ def reset_default_embedding(self, items):
+ return {"type": "reset_default_model", "model_type": "embd_id"}
+
+ def reset_default_reranker(self, items):
+ return {"type": "reset_default_model", "model_type": "reranker_id"}
+
+ def reset_default_asr(self, items):
+ return {"type": "reset_default_model", "model_type": "asr_id"}
+
+ def reset_default_tts(self, items):
+ return {"type": "reset_default_model", "model_type": "tts_id"}
+
+ def list_user_datasets(self, items):
+ return {"type": "list_user_datasets"}
+
+ def create_user_dataset_with_parser(self, items):
+ dataset_name = items[2].children[0].strip("'\"")
+ embedding = items[5].children[0].strip("'\"")
+ parser_type = items[7].children[0].strip("'\"")
+ return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
+ "parser_type": parser_type}
+
+ def create_user_dataset_with_pipeline(self, items):
+ dataset_name = items[2].children[0].strip("'\"")
+ embedding = items[5].children[0].strip("'\"")
+ pipeline = items[7].children[0].strip("'\"")
+ return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
+ "pipeline": pipeline}
+
+ def drop_user_dataset(self, items):
+ dataset_name = items[2].children[0].strip("'\"")
+ return {"type": "drop_user_dataset", "dataset_name": dataset_name}
+
+ def list_user_dataset_files(self, items):
+ dataset_name = items[4].children[0].strip("'\"")
+ return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
+
+ def list_user_agents(self, items):
+ return {"type": "list_user_agents"}
+
+ def list_user_chats(self, items):
+ return {"type": "list_user_chats"}
+
+ def create_user_chat(self, items):
+ chat_name = items[2].children[0].strip("'\"")
+ return {"type": "create_user_chat", "chat_name": chat_name}
+
+ def drop_user_chat(self, items):
+ chat_name = items[2].children[0].strip("'\"")
+ return {"type": "drop_user_chat", "chat_name": chat_name}
+
+ def list_user_model_providers(self, items):
+ return {"type": "list_user_model_providers"}
+
+ def list_user_default_models(self, items):
+ return {"type": "list_user_default_models"}
+
+ def parse_dataset_docs(self, items):
+ document_list_str = items[1].children[0].strip("'\"")
+ document_names = document_list_str.split(",")
+ if len(document_names) == 1:
+ document_names = document_names[0]
+ document_names = document_names.split(" ")
+ dataset_name = items[4].children[0].strip("'\"")
+ return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
+
+ def parse_dataset_sync(self, items):
+ dataset_name = items[2].children[0].strip("'\"")
+ return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
+
+ def parse_dataset_async(self, items):
+ dataset_name = items[2].children[0].strip("'\"")
+ return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
+
+ def import_docs_into_dataset(self, items):
+ document_list_str = items[1].children[0].strip("'\"")
+ document_paths = document_list_str.split(",")
+ if len(document_paths) == 1:
+ document_paths = document_paths[0]
+ document_paths = document_paths.split(" ")
+ dataset_name = items[4].children[0].strip("'\"")
+ return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
+
+ def search_on_datasets(self, items):
+ question = items[1].children[0].strip("'\"")
+ datasets_str = items[4].children[0].strip("'\"")
+ datasets = datasets_str.split(",")
+ if len(datasets) == 1:
+ datasets = datasets[0]
+ datasets = datasets.split(" ")
+ return {"type": "search_on_datasets", "datasets": datasets, "question": question}
+
+ def benchmark(self, items):
+ concurrency: int = int(items[1])
+ iterations: int = int(items[2])
+ command = items[3].children[0]
+ return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
+
+ def action_list(self, items):
+ return items
+
+ def meta_command(self, items):
+ command_name = str(items[0]).lower()
+ args = items[1:] if len(items) > 1 else []
+
+ # handle quoted parameter
+ parsed_args = []
+ for arg in args:
+ if hasattr(arg, "value"):
+ parsed_args.append(arg.value)
+ else:
+ parsed_args.append(str(arg))
+
+ return {"type": "meta", "command": command_name, "args": parsed_args}
+
+ def meta_command_name(self, items):
+ return items[0]
+
+ def meta_args(self, items):
+ return items
diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml
index de6bf7bc348..4b5e2cd31b8 100644
--- a/admin/client/pyproject.toml
+++ b/admin/client/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ragflow-cli"
-version = "0.23.1"
+version = "0.24.0"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" }
@@ -20,5 +20,8 @@ test = [
"requests-toolbelt>=1.0.0",
]
+[tool.setuptools]
+py-modules = ["ragflow_cli", "parser"]
+
[project.scripts]
-ragflow-cli = "admin_client:main"
+ragflow-cli = "ragflow_cli:main"
diff --git a/admin/client/ragflow_cli.py b/admin/client/ragflow_cli.py
new file mode 100644
index 00000000000..38c32ddff4d
--- /dev/null
+++ b/admin/client/ragflow_cli.py
@@ -0,0 +1,322 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import argparse
+import base64
+import getpass
+from cmd import Cmd
+from typing import Any, Dict, List
+
+import requests
+import warnings
+from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
+from Cryptodome.PublicKey import RSA
+from lark import Lark, Tree
+from parser import GRAMMAR, RAGFlowCLITransformer
+from http_client import HttpClient
+from ragflow_client import RAGFlowClient, run_command
+from user import login_user
+
+warnings.filterwarnings("ignore", category=getpass.GetPassWarning)
+
+def encrypt(input_string):
+ pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
+ pub_key = RSA.importKey(pub)
+ cipher = Cipher_pkcs1_v1_5.new(pub_key)
+ cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
+ return base64.b64encode(cipher_text).decode("utf-8")
+
+
+def encode_to_base64(input_string):
+ base64_encoded = base64.b64encode(input_string.encode("utf-8"))
+ return base64_encoded.decode("utf-8")
+
+
+
+
+
+class RAGFlowCLI(Cmd):
+ def __init__(self):
+ super().__init__()
+ self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=RAGFlowCLITransformer())
+ self.command_history = []
+ self.account = "admin@ragflow.io"
+ self.account_password: str = "admin"
+ self.session = requests.Session()
+ self.host: str = ""
+ self.port: int = 0
+ self.mode: str = "admin"
+ self.ragflow_client = None
+
+ intro = r"""Type "\h" for help."""
+ prompt = "ragflow> "
+
+ def onecmd(self, command: str) -> bool:
+ try:
+ result = self.parse_command(command)
+
+ if isinstance(result, dict):
+ if "type" in result and result.get("type") == "empty":
+ return False
+
+ self.execute_command(result)
+
+ if isinstance(result, Tree):
+ return False
+
+ if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
+ return True
+
+ except KeyboardInterrupt:
+ print("\nUse '\\q' to quit")
+ except EOFError:
+ print("\nGoodbye!")
+ return True
+ return False
+
+ def emptyline(self) -> bool:
+ return False
+
+ def default(self, line: str) -> bool:
+ return self.onecmd(line)
+
+ def parse_command(self, command_str: str) -> dict[str, str]:
+ if not command_str.strip():
+ return {"type": "empty"}
+
+ self.command_history.append(command_str)
+
+ try:
+ result = self.parser.parse(command_str)
+ return result
+ except Exception as e:
+ return {"type": "error", "message": f"Parse error: {str(e)}"}
+
+ def verify_auth(self, arguments: dict, single_command: bool, auth: bool):
+ server_type = arguments.get("type", "admin")
+ http_client = HttpClient(arguments["host"], arguments["port"])
+ if not auth:
+ self.ragflow_client = RAGFlowClient(http_client, server_type)
+ return True
+
+ user_name = arguments["username"]
+ attempt_count = 3
+ if single_command:
+ attempt_count = 1
+
+ try_count = 0
+ while True:
+ try_count += 1
+ if try_count > attempt_count:
+ return False
+
+ if single_command:
+ user_password = arguments["password"]
+ else:
+ user_password = getpass.getpass(f"password for {user_name}: ").strip()
+
+ try:
+ token = login_user(http_client, server_type, user_name, user_password)
+ http_client.login_token = token
+ self.ragflow_client = RAGFlowClient(http_client, server_type)
+ return True
+ except Exception as e:
+ print(str(e))
+ print("Can't access server for login (connection failed)")
+
+ def _format_service_detail_table(self, data):
+ if isinstance(data, list):
+ return data
+ if not all([isinstance(v, list) for v in data.values()]):
+ # normal table
+ return data
+ # handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
+ task_executor_list = []
+ for k, v in data.items():
+ # display latest status
+ heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
+ task_executor_list.append(
+ {
+ "task_executor_name": k,
+ **heartbeats[0],
+ }
+ if heartbeats
+ else {"task_executor_name": k}
+ )
+ return task_executor_list
+
+ def _print_table_simple(self, data):
+ if not data:
+ print("No data to print")
+ return
+ if isinstance(data, dict):
+ # handle single row data
+ data = [data]
+
+ columns = list(set().union(*(d.keys() for d in data)))
+ columns.sort()
+ col_widths = {}
+
+ def get_string_width(text):
+ half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
+ width = 0
+ for char in text:
+ if char in half_width_chars:
+ width += 1
+ else:
+ width += 2
+ return width
+
+ for col in columns:
+ max_width = get_string_width(str(col))
+ for item in data:
+ value_len = get_string_width(str(item.get(col, "")))
+ if value_len > max_width:
+ max_width = value_len
+ col_widths[col] = max(2, max_width)
+
+ # Generate delimiter
+ separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
+
+ # Print header
+ print(separator)
+ header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
+ print(header)
+ print(separator)
+
+ # Print data
+ for item in data:
+ row = "|"
+ for col in columns:
+ value = str(item.get(col, ""))
+ if get_string_width(value) > col_widths[col]:
+ value = value[: col_widths[col] - 3] + "..."
+ row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
+ print(row)
+
+ print(separator)
+
+ def run_interactive(self, args):
+ if self.verify_auth(args, single_command=False, auth=args["auth"]):
+ print(r"""
+ ____ ___ ______________ ________ ____
+ / __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/
+ / /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / /
+ / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ /
+ /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/
+ """)
+ self.cmdloop()
+
+ print("RAGFlow command line interface - Type '\\?' for help, '\\q' to quit")
+
+ def run_single_command(self, args):
+ if self.verify_auth(args, single_command=True, auth=args["auth"]):
+ command = args["command"]
+ result = self.parse_command(command)
+ self.execute_command(result)
+
+
+ def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
+ parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False)
+ parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host")
+ parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port")
+ parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
+ parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user")
+ parser.add_argument("-u", "--username", default=None,
+ help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.")
+ parser.add_argument("command", nargs="?", help="Single command")
+ try:
+ parsed_args, remaining_args = parser.parse_known_args(args)
+ # Determine username based on mode
+ username = parsed_args.username
+ if parsed_args.type == "admin":
+ if username is None:
+ username = "admin@ragflow.io"
+
+ if remaining_args:
+ if remaining_args[0] == "command":
+ command_str = ' '.join(remaining_args[1:]) + ';'
+ auth = True
+ if remaining_args[1] == "register":
+ auth = False
+ else:
+ if username is None:
+ print("Error: username (-u) is required in user mode")
+ return {"error": "Username required"}
+ return {
+ "host": parsed_args.host,
+ "port": parsed_args.port,
+ "password": parsed_args.password,
+ "type": parsed_args.type,
+ "username": username,
+ "command": command_str,
+ "auth": auth
+ }
+ else:
+ return {"error": "Invalid command"}
+ else:
+ auth = True
+ if username is None:
+ auth = False
+ return {
+ "host": parsed_args.host,
+ "port": parsed_args.port,
+ "type": parsed_args.type,
+ "username": username,
+ "auth": auth
+ }
+ except SystemExit:
+ return {"error": "Invalid connection arguments"}
+
+ def execute_command(self, parsed_command: Dict[str, Any]):
+ command_dict: dict
+ if isinstance(parsed_command, Tree):
+ command_dict = parsed_command.children[0]
+ else:
+ if parsed_command["type"] == "error":
+ print(f"Error: {parsed_command['message']}")
+ return
+ else:
+ command_dict = parsed_command
+
+ # print(f"Parsed command: {command_dict}")
+ run_command(self.ragflow_client, command_dict)
+
+def main():
+
+ cli = RAGFlowCLI()
+
+ args = cli.parse_connection_args(sys.argv)
+ if "error" in args:
+ print("Error: Invalid connection arguments")
+ return
+
+ if "command" in args:
+ # single command mode
+ # for user mode, api key or password is ok
+ # for admin mode, only password
+ if "password" not in args:
+ print("Error: password is missing")
+ return
+
+ cli.run_single_command(args)
+ else:
+ cli.run_interactive(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py
new file mode 100644
index 00000000000..7433467dedf
--- /dev/null
+++ b/admin/client/ragflow_client.py
@@ -0,0 +1,1508 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from typing import Any, List, Optional
+import multiprocessing as mp
+from concurrent.futures import ProcessPoolExecutor, as_completed
+import urllib.parse
+from pathlib import Path
+from http_client import HttpClient
+from lark import Tree
+from user import encrypt_password, login_user
+
+import getpass
+import base64
+from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
+from Cryptodome.PublicKey import RSA
+
+try:
+ from requests_toolbelt import MultipartEncoder
+except Exception as e: # pragma: no cover - fallback without toolbelt
+ print(f"Fallback without belt: {e}")
+ MultipartEncoder = None
+
+
+def encrypt(input_string):
+ pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
+ pub_key = RSA.importKey(pub)
+ cipher = Cipher_pkcs1_v1_5.new(pub_key)
+ cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
+ return base64.b64encode(cipher_text).decode("utf-8")
+
+
+class RAGFlowClient:
+ def __init__(self, http_client: HttpClient, server_type: str):
+ self.http_client = http_client
+ self.server_type = server_type
+
+ def login_user(self, command):
+ try:
+ response = self.http_client.request("GET", "/system/ping", use_api_base=False, auth_kind="web")
+ if response.status_code == 200 and response.content == b"pong":
+ pass
+ else:
+ print("Server is down")
+ return
+ except Exception as e:
+ print(str(e))
+ print("Can't access server for login (connection failed)")
+ return
+
+ email : str = command["email"]
+ user_password = getpass.getpass(f"password for {email}: ").strip()
+ try:
+ token = login_user(self.http_client, self.server_type, email, user_password)
+ self.http_client.login_token = token
+ print(f"Login user {email} successfully")
+ except Exception as e:
+ print(str(e))
+ print("Can't access server for login (connection failed)")
+
+ def ping_server(self, command):
+ iterations = command.get("iterations", 1)
+ if iterations > 1:
+ response = self.http_client.request("GET", "/system/ping", use_api_base=False, auth_kind="web",
+ iterations=iterations)
+ return response
+ else:
+ response = self.http_client.request("GET", "/system/ping", use_api_base=False, auth_kind="web")
+ if response.status_code == 200 and response.content == b"pong":
+ print("Server is alive")
+ else:
+ print("Server is down")
+ return None
+
+ def register_user(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ username: str = command["user_name"]
+ nickname: str = command["nickname"]
+ password: str = command["password"]
+ enc_password = encrypt_password(password)
+ print(f"Register user: {nickname}, email: {username}, password: ******")
+ payload = {"email": username, "nickname": nickname, "password": enc_password}
+ response = self.http_client.request(method="POST", path="/user/register",
+ json_body=payload, use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ if res_json["code"] == 0:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to register user {username}, code: {res_json['code']}, message: {res_json['message']}")
+ else:
+ print(f"Fail to register user {username}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_services(self):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/services", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
+ pass
+
+ def show_service(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ service_id: int = command["number"]
+
+ response = self.http_client.request("GET", f"/admin/services/{service_id}", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ res_data = res_json["data"]
+ if "status" in res_data and res_data["status"] == "alive":
+ print(f"Service {res_data['service_name']} is alive, ")
+ res_message = res_data["message"]
+ if res_message is None:
+ return
+ elif isinstance(res_message, str):
+ print(res_message)
+ else:
+ data = self._format_service_detail_table(res_message)
+ self._print_table_simple(data)
+ else:
+ print(f"Service {res_data['service_name']} is down, {res_data['message']}")
+ else:
+ print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
+
+ def restart_service(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ # service_id: int = command["number"]
+ print("Restart service isn't implemented")
+
+ def shutdown_service(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ # service_id: int = command["number"]
+ print("Shutdown service isn't implemented")
+
+ def startup_service(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ # service_id: int = command["number"]
+ print("Startup service isn't implemented")
+
+ def list_users(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/users", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_user(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Showing user: {user_name}")
+ response = self.http_client.request("GET", f"/admin/users/{user_name}", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ table_data = res_json["data"][0]
+ table_data.pop("avatar")
+ self._print_table_simple(table_data)
+ else:
+ print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_user(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Drop user: {user_name}")
+ response = self.http_client.request("DELETE", f"/admin/users/{user_name}", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
+
+ def alter_user(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name: str = user_name_tree.children[0].strip("'\"")
+ password_tree: Tree = command["password"]
+ password: str = password_tree.children[0].strip("'\"")
+ print(f"Alter user: {user_name}, password: ******")
+ response = self.http_client.request("PUT", f"/admin/users/{user_name}/password",
+ json_body={"new_password": encrypt_password(password)}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
+
+ def create_user(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name: str = user_name_tree.children[0].strip("'\"")
+ password_tree: Tree = command["password"]
+ password: str = password_tree.children[0].strip("'\"")
+ role: str = command["role"]
+ print(f"Create user: {user_name}, password: ******, role: {role}")
+ # enpass1 = encrypt(password)
+ enc_password = encrypt_password(password)
+ response = self.http_client.request(method="POST", path="/admin/users",
+ json_body={"username": user_name, "password": enc_password, "role": role},
+ use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def activate_user(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name: str = user_name_tree.children[0].strip("'\"")
+ activate_tree: Tree = command["activate_status"]
+ activate_status: str = activate_tree.children[0].strip("'\"")
+ if activate_status.lower() in ["on", "off"]:
+ print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
+ response = self.http_client.request("PUT", f"/admin/users/{user_name}/activate",
+ json_body={"activate_status": activate_status}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
+ else:
+ print(f"Unknown activate status: {activate_status}.")
+
+ def grant_admin(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name: str = user_name_tree.children[0].strip("'\"")
+ response = self.http_client.request("PUT", f"/admin/users/{user_name}/admin", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(
+ f"Fail to grant {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}")
+
+ def revoke_admin(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name: str = user_name_tree.children[0].strip("'\"")
+ response = self.http_client.request("DELETE", f"/admin/users/{user_name}/admin", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(
+ f"Fail to revoke {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}")
+
+ def create_role(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name: str = role_name_tree.children[0].strip("'\"")
+ desc_str: str = ""
+ if "description" in command and command["description"] is not None:
+ desc_tree: Tree = command["description"]
+ desc_str = desc_tree.children[0].strip("'\"")
+
+ print(f"create role name: {role_name}, description: {desc_str}")
+ response = self.http_client.request("POST", "/admin/roles",
+ json_body={"role_name": role_name, "description": desc_str},
+ use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_role(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name: str = role_name_tree.children[0].strip("'\"")
+ print(f"drop role name: {role_name}")
+ response = self.http_client.request("DELETE", f"/admin/roles/{role_name}",
+ use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def alter_role(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name: str = role_name_tree.children[0].strip("'\"")
+ desc_tree: Tree = command["description"]
+ desc_str: str = desc_tree.children[0].strip("'\"")
+
+ print(f"alter role name: {role_name}, description: {desc_str}")
+ response = self.http_client.request("PUT", f"/admin/roles/{role_name}",
+ json_body={"description": desc_str},
+ use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_roles(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/roles",
+ use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_role(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name: str = role_name_tree.children[0].strip("'\"")
+ print(f"show role: {role_name}")
+ response = self.http_client.request("GET", f"/admin/roles/{role_name}/permission",
+ use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
+
+ def grant_permission(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name_str: str = role_name_tree.children[0].strip("'\"")
+ resource_tree: Tree = command["resource"]
+ resource_str: str = resource_tree.children[0].strip("'\"")
+ action_tree_list: list = command["actions"]
+ actions: list = []
+ for action_tree in action_tree_list:
+ action_str: str = action_tree.children[0].strip("'\"")
+ actions.append(action_str)
+ print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
+ response = self.http_client.request("POST", f"/admin/roles/{role_name_str}/permission",
+ json_body={"actions": actions, "resource": resource_str}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def revoke_permission(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name_str: str = role_name_tree.children[0].strip("'\"")
+ resource_tree: Tree = command["resource"]
+ resource_str: str = resource_tree.children[0].strip("'\"")
+ action_tree_list: list = command["actions"]
+ actions: list = []
+ for action_tree in action_tree_list:
+ action_str: str = action_tree.children[0].strip("'\"")
+ actions.append(action_str)
+ print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
+ response = self.http_client.request("DELETE", f"/admin/roles/{role_name_str}/permission",
+ json_body={"actions": actions, "resource": resource_str}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def alter_user_role(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ role_name_tree: Tree = command["role_name"]
+ role_name_str: str = role_name_tree.children[0].strip("'\"")
+ user_name_tree: Tree = command["user_name"]
+ user_name_str: str = user_name_tree.children[0].strip("'\"")
+ print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
+ response = self.http_client.request("PUT", f"/admin/users/{user_name_str}/role",
+ json_body={"role_name": role_name_str}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_user_permission(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ user_name_tree: Tree = command["user_name"]
+ user_name_str: str = user_name_tree.children[0].strip("'\"")
+ print(f"show_user_permission user_name: {user_name_str}")
+ response = self.http_client.request("GET", f"/admin/users/{user_name_str}/permission", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
+
+ def generate_key(self, command: dict[str, Any]) -> None:
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Generating API key for user: {user_name}")
+ response = self.http_client.request("POST", f"/admin/users/{user_name}/keys", use_api_base=True,
+ auth_kind="admin")
+ res_json: dict[str, Any] = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(
+ f"Failed to generate key for user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_keys(self, command: dict[str, Any]) -> None:
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Listing API keys for user: {user_name}")
+ response = self.http_client.request("GET", f"/admin/users/{user_name}/keys", use_api_base=True,
+ auth_kind="admin")
+ res_json: dict[str, Any] = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Failed to list keys for user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_key(self, command: dict[str, Any]) -> None:
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ key_tree: Tree = command["key"]
+ key: str = key_tree.children[0].strip("'\"")
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Dropping API key for user: {user_name}")
+ # URL encode the key to handle special characters
+ encoded_key: str = urllib.parse.quote(key, safe="")
+ response = self.http_client.request("DELETE", f"/admin/users/{user_name}/keys/{encoded_key}", use_api_base=True,
+ auth_kind="admin")
+ res_json: dict[str, Any] = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(f"Failed to drop key for user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def set_variable(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ var_name_tree: Tree = command["var_name"]
+ var_name = var_name_tree.children[0].strip("'\"")
+ var_value_tree: Tree = command["var_value"]
+ var_value = var_value_tree.children[0].strip("'\"")
+ response = self.http_client.request("PUT", "/admin/variables",
+ json_body={"var_name": var_name, "var_value": var_value}, use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(res_json["message"])
+ else:
+ print(
+ f"Fail to set variable {var_name} to {var_value}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_variable(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ var_name_tree: Tree = command["var_name"]
+ var_name = var_name_tree.children[0].strip("'\"")
+ response = self.http_client.request(method="GET", path="/admin/variables", json_body={"var_name": var_name},
+ use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to get variable {var_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_variables(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/variables", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list variables, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_configs(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/configs", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list variables, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_environments(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ response = self.http_client.request("GET", "/admin/environments", use_api_base=True, auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list variables, code: {res_json['code']}, message: {res_json['message']}")
+
+ def handle_list_datasets(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Listing all datasets of user: {user_name}")
+
+ response = self.http_client.request("GET", f"/admin/users/{user_name}/datasets", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ table_data = res_json["data"]
+ for t in table_data:
+ t.pop("avatar")
+ self._print_table_simple(table_data)
+ else:
+ print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def handle_list_agents(self, command):
+ if self.server_type != "admin":
+ print("This command is only allowed in ADMIN mode")
+
+ username_tree: Tree = command["user_name"]
+ user_name: str = username_tree.children[0].strip("'\"")
+ print(f"Listing all agents of user: {user_name}")
+ response = self.http_client.request("GET", f"/admin/users/{user_name}/agents", use_api_base=True,
+ auth_kind="admin")
+ res_json = response.json()
+ if response.status_code == 200:
+ table_data = res_json["data"]
+ for t in table_data:
+ t.pop("avatar")
+ self._print_table_simple(table_data)
+ else:
+ print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_current_user(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ print("show current user")
+
+ def create_model_provider(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ llm_factory: str = command["provider_name"]
+ api_key: str = command["provider_key"]
+ payload = {"api_key": api_key, "llm_factory": llm_factory}
+ response = self.http_client.request("POST", "/llm/set_api_key", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to add model provider {llm_factory}")
+ else:
+ print(f"Fail to add model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_model_provider(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ llm_factory: str = command["provider_name"]
+ payload = {"llm_factory": llm_factory}
+ response = self.http_client.request("POST", "/llm/delete_factory", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to drop model provider {llm_factory}")
+ else:
+ print(
+ f"Fail to drop model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def set_default_model(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ model_type: str = command["model_type"]
+ model_id: str = command["model_id"]
+ self._set_default_models(model_type, model_id)
+
+ def reset_default_model(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ model_type: str = command["model_type"]
+ self._set_default_models(model_type, "")
+
+ def list_user_datasets(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ iterations = command.get("iterations", 1)
+ if iterations > 1:
+ response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web",
+ iterations=iterations)
+ return response
+ else:
+ response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"]["kbs"])
+ else:
+ print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}")
+ return None
+
+ def create_user_dataset(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ payload = {
+ "name": command["dataset_name"],
+ "embd_id": command["embedding"]
+ }
+ if "parser_id" in command:
+ payload["parser_id"] = command["parser"]
+ if "pipeline" in command:
+ payload["pipeline_id"] = command["pipeline"]
+ response = self.http_client.request("POST", "/kb/create", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to create datasets, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_user_dataset(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_name = command["dataset_name"]
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+ payload = {"kb_id": dataset_id}
+ response = self.http_client.request("POST", "/kb/rm", json_body=payload, use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ print(f"Drop dataset {dataset_name} successfully")
+ else:
+ print(f"Fail to drop datasets, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_user_dataset_files(self, command_dict):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_name = command_dict["dataset_name"]
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+
+ res_json = self._list_documents(dataset_name, dataset_id)
+ if res_json is None:
+ return
+ self._print_table_simple(res_json)
+
+ def list_user_agents(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ response = self.http_client.request("GET", "/canvas/list", use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ self._print_table_simple(res_json["data"])
+ else:
+ print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_user_chats(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ res_json = self._list_chats(command)
+ if res_json is None:
+ return None
+ if "iterations" in command:
+ # for benchmark
+ return res_json
+ self._print_table_simple(res_json)
+
+ def create_user_chat(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ '''
+ description
+ :
+ ""
+ icon
+ :
+ ""
+ language
+ :
+ "English"
+ llm_id
+ :
+ "glm-4-flash@ZHIPU-AI"
+ llm_setting
+ :
+ {}
+ name
+ :
+ "xx"
+ prompt_config
+ :
+ {empty_response: "", prologue: "Hi! I'm your assistant. What can I do for you?", quote: true,…}
+ empty_response
+ :
+ ""
+ keyword
+ :
+ false
+ parameters
+ :
+ [{key: "knowledge", optional: false}]
+ prologue
+ :
+ "Hi! I'm your assistant. What can I do for you?"
+ quote
+ :
+ true
+ reasoning
+ :
+ false
+ refine_multiturn
+ :
+ false
+ system
+ :
+ "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base.\n\n **Essential Rules:**\n - Your answer must be derived **solely** from this knowledge base: `{knowledge}`.\n - **When information is available**: Summarize the content to give a detailed answer.\n - **When information is unavailable**: Your response must contain this exact sentence: \"The answer you are looking for is not found in the knowledge base!\"\n - **Always consider** the entire conversation history."
+ toc_enhance
+ :
+ false
+ tts
+ :
+ false
+ use_kg
+ :
+ false
+ similarity_threshold
+ :
+ 0.2
+ top_n
+ :
+ 8
+ vector_similarity_weight
+ :
+ 0.3
+ '''
+ chat_name = command["chat_name"]
+ payload = {
+ "description": "",
+ "icon": "",
+ "language": "English",
+ "llm_setting": {},
+ "prompt_config": {
+ "empty_response": "",
+ "prologue": "Hi! I'm your assistant. What can I do for you?",
+ "quote": True,
+ "keyword": False,
+ "tts": False,
+ "system": "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base.\n\n **Essential Rules:**\n - Your answer must be derived **solely** from this knowledge base: `{knowledge}`.\n - **When information is available**: Summarize the content to give a detailed answer.\n - **When information is unavailable**: Your response must contain this exact sentence: \"The answer you are looking for is not found in the knowledge base!\"\n - **Always consider** the entire conversation history.",
+ "refine_multiturn": False,
+ "use_kg": False,
+ "reasoning": False,
+ "parameters": [
+ {
+ "key": "knowledge",
+ "optional": False
+ }
+ ],
+ "toc_enhance": False
+ },
+ "similarity_threshold": 0.2,
+ "top_n": 8,
+ "vector_similarity_weight": 0.3
+ }
+
+ payload.update({"name": chat_name})
+ response = self.http_client.request("POST", "/dialog/set", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to create chat: {chat_name}")
+ else:
+ print(f"Fail to create chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def drop_user_chat(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+ chat_name = command["chat_name"]
+ res_json = self._list_chats(command)
+ to_drop_chat_ids = []
+ for elem in res_json:
+ if elem["name"] == chat_name:
+ to_drop_chat_ids.append(elem["id"])
+ payload = {"dialog_ids": to_drop_chat_ids}
+ response = self.http_client.request("POST", "/dialog/rm", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to drop chat: {chat_name}")
+ else:
+ print(f"Fail to drop chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_user_model_providers(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ response = self.http_client.request("GET", "/llm/my_llms", use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ new_input = []
+ for key, value in res_json["data"].items():
+ new_input.append({"model provider": key, "models": value})
+ self._print_table_simple(new_input)
+ else:
+ print(f"Fail to list model provider, code: {res_json['code']}, message: {res_json['message']}")
+
+ def list_user_default_models(self, command):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ res_json = self._get_default_models()
+ if res_json is None:
+ return
+ else:
+ new_input = []
+ for key, value in res_json.items():
+ if key == "asr_id" and value != "":
+ new_input.append({"model_category": "ASR", "model_name": value})
+ elif key == "embd_id" and value != "":
+ new_input.append({"model_category": "Embedding", "model_name": value})
+ elif key == "llm_id" and value != "":
+ new_input.append({"model_category": "LLM", "model_name": value})
+ elif key == "rerank_id" and value != "":
+ new_input.append({"model_category": "Reranker", "model_name": value})
+ elif key == "tts_id" and value != "":
+ new_input.append({"model_category": "TTS", "model_name": value})
+ elif key == "img2txt_id" and value != "":
+ new_input.append({"model_category": "VLM", "model_name": value})
+ else:
+ continue
+ self._print_table_simple(new_input)
+
+ def parse_dataset_docs(self, command_dict):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_name = command_dict["dataset_name"]
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+
+ res_json = self._list_documents(dataset_name, dataset_id)
+ if res_json is None:
+ return
+
+ document_names = command_dict["document_names"]
+ document_ids = []
+ to_parse_doc_names = []
+ for doc in res_json:
+ doc_name = doc["name"]
+ if doc_name in document_names:
+ document_ids.append(doc["id"])
+ document_names.remove(doc_name)
+ to_parse_doc_names.append(doc_name)
+
+ if len(document_ids) == 0:
+ print(f"No documents found in {dataset_name}")
+ return
+
+ if len(document_names) != 0:
+ print(f"Documents {document_names} not found in {dataset_name}")
+
+ payload = {"doc_ids": document_ids, "run": 1}
+ response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to parse {to_parse_doc_names} of {dataset_name}")
+ else:
+ print(
+ f"Fail to parse documents {res_json["data"]["docs"]}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def parse_dataset(self, command_dict):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_name = command_dict["dataset_name"]
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+
+ res_json = self._list_documents(dataset_name, dataset_id)
+ if res_json is None:
+ return
+ document_ids = []
+ for doc in res_json:
+ document_ids.append(doc["id"])
+
+ payload = {"doc_ids": document_ids, "run": 1}
+ response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ pass
+ else:
+ print(f"Fail to parse dataset {dataset_name}, code: {res_json['code']}, message: {res_json['message']}")
+
+ if command_dict["method"] == "async":
+ print(f"Success to start parse dataset {dataset_name}")
+ return
+ else:
+ print(f"Start to parse dataset {dataset_name}, please wait...")
+ if self._wait_parse_done(dataset_name, dataset_id):
+ print(f"Success to parse dataset {dataset_name}")
+ else:
+ print(f"Parse dataset {dataset_name} timeout")
+
+ def import_docs_into_dataset(self, command_dict):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_name = command_dict["dataset_name"]
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+
+ document_paths = command_dict["document_paths"]
+ paths = [Path(p) for p in document_paths]
+
+ fields = []
+ file_handles = []
+ try:
+ for path in paths:
+ fh = path.open("rb")
+ fields.append(("file", (path.name, fh)))
+ file_handles.append(fh)
+ fields.append(("kb_id", dataset_id))
+ encoder = MultipartEncoder(fields=fields)
+ headers = {"Content-Type": encoder.content_type}
+ response = self.http_client.request(
+ "POST",
+ "/document/upload",
+ headers=headers,
+ data=encoder,
+ json_body=None,
+ params=None,
+ stream=False,
+ auth_kind="web",
+ use_api_base=False
+ )
+ res = response.json()
+ if res.get("code") == 0:
+ print(f"Success to import documents into dataset {dataset_name}")
+ else:
+ print(f"Fail to import documents: code: {res['code']}, message: {res['message']}")
+ except Exception as exc:
+ print(f"Fail to import document into dataset: {dataset_name}, error: {exc}")
+ finally:
+ for fh in file_handles:
+ fh.close()
+
+ def search_on_datasets(self, command_dict):
+ if self.server_type != "user":
+ print("This command is only allowed in USER mode")
+
+ dataset_names = command_dict["datasets"]
+ dataset_ids = []
+ for dataset_name in dataset_names:
+ dataset_id = self._get_dataset_id(dataset_name)
+ if dataset_id is None:
+ return
+ dataset_ids.append(dataset_id)
+
+ payload = {
+ "question": command_dict["question"],
+ "kb_id": dataset_ids,
+ "similarity_threshold": 0.2,
+ "vector_similarity_weight": 0.3,
+ # "top_k": 1024,
+ # "kb_id": command_dict["datasets"][0],
+ }
+ iterations = command_dict.get("iterations", 1)
+ if iterations > 1:
+ response = self.http_client.request("POST", "/chunk/retrieval_test", json_body=payload, use_api_base=False,
+ auth_kind="web", iterations=iterations)
+ return response
+ else:
+ response = self.http_client.request("POST", "/chunk/retrieval_test", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ if res_json["code"] == 0:
+ self._print_table_simple(res_json["data"]["chunks"])
+ else:
+ print(
+ f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}")
+ else:
+ print(
+ f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def show_version(self, command):
+ if self.server_type == "admin":
+ response = self.http_client.request("GET", "/admin/version", use_api_base=True, auth_kind="admin")
+ else:
+ response = self.http_client.request("GET", "/system/version", use_api_base=False, auth_kind="admin")
+
+ res_json = response.json()
+ if response.status_code == 200:
+ if self.server_type == "admin":
+ self._print_table_simple(res_json["data"])
+ else:
+ self._print_table_simple({"version": res_json["data"]})
+ else:
+ print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
+
+ def _wait_parse_done(self, dataset_name: str, dataset_id: str):
+ start = time.monotonic()
+ while True:
+ docs = self._list_documents(dataset_name, dataset_id)
+ if docs is None:
+ return False
+ all_done = True
+ for doc in docs:
+ if doc.get("run") != "3":
+ print(f"Document {doc["name"]} is not done, status: {doc.get("run")}")
+ all_done = False
+ break
+ if all_done:
+ return True
+ if time.monotonic() - start > 60:
+ return False
+ time.sleep(0.5)
+
+ def _list_documents(self, dataset_name: str, dataset_id: str):
+ response = self.http_client.request("POST", f"/document/list?kb_id={dataset_id}", use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code != 200:
+ print(
+ f"Fail to list files from dataset {dataset_name}, code: {res_json['code']}, message: {res_json['message']}")
+ return None
+ return res_json["data"]["docs"]
+
+ def _get_dataset_id(self, dataset_name: str):
+ response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code != 200:
+ print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}")
+ return None
+
+ dataset_list = res_json["data"]["kbs"]
+ dataset_id: str = ""
+ for dataset in dataset_list:
+ if dataset["name"] == dataset_name:
+ dataset_id = dataset["id"]
+
+ if dataset_id == "":
+ print(f"Dataset {dataset_name} not found")
+ return None
+ return dataset_id
+
+ def _list_chats(self, command):
+ iterations = command.get("iterations", 1)
+ if iterations > 1:
+ response = self.http_client.request("POST", "/dialog/next", use_api_base=False, auth_kind="web",
+ iterations=iterations)
+ return response
+ else:
+ response = self.http_client.request("POST", "/dialog/next", use_api_base=False, auth_kind="web",
+ iterations=iterations)
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ return res_json["data"]["dialogs"]
+ else:
+ print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}")
+ return None
+
+ def _get_default_models(self):
+ response = self.http_client.request("GET", "/user/tenant_info", use_api_base=False, auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200:
+ if res_json["code"] == 0:
+ return res_json["data"]
+ else:
+ print(f"Fail to list user default models, code: {res_json['code']}, message: {res_json['message']}")
+ return None
+ else:
+ print(f"Fail to list user default models, HTTP code: {response.status_code}, message: {res_json}")
+ return None
+
+ def _set_default_models(self, model_type, model_id):
+ current_payload = self._get_default_models()
+ if current_payload is None:
+ return
+ else:
+ current_payload.update({model_type: model_id})
+ payload = {
+ "tenant_id": current_payload["tenant_id"],
+ "llm_id": current_payload["llm_id"],
+ "embd_id": current_payload["embd_id"],
+ "img2txt_id": current_payload["img2txt_id"],
+ "asr_id": current_payload["asr_id"],
+ "tts_id": current_payload["tts_id"],
+ }
+ response = self.http_client.request("POST", "/user/set_tenant_info", json_body=payload, use_api_base=False,
+ auth_kind="web")
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ print(f"Success to set default llm to {model_type}")
+ else:
+ print(f"Fail to set default llm to {model_type}, code: {res_json['code']}, message: {res_json['message']}")
+
+ def _format_service_detail_table(self, data):
+ if isinstance(data, list):
+ return data
+ if not all([isinstance(v, list) for v in data.values()]):
+ # normal table
+ return data
+ # handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
+ task_executor_list = []
+ for k, v in data.items():
+ # display latest status
+ heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
+ task_executor_list.append(
+ {
+ "task_executor_name": k,
+ **heartbeats[0],
+ }
+ if heartbeats
+ else {"task_executor_name": k}
+ )
+ return task_executor_list
+
+ def _print_table_simple(self, data):
+ if not data:
+ print("No data to print")
+ return
+ if isinstance(data, dict):
+ # handle single row data
+ data = [data]
+
+ columns = list(set().union(*(d.keys() for d in data)))
+ columns.sort()
+ col_widths = {}
+
+ def get_string_width(text):
+ half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
+ width = 0
+ for char in text:
+ if char in half_width_chars:
+ width += 1
+ else:
+ width += 2
+ return width
+
+ for col in columns:
+ max_width = get_string_width(str(col))
+ for item in data:
+ value_len = get_string_width(str(item.get(col, "")))
+ if value_len > max_width:
+ max_width = value_len
+ col_widths[col] = max(2, max_width)
+
+ # Generate delimiter
+ separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
+
+ # Print header
+ print(separator)
+ header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
+ print(header)
+ print(separator)
+
+ # Print data
+ for item in data:
+ row = "|"
+ for col in columns:
+ value = str(item.get(col, ""))
+ if get_string_width(value) > col_widths[col]:
+ value = value[: col_widths[col] - 3] + "..."
+ row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
+ print(row)
+
+ print(separator)
+
+
+def run_command(client: RAGFlowClient, command_dict: dict):
+ command_type = command_dict["type"]
+
+ match command_type:
+ case "benchmark":
+ run_benchmark(client, command_dict)
+ case "login_user":
+ client.login_user(command_dict)
+ case "ping_server":
+ return client.ping_server(command_dict)
+ case "register_user":
+ client.register_user(command_dict)
+ case "list_services":
+ client.list_services()
+ case "show_service":
+ client.show_service(command_dict)
+ case "restart_service":
+ client.restart_service(command_dict)
+ case "shutdown_service":
+ client.shutdown_service(command_dict)
+ case "startup_service":
+ client.startup_service(command_dict)
+ case "list_users":
+ client.list_users(command_dict)
+ case "show_user":
+ client.show_user(command_dict)
+ case "drop_user":
+ client.drop_user(command_dict)
+ case "alter_user":
+ client.alter_user(command_dict)
+ case "create_user":
+ client.create_user(command_dict)
+ case "activate_user":
+ client.activate_user(command_dict)
+ case "list_datasets":
+ client.handle_list_datasets(command_dict)
+ case "list_agents":
+ client.handle_list_agents(command_dict)
+ case "create_role":
+ client.create_role(command_dict)
+ case "drop_role":
+ client.drop_role(command_dict)
+ case "alter_role":
+ client.alter_role(command_dict)
+ case "list_roles":
+ client.list_roles(command_dict)
+ case "show_role":
+ client.show_role(command_dict)
+ case "grant_permission":
+ client.grant_permission(command_dict)
+ case "revoke_permission":
+ client.revoke_permission(command_dict)
+ case "alter_user_role":
+ client.alter_user_role(command_dict)
+ case "show_user_permission":
+ client.show_user_permission(command_dict)
+ case "show_version":
+ client.show_version(command_dict)
+ case "grant_admin":
+ client.grant_admin(command_dict)
+ case "revoke_admin":
+ client.revoke_admin(command_dict)
+ case "generate_key":
+ client.generate_key(command_dict)
+ case "list_keys":
+ client.list_keys(command_dict)
+ case "drop_key":
+ client.drop_key(command_dict)
+ case "set_variable":
+ client.set_variable(command_dict)
+ case "show_variable":
+ client.show_variable(command_dict)
+ case "list_variables":
+ client.list_variables(command_dict)
+ case "list_configs":
+ client.list_configs(command_dict)
+ case "list_environments":
+ client.list_environments(command_dict)
+ case "create_model_provider":
+ client.create_model_provider(command_dict)
+ case "drop_model_provider":
+ client.drop_model_provider(command_dict)
+ case "show_current_user":
+ client.show_current_user(command_dict)
+ case "set_default_model":
+ client.set_default_model(command_dict)
+ case "reset_default_model":
+ client.reset_default_model(command_dict)
+ case "list_user_datasets":
+ return client.list_user_datasets(command_dict)
+ case "create_user_dataset":
+ client.create_user_dataset(command_dict)
+ case "drop_user_dataset":
+ client.drop_user_dataset(command_dict)
+ case "list_user_dataset_files":
+ return client.list_user_dataset_files(command_dict)
+ case "list_user_agents":
+ return client.list_user_agents(command_dict)
+ case "list_user_chats":
+ return client.list_user_chats(command_dict)
+ case "create_user_chat":
+ client.create_user_chat(command_dict)
+ case "drop_user_chat":
+ client.drop_user_chat(command_dict)
+ case "list_user_model_providers":
+ client.list_user_model_providers(command_dict)
+ case "list_user_default_models":
+ client.list_user_default_models(command_dict)
+ case "parse_dataset_docs":
+ client.parse_dataset_docs(command_dict)
+ case "parse_dataset":
+ client.parse_dataset(command_dict)
+ case "import_docs_into_dataset":
+ client.import_docs_into_dataset(command_dict)
+ case "search_on_datasets":
+ return client.search_on_datasets(command_dict)
+ case "meta":
+ _handle_meta_command(command_dict)
+ case _:
+ print(f"Command '{command_type}' would be executed with API")
+
+
+def _handle_meta_command(command: dict):
+ meta_command = command["command"]
+ args = command.get("args", [])
+
+ if meta_command in ["?", "h", "help"]:
+ show_help()
+ elif meta_command in ["q", "quit", "exit"]:
+ print("Goodbye!")
+ else:
+ print(f"Meta command '{meta_command}' with args {args}")
+
+
+def show_help():
+ """Help info"""
+ help_text = """
+Commands:
+LIST SERVICES
+SHOW SERVICE
+STARTUP SERVICE
+SHUTDOWN SERVICE
+RESTART SERVICE
+LIST USERS
+SHOW USER
+DROP USER
+CREATE USER
+ALTER USER PASSWORD
+ALTER USER ACTIVE
+LIST DATASETS OF
+LIST AGENTS OF
+CREATE ROLE
+DROP ROLE
+ALTER ROLE SET DESCRIPTION
+LIST ROLES
+SHOW ROLE
+GRANT ON TO ROLE
+REVOKE ON TO ROLE
+ALTER USER SET ROLE
+SHOW USER PERMISSION
+SHOW VERSION
+GRANT ADMIN
+REVOKE ADMIN
+GENERATE KEY FOR USER
+LIST KEYS OF
+DROP KEY OF
+
+Meta Commands:
+\\?, \\h, \\help Show this help
+\\q, \\quit, \\exit Quit the CLI
+ """
+ print(help_text)
+
+
+def run_benchmark(client: RAGFlowClient, command_dict: dict):
+ concurrency = command_dict.get("concurrency", 1)
+ iterations = command_dict.get("iterations", 1)
+ command: dict = command_dict["command"]
+ command.update({"iterations": iterations})
+
+ command_type = command["type"]
+ if concurrency < 1:
+ print("Concurrency must be greater than 0")
+ return
+ elif concurrency == 1:
+ result = run_command(client, command)
+ success_count: int = 0
+ response_list = result["response_list"]
+ for response in response_list:
+ match command_type:
+ case "ping_server":
+ if response.status_code == 200:
+ success_count += 1
+ case _:
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ success_count += 1
+
+ total_duration = result["duration"]
+ qps = iterations / total_duration if total_duration > 0 else None
+ print(f"command: {command}, Concurrency: {concurrency}, iterations: {iterations}")
+ print(
+ f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {iterations}, SUCCESS: {success_count}, FAILURE: {iterations - success_count}")
+ pass
+ else:
+ results: List[Optional[dict]] = [None] * concurrency
+ mp_context = mp.get_context("spawn")
+ start_time = time.perf_counter()
+ with ProcessPoolExecutor(max_workers=concurrency, mp_context=mp_context) as executor:
+ future_map = {
+ executor.submit(
+ run_command,
+ client,
+ command
+ ): idx
+ for idx in range(concurrency)
+ }
+ for future in as_completed(future_map):
+ idx = future_map[future]
+ results[idx] = future.result()
+ end_time = time.perf_counter()
+ success_count = 0
+ for result in results:
+ response_list = result["response_list"]
+ for response in response_list:
+ match command_type:
+ case "ping_server":
+ if response.status_code == 200:
+ success_count += 1
+ case _:
+ res_json = response.json()
+ if response.status_code == 200 and res_json["code"] == 0:
+ success_count += 1
+
+ total_duration = end_time - start_time
+ total_command_count = iterations * concurrency
+ qps = total_command_count / total_duration if total_duration > 0 else None
+ print(f"command: {command}, Concurrency: {concurrency} , iterations: {iterations}")
+ print(
+ f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {total_command_count}, SUCCESS: {success_count}, FAILURE: {total_command_count - success_count}")
+
+ pass
diff --git a/admin/client/user.py b/admin/client/user.py
new file mode 100644
index 00000000000..823e2a13001
--- /dev/null
+++ b/admin/client/user.py
@@ -0,0 +1,65 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from http_client import HttpClient
+
+
+class AuthException(Exception):
+ def __init__(self, message, code=401):
+ super().__init__(message)
+ self.code = code
+ self.message = message
+
+
+def encrypt_password(password_plain: str) -> str:
+ try:
+ from api.utils.crypt import crypt
+ except Exception as exc:
+ raise AuthException(
+ "Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
+ ) from exc
+ return crypt(password_plain)
+
+
+def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
+ password_enc = encrypt_password(password)
+ payload = {"email": email, "nickname": nickname, "password": password_enc}
+ res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
+ if res.get("code") == 0:
+ return
+ msg = res.get("message", "")
+ if "has already registered" in msg:
+ return
+ raise AuthException(f"Register failed: {msg}")
+
+
+def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
+ password_enc = encrypt_password(password)
+ payload = {"email": email, "password": password_enc}
+ if server_type == "admin":
+ response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
+ else:
+ response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
+ try:
+ res = response.json()
+ except Exception as exc:
+ raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
+ if res.get("code") != 0:
+ raise AuthException(f"Login failed: {res.get('message')}")
+ token = response.headers.get("Authorization")
+ if not token:
+ raise AuthException("Login failed: missing Authorization header")
+ return token
diff --git a/admin/client/uv.lock b/admin/client/uv.lock
index 7e38b7144c0..6a0fa57faf2 100644
--- a/admin/client/uv.lock
+++ b/admin/client/uv.lock
@@ -196,7 +196,7 @@ wheels = [
[[package]]
name = "ragflow-cli"
-version = "0.23.1"
+version = "0.24.0"
source = { virtual = "." }
dependencies = [
{ name = "beartype" },
diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py
index b8c96a62c45..2fbb4174c02 100644
--- a/admin/server/admin_server.py
+++ b/admin/server/admin_server.py
@@ -14,10 +14,12 @@
# limitations under the License.
#
+import time
+start_ts = time.time()
+
import os
import signal
import logging
-import time
import threading
import traceback
import faulthandler
@@ -66,7 +68,7 @@
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
try:
- logging.info("RAGFlow Admin service start...")
+ logging.info(f"RAGFlow admin is ready after {time.time() - start_ts}s initialization.")
run_simple(
hostname="0.0.0.0",
port=9381,
diff --git a/admin/server/auth.py b/admin/server/auth.py
index 486b9a4fbf7..30d3bd4dd79 100644
--- a/admin/server/auth.py
+++ b/admin/server/auth.py
@@ -27,6 +27,8 @@
from api.common.exceptions import AdminException, UserNotFoundError
from api.common.base64 import encode_to_base64
from api.db.services import UserService
+from api.db import UserTenantRole
+from api.db.services.user_service import TenantService, UserTenantService
from common.constants import ActiveEnum, StatusEnum
from api.utils.crypt import decrypt
from common.misc_utils import get_uuid
@@ -85,8 +87,44 @@ def init_default_admin():
}
if not UserService.save(**default_admin):
raise AdminException("Can't init admin.", 500)
+ add_tenant_for_admin(default_admin, UserTenantRole.OWNER)
elif not any([u.is_active == ActiveEnum.ACTIVE.value for u in users]):
raise AdminException("No active admin. Please update 'is_active' in db manually.", 500)
+ else:
+ default_admin_rows = [u for u in users if u.email == "admin@ragflow.io"]
+ if default_admin_rows:
+ default_admin = default_admin_rows[0].to_dict()
+ exist, default_admin_tenant = TenantService.get_by_id(default_admin["id"])
+ if not exist:
+ add_tenant_for_admin(default_admin, UserTenantRole.OWNER)
+
+
+def add_tenant_for_admin(user_info: dict, role: str):
+ from api.db.services.tenant_llm_service import TenantLLMService
+ from api.db.services.llm_service import get_init_tenant_llm
+
+ tenant = {
+ "id": user_info["id"],
+ "name": user_info["nickname"] + "‘s Kingdom",
+ "llm_id": settings.CHAT_MDL,
+ "embd_id": settings.EMBEDDING_MDL,
+ "asr_id": settings.ASR_MDL,
+ "parser_ids": settings.PARSERS,
+ "img2txt_id": settings.IMAGE2TEXT_MDL
+ }
+ usr_tenant = {
+ "tenant_id": user_info["id"],
+ "user_id": user_info["id"],
+ "invited_by": user_info["id"],
+ "role": role
+ }
+
+ tenant_llm = get_init_tenant_llm(user_info["id"])
+ TenantService.insert(**tenant)
+ UserTenantService.insert(**usr_tenant)
+ TenantLLMService.insert_many(tenant_llm)
+ logging.info(
+ f"Added tenant for email: {user_info['email']}, A default tenant has been set; changing the default models after login is strongly recommended.")
def check_admin_auth(func):
diff --git a/admin/server/routes.py b/admin/server/routes.py
index e83f3ff08e1..53b0f43206e 100644
--- a/admin/server/routes.py
+++ b/admin/server/routes.py
@@ -15,29 +15,34 @@
#
import secrets
+import logging
+from typing import Any
-from flask import Blueprint, request
+from common.time_utils import current_timestamp, datetime_format
+from datetime import datetime
+from flask import Blueprint, Response, request
from flask_login import current_user, login_required, logout_user
from auth import login_verify, login_admin, check_admin_auth
from responses import success_response, error_response
-from services import UserMgr, ServiceMgr, UserServiceMgr
+from services import UserMgr, ServiceMgr, UserServiceMgr, SettingsMgr, ConfigMgr, EnvironmentsMgr, SandboxMgr
from roles import RoleMgr
from api.common.exceptions import AdminException
from common.versions import get_ragflow_version
+from api.utils.api_utils import generate_confirmation_token
-admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
+admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
-@admin_bp.route('/ping', methods=['GET'])
+@admin_bp.route("/ping", methods=["GET"])
def ping():
- return success_response('PONG')
+ return success_response("PONG")
-@admin_bp.route('/login', methods=['POST'])
+@admin_bp.route("/login", methods=["POST"])
def login():
if not request.json:
- return error_response('Authorize admin failed.' ,400)
+ return error_response("Authorize admin failed.", 400)
try:
email = request.json.get("email", "")
password = request.json.get("password", "")
@@ -46,7 +51,7 @@ def login():
return error_response(str(e), 500)
-@admin_bp.route('/logout', methods=['GET'])
+@admin_bp.route("/logout", methods=["GET"])
@login_required
def logout():
try:
@@ -58,7 +63,7 @@ def logout():
return error_response(str(e), 500)
-@admin_bp.route('/auth', methods=['GET'])
+@admin_bp.route("/auth", methods=["GET"])
@login_verify
def auth_admin():
try:
@@ -67,7 +72,7 @@ def auth_admin():
return error_response(str(e), 500)
-@admin_bp.route('/users', methods=['GET'])
+@admin_bp.route("/users", methods=["GET"])
@login_required
@check_admin_auth
def list_users():
@@ -78,18 +83,18 @@ def list_users():
return error_response(str(e), 500)
-@admin_bp.route('/users', methods=['POST'])
+@admin_bp.route("/users", methods=["POST"])
@login_required
@check_admin_auth
def create_user():
try:
data = request.get_json()
- if not data or 'username' not in data or 'password' not in data:
+ if not data or "username" not in data or "password" not in data:
return error_response("Username and password are required", 400)
- username = data['username']
- password = data['password']
- role = data.get('role', 'user')
+ username = data["username"]
+ password = data["password"]
+ role = data.get("role", "user")
res = UserMgr.create_user(username, password, role)
if res["success"]:
@@ -105,7 +110,7 @@ def create_user():
return error_response(str(e))
-@admin_bp.route('/users/', methods=['DELETE'])
+@admin_bp.route("/users/", methods=["DELETE"])
@login_required
@check_admin_auth
def delete_user(username):
@@ -122,16 +127,16 @@ def delete_user(username):
return error_response(str(e), 500)
-@admin_bp.route('/users//password', methods=['PUT'])
+@admin_bp.route("/users//password", methods=["PUT"])
@login_required
@check_admin_auth
def change_password(username):
try:
data = request.get_json()
- if not data or 'new_password' not in data:
+ if not data or "new_password" not in data:
return error_response("New password is required", 400)
- new_password = data['new_password']
+ new_password = data["new_password"]
msg = UserMgr.update_user_password(username, new_password)
return success_response(None, msg)
@@ -141,15 +146,15 @@ def change_password(username):
return error_response(str(e), 500)
-@admin_bp.route('/users//activate', methods=['PUT'])
+@admin_bp.route("/users//activate", methods=["PUT"])
@login_required
@check_admin_auth
def alter_user_activate_status(username):
try:
data = request.get_json()
- if not data or 'activate_status' not in data:
+ if not data or "activate_status" not in data:
return error_response("Activation status is required", 400)
- activate_status = data['activate_status']
+ activate_status = data["activate_status"]
msg = UserMgr.update_user_activate_status(username, activate_status)
return success_response(None, msg)
except AdminException as e:
@@ -158,7 +163,39 @@ def alter_user_activate_status(username):
return error_response(str(e), 500)
-@admin_bp.route('/users/', methods=['GET'])
+@admin_bp.route("/users//admin", methods=["PUT"])
+@login_required
+@check_admin_auth
+def grant_admin(username):
+ try:
+ if current_user.email == username:
+ return error_response(f"can't grant current user: {username}", 409)
+ msg = UserMgr.grant_admin(username)
+ return success_response(None, msg)
+
+ except AdminException as e:
+ return error_response(e.message, e.code)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/users//admin", methods=["DELETE"])
+@login_required
+@check_admin_auth
+def revoke_admin(username):
+ try:
+ if current_user.email == username:
+ return error_response(f"can't grant current user: {username}", 409)
+ msg = UserMgr.revoke_admin(username)
+ return success_response(None, msg)
+
+ except AdminException as e:
+ return error_response(e.message, e.code)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/users/", methods=["GET"])
@login_required
@check_admin_auth
def get_user_details(username):
@@ -172,7 +209,7 @@ def get_user_details(username):
return error_response(str(e), 500)
-@admin_bp.route('/users//datasets', methods=['GET'])
+@admin_bp.route("/users//datasets", methods=["GET"])
@login_required
@check_admin_auth
def get_user_datasets(username):
@@ -186,7 +223,7 @@ def get_user_datasets(username):
return error_response(str(e), 500)
-@admin_bp.route('/users//agents', methods=['GET'])
+@admin_bp.route("/users//agents", methods=["GET"])
@login_required
@check_admin_auth
def get_user_agents(username):
@@ -200,7 +237,7 @@ def get_user_agents(username):
return error_response(str(e), 500)
-@admin_bp.route('/services', methods=['GET'])
+@admin_bp.route("/services", methods=["GET"])
@login_required
@check_admin_auth
def get_services():
@@ -211,7 +248,7 @@ def get_services():
return error_response(str(e), 500)
-@admin_bp.route('/service_types/', methods=['GET'])
+@admin_bp.route("/service_types/", methods=["GET"])
@login_required
@check_admin_auth
def get_services_by_type(service_type_str):
@@ -222,7 +259,7 @@ def get_services_by_type(service_type_str):
return error_response(str(e), 500)
-@admin_bp.route('/services/', methods=['GET'])
+@admin_bp.route("/services/", methods=["GET"])
@login_required
@check_admin_auth
def get_service(service_id):
@@ -233,7 +270,7 @@ def get_service(service_id):
return error_response(str(e), 500)
-@admin_bp.route('/services/', methods=['DELETE'])
+@admin_bp.route("/services/", methods=["DELETE"])
@login_required
@check_admin_auth
def shutdown_service(service_id):
@@ -244,7 +281,7 @@ def shutdown_service(service_id):
return error_response(str(e), 500)
-@admin_bp.route('/services/', methods=['PUT'])
+@admin_bp.route("/services/", methods=["PUT"])
@login_required
@check_admin_auth
def restart_service(service_id):
@@ -255,38 +292,38 @@ def restart_service(service_id):
return error_response(str(e), 500)
-@admin_bp.route('/roles', methods=['POST'])
+@admin_bp.route("/roles", methods=["POST"])
@login_required
@check_admin_auth
def create_role():
try:
data = request.get_json()
- if not data or 'role_name' not in data:
+ if not data or "role_name" not in data:
return error_response("Role name is required", 400)
- role_name: str = data['role_name']
- description: str = data['description']
+ role_name: str = data["role_name"]
+ description: str = data["description"]
res = RoleMgr.create_role(role_name, description)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/roles/', methods=['PUT'])
+@admin_bp.route("/roles/", methods=["PUT"])
@login_required
@check_admin_auth
def update_role(role_name: str):
try:
data = request.get_json()
- if not data or 'description' not in data:
+ if not data or "description" not in data:
return error_response("Role description is required", 400)
- description: str = data['description']
+ description: str = data["description"]
res = RoleMgr.update_role_description(role_name, description)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/roles/', methods=['DELETE'])
+@admin_bp.route("/roles/", methods=["DELETE"])
@login_required
@check_admin_auth
def delete_role(role_name: str):
@@ -297,7 +334,7 @@ def delete_role(role_name: str):
return error_response(str(e), 500)
-@admin_bp.route('/roles', methods=['GET'])
+@admin_bp.route("/roles", methods=["GET"])
@login_required
@check_admin_auth
def list_roles():
@@ -308,7 +345,7 @@ def list_roles():
return error_response(str(e), 500)
-@admin_bp.route('/roles//permission', methods=['GET'])
+@admin_bp.route("/roles//permission", methods=["GET"])
@login_required
@check_admin_auth
def get_role_permission(role_name: str):
@@ -319,54 +356,54 @@ def get_role_permission(role_name: str):
return error_response(str(e), 500)
-@admin_bp.route('/roles//permission', methods=['POST'])
+@admin_bp.route("/roles//permission", methods=["POST"])
@login_required
@check_admin_auth
def grant_role_permission(role_name: str):
try:
data = request.get_json()
- if not data or 'actions' not in data or 'resource' not in data:
+ if not data or "actions" not in data or "resource" not in data:
return error_response("Permission is required", 400)
- actions: list = data['actions']
- resource: str = data['resource']
+ actions: list = data["actions"]
+ resource: str = data["resource"]
res = RoleMgr.grant_role_permission(role_name, actions, resource)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/roles//permission', methods=['DELETE'])
+@admin_bp.route("/roles//permission", methods=["DELETE"])
@login_required
@check_admin_auth
def revoke_role_permission(role_name: str):
try:
data = request.get_json()
- if not data or 'actions' not in data or 'resource' not in data:
+ if not data or "actions" not in data or "resource" not in data:
return error_response("Permission is required", 400)
- actions: list = data['actions']
- resource: str = data['resource']
+ actions: list = data["actions"]
+ resource: str = data["resource"]
res = RoleMgr.revoke_role_permission(role_name, actions, resource)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/users//role', methods=['PUT'])
+@admin_bp.route("/users//role", methods=["PUT"])
@login_required
@check_admin_auth
def update_user_role(user_name: str):
try:
data = request.get_json()
- if not data or 'role_name' not in data:
+ if not data or "role_name" not in data:
return error_response("Role name is required", 400)
- role_name: str = data['role_name']
+ role_name: str = data["role_name"]
res = RoleMgr.update_user_role(user_name, role_name)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/users//permission', methods=['GET'])
+@admin_bp.route("/users//permission", methods=["GET"])
@login_required
@check_admin_auth
def get_user_permission(user_name: str):
@@ -376,7 +413,140 @@ def get_user_permission(user_name: str):
except Exception as e:
return error_response(str(e), 500)
-@admin_bp.route('/version', methods=['GET'])
+
+@admin_bp.route("/variables", methods=["PUT"])
+@login_required
+@check_admin_auth
+def set_variable():
+ try:
+ data = request.get_json()
+ if not data and "var_name" not in data:
+ return error_response("Var name is required", 400)
+
+ if "var_value" not in data:
+ return error_response("Var value is required", 400)
+ var_name: str = data["var_name"]
+ var_value: str = data["var_value"]
+
+ SettingsMgr.update_by_name(var_name, var_value)
+ return success_response(None, "Set variable successfully")
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/variables", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_variable():
+ try:
+ if request.content_length is None or request.content_length == 0:
+ # list variables
+ res = list(SettingsMgr.get_all())
+ return success_response(res)
+
+ # get var
+ data = request.get_json()
+ if not data and "var_name" not in data:
+ return error_response("Var name is required", 400)
+ var_name: str = data["var_name"]
+ res = SettingsMgr.get_by_name(var_name)
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/configs", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_config():
+ try:
+ res = list(ConfigMgr.get_all())
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/environments", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_environments():
+ try:
+ res = list(EnvironmentsMgr.get_all())
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/users//keys", methods=["POST"])
+@login_required
+@check_admin_auth
+def generate_user_api_key(username: str) -> tuple[Response, int]:
+ try:
+ user_details: list[dict[str, Any]] = UserMgr.get_user_details(username)
+ if not user_details:
+ return error_response("User not found!", 404)
+ tenants: list[dict[str, Any]] = UserServiceMgr.get_user_tenants(username)
+ if not tenants:
+ return error_response("Tenant not found!", 404)
+ tenant_id: str = tenants[0]["tenant_id"]
+ key: str = generate_confirmation_token()
+ obj: dict[str, Any] = {
+ "tenant_id": tenant_id,
+ "token": key,
+ "beta": generate_confirmation_token().replace("ragflow-", "")[:32],
+ "create_time": current_timestamp(),
+ "create_date": datetime_format(datetime.now()),
+ "update_time": None,
+ "update_date": None,
+ }
+
+ if not UserMgr.save_api_key(obj):
+ return error_response("Failed to generate API key!", 500)
+ return success_response(obj, "API key generated successfully")
+ except AdminException as e:
+ return error_response(e.message, e.code)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/users//keys", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_user_api_keys(username: str) -> tuple[Response, int]:
+ try:
+ api_keys: list[dict[str, Any]] = UserMgr.get_user_api_key(username)
+ return success_response(api_keys, "Get user API keys")
+ except AdminException as e:
+ return error_response(e.message, e.code)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/users//keys/", methods=["DELETE"])
+@login_required
+@check_admin_auth
+def delete_user_api_key(username: str, key: str) -> tuple[Response, int]:
+ try:
+ deleted = UserMgr.delete_api_key(username, key)
+ if deleted:
+ return success_response(None, "API key deleted successfully")
+ else:
+ return error_response("API key not found or could not be deleted", 404)
+ except AdminException as e:
+ return error_response(e.message, e.code)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/version", methods=["GET"])
@login_required
@check_admin_auth
def show_version():
@@ -385,3 +555,100 @@ def show_version():
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
+
+
+@admin_bp.route("/sandbox/providers", methods=["GET"])
+@login_required
+@check_admin_auth
+def list_sandbox_providers():
+ """List all available sandbox providers."""
+ try:
+ res = SandboxMgr.list_providers()
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/sandbox/providers//schema", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_sandbox_provider_schema(provider_id: str):
+ """Get configuration schema for a specific provider."""
+ try:
+ res = SandboxMgr.get_provider_config_schema(provider_id)
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/sandbox/config", methods=["GET"])
+@login_required
+@check_admin_auth
+def get_sandbox_config():
+ """Get current sandbox configuration."""
+ try:
+ res = SandboxMgr.get_config()
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/sandbox/config", methods=["POST"])
+@login_required
+@check_admin_auth
+def set_sandbox_config():
+ """Set sandbox provider configuration."""
+ try:
+ data = request.get_json()
+ if not data:
+ logging.error("set_sandbox_config: Request body is required")
+ return error_response("Request body is required", 400)
+
+ provider_type = data.get("provider_type")
+ if not provider_type:
+ logging.error("set_sandbox_config: provider_type is required")
+ return error_response("provider_type is required", 400)
+
+ config = data.get("config", {})
+ set_active = data.get("set_active", True) # Default to True for backward compatibility
+
+ logging.info(f"set_sandbox_config: provider_type={provider_type}, set_active={set_active}")
+ logging.info(f"set_sandbox_config: config keys={list(config.keys())}")
+
+ res = SandboxMgr.set_config(provider_type, config, set_active)
+ return success_response(res, "Sandbox configuration updated successfully")
+ except AdminException as e:
+ logging.exception("set_sandbox_config AdminException")
+ return error_response(str(e), 400)
+ except Exception as e:
+ logging.exception("set_sandbox_config unexpected error")
+ return error_response(str(e), 500)
+
+
+@admin_bp.route("/sandbox/test", methods=["POST"])
+@login_required
+@check_admin_auth
+def test_sandbox_connection():
+ """Test connection to sandbox provider."""
+ try:
+ data = request.get_json()
+ if not data:
+ return error_response("Request body is required", 400)
+
+ provider_type = data.get("provider_type")
+ if not provider_type:
+ return error_response("provider_type is required", 400)
+
+ config = data.get("config", {})
+ res = SandboxMgr.test_connection(provider_type, config)
+ return success_response(res)
+ except AdminException as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ return error_response(str(e), 500)
diff --git a/admin/server/services.py b/admin/server/services.py
index c394dae3a65..43646d7918a 100644
--- a/admin/server/services.py
+++ b/admin/server/services.py
@@ -14,16 +14,22 @@
# limitations under the License.
#
+import json
import os
import logging
import re
+from typing import Any
+
from werkzeug.security import check_password_hash
from common.constants import ActiveEnum
from api.db.services import UserService
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
from api.db.services.canvas_service import UserCanvasService
-from api.db.services.user_service import TenantService
+from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.knowledgebase_service import KnowledgebaseService
+from api.db.services.system_settings_service import SystemSettingsService
+from api.db.services.api_service import APITokenService
+from api.db.db_models import APIToken
from api.utils.crypt import decrypt
from api.utils import health_utils
@@ -37,13 +43,15 @@ def get_all_users():
users = UserService.get_all_users()
result = []
for user in users:
- result.append({
- 'email': user.email,
- 'nickname': user.nickname,
- 'create_date': user.create_date,
- 'is_active': user.is_active,
- 'is_superuser': user.is_superuser,
- })
+ result.append(
+ {
+ "email": user.email,
+ "nickname": user.nickname,
+ "create_date": user.create_date,
+ "is_active": user.is_active,
+ "is_superuser": user.is_superuser,
+ }
+ )
return result
@staticmethod
@@ -52,19 +60,21 @@ def get_user_details(username):
users = UserService.query_user_by_email(username)
result = []
for user in users:
- result.append({
- 'avatar': user.avatar,
- 'email': user.email,
- 'language': user.language,
- 'last_login_time': user.last_login_time,
- 'is_active': user.is_active,
- 'is_anonymous': user.is_anonymous,
- 'login_channel': user.login_channel,
- 'status': user.status,
- 'is_superuser': user.is_superuser,
- 'create_date': user.create_date,
- 'update_date': user.update_date
- })
+ result.append(
+ {
+ "avatar": user.avatar,
+ "email": user.email,
+ "language": user.language,
+ "last_login_time": user.last_login_time,
+ "is_active": user.is_active,
+ "is_anonymous": user.is_anonymous,
+ "login_channel": user.login_channel,
+ "status": user.status,
+ "is_superuser": user.is_superuser,
+ "create_date": user.create_date,
+ "update_date": user.update_date,
+ }
+ )
return result
@staticmethod
@@ -126,8 +136,8 @@ def update_user_activate_status(username, activate_status: str):
# format activate_status before handle
_activate_status = activate_status.lower()
target_status = {
- 'on': ActiveEnum.ACTIVE.value,
- 'off': ActiveEnum.INACTIVE.value,
+ "on": ActiveEnum.ACTIVE.value,
+ "off": ActiveEnum.INACTIVE.value,
}.get(_activate_status)
if not target_status:
raise AdminException(f"Invalid activate_status: {activate_status}")
@@ -137,9 +147,84 @@ def update_user_activate_status(username, activate_status: str):
UserService.update_user(usr.id, {"is_active": target_status})
return f"Turn {_activate_status} user activate status successfully!"
+ @staticmethod
+ def get_user_api_key(username: str) -> list[dict[str, Any]]:
+ # use email to find user. check exist and unique.
+ user_list: list[Any] = UserService.query_user_by_email(username)
+ if not user_list:
+ raise UserNotFoundError(username)
+ elif len(user_list) > 1:
+ raise AdminException(f"More than one user with username '{username}' found!")
-class UserServiceMgr:
+ usr: Any = user_list[0]
+ # tenant_id is typically the same as user_id for the owner tenant
+ tenant_id: str = usr.id
+
+ # Query all API keys for this tenant
+ api_keys: Any = APITokenService.query(tenant_id=tenant_id)
+
+ result: list[dict[str, Any]] = []
+ for key in api_keys:
+ result.append(key.to_dict())
+
+ return result
+
+ @staticmethod
+ def save_api_key(api_key: dict[str, Any]) -> bool:
+ return APITokenService.save(**api_key)
+
+ @staticmethod
+ def delete_api_key(username: str, key: str) -> bool:
+ # use email to find user. check exist and unique.
+ user_list: list[Any] = UserService.query_user_by_email(username)
+ if not user_list:
+ raise UserNotFoundError(username)
+ elif len(user_list) > 1:
+ raise AdminException(f"Exist more than 1 user: {username}!")
+
+ usr: Any = user_list[0]
+ # tenant_id is typically the same as user_id for the owner tenant
+ tenant_id: str = usr.id
+
+ # Delete the API key
+ deleted_count: int = APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == key])
+ return deleted_count > 0
+
+ @staticmethod
+ def grant_admin(username: str):
+ # use email to find user. check exist and unique.
+ user_list = UserService.query_user_by_email(username)
+ if not user_list:
+ raise UserNotFoundError(username)
+ elif len(user_list) > 1:
+ raise AdminException(f"Exist more than 1 user: {username}!")
+
+ # check activate status different from new
+ usr = user_list[0]
+ if usr.is_superuser:
+ return f"{usr} is already superuser!"
+ # update is_active
+ UserService.update_user(usr.id, {"is_superuser": True})
+ return "Grant successfully!"
+ @staticmethod
+ def revoke_admin(username: str):
+ # use email to find user. check exist and unique.
+ user_list = UserService.query_user_by_email(username)
+ if not user_list:
+ raise UserNotFoundError(username)
+ elif len(user_list) > 1:
+ raise AdminException(f"Exist more than 1 user: {username}!")
+ # check activate status different from new
+ usr = user_list[0]
+ if not usr.is_superuser:
+ return f"{usr} isn't superuser, yet!"
+ # update is_active
+ UserService.update_user(usr.id, {"is_superuser": False})
+ return "Revoke successfully!"
+
+
+class UserServiceMgr:
@staticmethod
def get_user_datasets(username):
# use email to find user.
@@ -169,39 +254,43 @@ def get_user_agents(username):
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted agents and owned agents
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
- return [{
- 'title': r['title'],
- 'permission': r['permission'],
- 'canvas_category': r['canvas_category'].split('_')[0],
- 'avatar': r['avatar']
- } for r in res]
+ return [{"title": r["title"], "permission": r["permission"], "canvas_category": r["canvas_category"].split("_")[0], "avatar": r["avatar"]} for r in res]
+ @staticmethod
+ def get_user_tenants(email: str) -> list[dict[str, Any]]:
+ users: list[Any] = UserService.query_user_by_email(email)
+ if not users:
+ raise UserNotFoundError(email)
+ user: Any = users[0]
-class ServiceMgr:
+ tenants: list[dict[str, Any]] = UserTenantService.get_tenants_by_user_id(user.id)
+ return tenants
+
+class ServiceMgr:
@staticmethod
def get_all_services():
- doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
+ doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
result = []
configs = SERVICE_CONFIGS.configs
for service_id, config in enumerate(configs):
config_dict = config.to_dict()
- if config_dict['service_type'] == 'retrieval':
- if config_dict['extra']['retrieval_type'] != doc_engine:
+ if config_dict["service_type"] == "retrieval":
+ if config_dict["extra"]["retrieval_type"] != doc_engine:
continue
try:
service_detail = ServiceMgr.get_service_details(service_id)
if "status" in service_detail:
- config_dict['status'] = service_detail['status']
+ config_dict["status"] = service_detail["status"]
else:
- config_dict['status'] = 'timeout'
+ config_dict["status"] = "timeout"
except Exception as e:
logging.warning(f"Can't get service details, error: {e}")
- config_dict['status'] = 'timeout'
- if not config_dict['host']:
- config_dict['host'] = '-'
- if not config_dict['port']:
- config_dict['port'] = '-'
+ config_dict["status"] = "timeout"
+ if not config_dict["host"]:
+ config_dict["host"] = "-"
+ if not config_dict["port"]:
+ config_dict["port"] = "-"
result.append(config_dict)
return result
@@ -217,11 +306,18 @@ def get_service_details(service_id: int):
raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx]
- service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
- detail_func = getattr(health_utils, service_info.get('detail_func_name'))
+ # exclude retrieval service if retrieval_type is not matched
+ doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
+ if service_config.service_type == "retrieval":
+ if service_config.retrieval_type != doc_engine:
+ raise AdminException(f"invalid service_index: {service_idx}")
+
+ service_info = {"name": service_config.name, "detail_func_name": service_config.detail_func_name}
+
+ detail_func = getattr(health_utils, service_info.get("detail_func_name"))
res = detail_func()
- res.update({'service_name': service_info.get('name')})
+ res.update({"service_name": service_info.get("name")})
return res
@staticmethod
@@ -231,3 +327,397 @@ def shutdown_service(service_id: int):
@staticmethod
def restart_service(service_id: int):
raise AdminException("restart_service: not implemented")
+
+
+class SettingsMgr:
+ @staticmethod
+ def get_all():
+ settings = SystemSettingsService.get_all()
+ result = []
+ for setting in settings:
+ result.append(
+ {
+ "name": setting.name,
+ "source": setting.source,
+ "data_type": setting.data_type,
+ "value": setting.value,
+ }
+ )
+ return result
+
+ @staticmethod
+ def get_by_name(name: str):
+ settings = SystemSettingsService.get_by_name(name)
+ if len(settings) == 0:
+ raise AdminException(f"Can't get setting: {name}")
+ result = []
+ for setting in settings:
+ result.append(
+ {
+ "name": setting.name,
+ "source": setting.source,
+ "data_type": setting.data_type,
+ "value": setting.value,
+ }
+ )
+ return result
+
+ @staticmethod
+ def update_by_name(name: str, value: str):
+ settings = SystemSettingsService.get_by_name(name)
+ if len(settings) == 1:
+ setting = settings[0]
+ setting.value = value
+ setting_dict = setting.to_dict()
+ SystemSettingsService.update_by_name(name, setting_dict)
+ elif len(settings) > 1:
+ raise AdminException(f"Can't update more than 1 setting: {name}")
+ else:
+ # Create new setting if it doesn't exist
+
+ # Determine data_type based on name and value
+ if name.startswith("sandbox."):
+ data_type = "json"
+ elif name.endswith(".enabled"):
+ data_type = "boolean"
+ else:
+ data_type = "string"
+
+ new_setting = {
+ "name": name,
+ "value": str(value),
+ "source": "admin",
+ "data_type": data_type,
+ }
+ SystemSettingsService.save(**new_setting)
+
+
+class ConfigMgr:
+ @staticmethod
+ def get_all():
+ result = []
+ configs = SERVICE_CONFIGS.configs
+ for config in configs:
+ config_dict = config.to_dict()
+ result.append(config_dict)
+ return result
+
+
+class EnvironmentsMgr:
+ @staticmethod
+ def get_all():
+ result = []
+
+ env_kv = {"env": "DOC_ENGINE", "value": os.getenv("DOC_ENGINE")}
+ result.append(env_kv)
+
+ env_kv = {"env": "DEFAULT_SUPERUSER_EMAIL", "value": os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")}
+ result.append(env_kv)
+
+ env_kv = {"env": "DB_TYPE", "value": os.getenv("DB_TYPE", "mysql")}
+ result.append(env_kv)
+
+ env_kv = {"env": "DEVICE", "value": os.getenv("DEVICE", "cpu")}
+ result.append(env_kv)
+
+ env_kv = {"env": "STORAGE_IMPL", "value": os.getenv("STORAGE_IMPL", "MINIO")}
+ result.append(env_kv)
+
+ return result
+
+
+class SandboxMgr:
+ """Manager for sandbox provider configuration and operations."""
+
+ # Provider registry with metadata
+ PROVIDER_REGISTRY = {
+ "self_managed": {
+ "name": "Self-Managed",
+ "description": "On-premise deployment using Daytona/Docker",
+ "tags": ["self-hosted", "low-latency", "secure"],
+ },
+ "aliyun_codeinterpreter": {
+ "name": "Aliyun Code Interpreter",
+ "description": "Aliyun Function Compute Code Interpreter - Code execution in serverless microVMs",
+ "tags": ["saas", "cloud", "scalable", "aliyun"],
+ },
+ "e2b": {
+ "name": "E2B",
+ "description": "E2B Cloud - Code Execution Sandboxes",
+ "tags": ["saas", "fast", "global"],
+ },
+ }
+
+ @staticmethod
+ def list_providers():
+ """List all available sandbox providers."""
+ result = []
+ for provider_id, metadata in SandboxMgr.PROVIDER_REGISTRY.items():
+ result.append({
+ "id": provider_id,
+ **metadata
+ })
+ return result
+
+ @staticmethod
+ def get_provider_config_schema(provider_id: str):
+ """Get configuration schema for a specific provider."""
+ from agent.sandbox.providers import (
+ SelfManagedProvider,
+ AliyunCodeInterpreterProvider,
+ E2BProvider,
+ )
+
+ schemas = {
+ "self_managed": SelfManagedProvider.get_config_schema(),
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider.get_config_schema(),
+ "e2b": E2BProvider.get_config_schema(),
+ }
+
+ if provider_id not in schemas:
+ raise AdminException(f"Unknown provider: {provider_id}")
+
+ return schemas.get(provider_id, {})
+
+ @staticmethod
+ def get_config():
+ """Get current sandbox configuration."""
+ try:
+ # Get active provider type
+ provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
+ if not provider_type_settings:
+ # Return default config if not set
+ provider_type = "self_managed"
+ else:
+ provider_type = provider_type_settings[0].value
+
+ # Get provider-specific config
+ provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
+ if not provider_config_settings:
+ provider_config = {}
+ else:
+ try:
+ provider_config = json.loads(provider_config_settings[0].value)
+ except json.JSONDecodeError:
+ provider_config = {}
+
+ return {
+ "provider_type": provider_type,
+ "config": provider_config,
+ }
+ except Exception as e:
+ raise AdminException(f"Failed to get sandbox config: {str(e)}")
+
+ @staticmethod
+ def set_config(provider_type: str, config: dict, set_active: bool = True):
+ """
+ Set sandbox provider configuration.
+
+ Args:
+ provider_type: Provider identifier (e.g., "self_managed", "e2b")
+ config: Provider configuration dictionary
+ set_active: If True, also update the active provider. If False,
+ only update the configuration without switching providers.
+ Default: True
+
+ Returns:
+ Dictionary with updated provider_type and config
+ """
+ from agent.sandbox.providers import (
+ SelfManagedProvider,
+ AliyunCodeInterpreterProvider,
+ E2BProvider,
+ )
+
+ try:
+ # Validate provider type
+ if provider_type not in SandboxMgr.PROVIDER_REGISTRY:
+ raise AdminException(f"Unknown provider type: {provider_type}")
+
+ # Get provider schema for validation
+ schema = SandboxMgr.get_provider_config_schema(provider_type)
+
+ # Validate config against schema
+ for field_name, field_schema in schema.items():
+ if field_schema.get("required", False) and field_name not in config:
+ raise AdminException(f"Required field '{field_name}' is missing")
+
+ # Type validation
+ if field_name in config:
+ field_type = field_schema.get("type")
+ if field_type == "integer":
+ if not isinstance(config[field_name], int):
+ raise AdminException(f"Field '{field_name}' must be an integer")
+ elif field_type == "string":
+ if not isinstance(config[field_name], str):
+ raise AdminException(f"Field '{field_name}' must be a string")
+ elif field_type == "bool":
+ if not isinstance(config[field_name], bool):
+ raise AdminException(f"Field '{field_name}' must be a boolean")
+
+ # Range validation for integers
+ if field_type == "integer" and field_name in config:
+ min_val = field_schema.get("min")
+ max_val = field_schema.get("max")
+ if min_val is not None and config[field_name] < min_val:
+ raise AdminException(f"Field '{field_name}' must be >= {min_val}")
+ if max_val is not None and config[field_name] > max_val:
+ raise AdminException(f"Field '{field_name}' must be <= {max_val}")
+
+ # Provider-specific custom validation
+ provider_classes = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+ }
+ provider = provider_classes[provider_type]()
+ is_valid, error_msg = provider.validate_config(config)
+ if not is_valid:
+ raise AdminException(f"Provider validation failed: {error_msg}")
+
+ # Update provider_type only if set_active is True
+ if set_active:
+ SettingsMgr.update_by_name("sandbox.provider_type", provider_type)
+
+ # Always update the provider config
+ config_json = json.dumps(config)
+ SettingsMgr.update_by_name(f"sandbox.{provider_type}", config_json)
+
+ return {"provider_type": provider_type, "config": config}
+ except AdminException:
+ raise
+ except Exception as e:
+ raise AdminException(f"Failed to set sandbox config: {str(e)}")
+
+ @staticmethod
+ def test_connection(provider_type: str, config: dict):
+ """
+ Test connection to sandbox provider by executing a simple Python script.
+
+ This creates a temporary sandbox instance and runs a test code to verify:
+ - Connection credentials are valid
+ - Sandbox can be created
+ - Code execution works correctly
+
+ Args:
+ provider_type: Provider identifier
+ config: Provider configuration dictionary
+
+ Returns:
+ dict with test results including stdout, stderr, exit_code, execution_time
+ """
+ try:
+ from agent.sandbox.providers import (
+ SelfManagedProvider,
+ AliyunCodeInterpreterProvider,
+ E2BProvider,
+ )
+
+ # Instantiate provider based on type
+ provider_classes = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+ }
+
+ if provider_type not in provider_classes:
+ raise AdminException(f"Unknown provider type: {provider_type}")
+
+ provider = provider_classes[provider_type]()
+
+ # Initialize with config
+ if not provider.initialize(config):
+ raise AdminException(f"Failed to initialize provider '{provider_type}'")
+
+ # Create a temporary sandbox instance for testing
+ instance = provider.create_instance(template="python")
+
+ if not instance or instance.status != "READY":
+ raise AdminException(f"Failed to create sandbox instance. Status: {instance.status if instance else 'None'}")
+
+ # Simple test code that exercises basic Python functionality
+ test_code = """
+# Test basic Python functionality
+import sys
+import json
+import math
+
+print("Python version:", sys.version)
+print("Platform:", sys.platform)
+
+# Test basic calculations
+result = 2 + 2
+print(f"2 + 2 = {result}")
+
+# Test JSON operations
+data = {"test": "data", "value": 123}
+print(f"JSON dump: {json.dumps(data)}")
+
+# Test math operations
+print(f"Math.sqrt(16) = {math.sqrt(16)}")
+
+# Test error handling
+try:
+ x = 1 / 1
+ print("Division test: OK")
+except Exception as e:
+ print(f"Error: {e}")
+
+# Return success indicator
+print("TEST_PASSED")
+"""
+
+ # Execute test code with timeout
+ execution_result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code=test_code,
+ language="python",
+ timeout=10 # 10 seconds timeout
+ )
+
+ # Clean up the test instance (if provider supports it)
+ try:
+ if hasattr(provider, 'terminate_instance'):
+ provider.terminate_instance(instance.instance_id)
+ logging.info(f"Cleaned up test instance {instance.instance_id}")
+ else:
+ logging.warning(f"Provider {provider_type} does not support terminate_instance, test instance may leak")
+ except Exception as cleanup_error:
+ logging.warning(f"Failed to cleanup test instance {instance.instance_id}: {cleanup_error}")
+
+ # Build detailed result message
+ success = execution_result.exit_code == 0 and "TEST_PASSED" in execution_result.stdout
+
+ message_parts = [
+ f"Test {success and 'PASSED' or 'FAILED'}",
+ f"Exit code: {execution_result.exit_code}",
+ f"Execution time: {execution_result.execution_time:.2f}s"
+ ]
+
+ if execution_result.stdout.strip():
+ stdout_preview = execution_result.stdout.strip()[:200]
+ message_parts.append(f"Output: {stdout_preview}...")
+
+ if execution_result.stderr.strip():
+ stderr_preview = execution_result.stderr.strip()[:200]
+ message_parts.append(f"Errors: {stderr_preview}...")
+
+ message = " | ".join(message_parts)
+
+ return {
+ "success": success,
+ "message": message,
+ "details": {
+ "exit_code": execution_result.exit_code,
+ "execution_time": execution_result.execution_time,
+ "stdout": execution_result.stdout,
+ "stderr": execution_result.stderr,
+ }
+ }
+
+ except AdminException:
+ raise
+ except Exception as e:
+ import traceback
+ error_details = traceback.format_exc()
+ raise AdminException(f"Connection test failed: {str(e)}\\n\\nStack trace:\\n{error_details}")
diff --git a/agent/canvas.py b/agent/canvas.py
index 6368e10e355..7a1d3bd234e 100644
--- a/agent/canvas.py
+++ b/agent/canvas.py
@@ -78,13 +78,14 @@ class Graph:
}
"""
- def __init__(self, dsl: str, tenant_id=None, task_id=None):
+ def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None):
self.path = []
self.components = {}
self.error = ""
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
+ self.custom_header = custom_header
self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load()
@@ -94,6 +95,7 @@ def load(self):
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
+ cpn["obj"]["params"]["custom_header"] = self.custom_header
param.update(cpn["obj"]["params"])
try:
param.check()
@@ -278,15 +280,16 @@ def cancel_task(self) -> bool:
class Canvas(Graph):
- def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None):
+ def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
- "sys.files": []
+ "sys.files": [],
+ "sys.history": []
}
self.variables = {}
- super().__init__(dsl, tenant_id, task_id)
+ super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
self._id = canvas_id
def load(self):
@@ -294,12 +297,15 @@ def load(self):
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
+ if "sys.history" not in self.globals:
+ self.globals["sys.history"] = []
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
- "sys.files": []
+ "sys.files": [],
+ "sys.history": []
}
if "variables" in self.dsl:
self.variables = self.dsl["variables"]
@@ -340,21 +346,23 @@ def reset(self, mem=False):
key = k[4:]
if key in self.variables:
variable = self.variables[key]
- if variable["value"]:
- self.globals[k] = variable["value"]
+ if variable["type"] == "string":
+ self.globals[k] = ""
+ variable["value"] = ""
+ elif variable["type"] == "number":
+ self.globals[k] = 0
+ variable["value"] = 0
+ elif variable["type"] == "boolean":
+ self.globals[k] = False
+ variable["value"] = False
+ elif variable["type"] == "object":
+ self.globals[k] = {}
+ variable["value"] = {}
+ elif variable["type"].startswith("array"):
+ self.globals[k] = []
+ variable["value"] = []
else:
- if variable["type"] == "string":
- self.globals[k] = ""
- elif variable["type"] == "number":
- self.globals[k] = 0
- elif variable["type"] == "boolean":
- self.globals[k] = False
- elif variable["type"] == "object":
- self.globals[k] = {}
- elif variable["type"].startswith("array"):
- self.globals[k] = []
- else:
- self.globals[k] = ""
+ self.globals[k] = ""
else:
self.globals[k] = ""
@@ -419,9 +427,15 @@ async def _run_batch(f, t):
loop = asyncio.get_running_loop()
tasks = []
+ max_concurrency = getattr(self._thread_pool, "_max_workers", 5)
+ sem = asyncio.Semaphore(max_concurrency)
- def _run_async_in_thread(coro_func, **call_kwargs):
- return asyncio.run(coro_func(**call_kwargs))
+ async def _invoke_one(cpn_obj, sync_fn, call_kwargs, use_async: bool):
+ async with sem:
+ if use_async:
+ await cpn_obj.invoke_async(**(call_kwargs or {}))
+ return
+ await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
i = f
while i < t:
@@ -447,11 +461,9 @@ def _run_async_in_thread(coro_func, **call_kwargs):
if task_fn is None:
continue
- invoke_async = getattr(cpn, "invoke_async", None)
- if invoke_async and asyncio.iscoroutinefunction(invoke_async):
- tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
- else:
- tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
+ fn_invoke_async = getattr(cpn, "_invoke_async", None)
+ use_async = (fn_invoke_async and asyncio.iscoroutinefunction(fn_invoke_async)) or asyncio.iscoroutinefunction(getattr(cpn, "_invoke", None))
+ tasks.append(asyncio.create_task(_invoke_one(cpn, task_fn, call_kwargs, use_async)))
if tasks:
await asyncio.gather(*tasks)
@@ -638,6 +650,7 @@ def _extend_path(cpn_ids):
"created_at": st,
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
+ self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
elif "Task has been canceled" in self.error:
yield decorate("workflow_finished",
{
@@ -715,6 +728,7 @@ def get_history(self, window_size):
def add_user_input(self, question):
self.history.append(("user", question))
+ self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue
@@ -740,13 +754,16 @@ async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
+ def parse_file(file):
+ blob = FileService.get_blob(file["created_by"], file["id"])
+ return FileService.parse(file["name"], blob, True, file["created_by"])
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
- tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
+ tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
return await asyncio.gather(*tasks)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py
index 5ff55adf93e..4ff09420ae3 100644
--- a/agent/component/agent_with_tools.py
+++ b/agent/component/agent_with_tools.py
@@ -76,6 +76,8 @@ def __init__(self):
self.mcp = []
self.max_rounds = 5
self.description = ""
+ self.custom_header = {}
+
class Agent(LLM, ToolBase):
@@ -105,7 +107,8 @@ def __init__(self, canvas, id, param: LLMParam):
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
- tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
+ custom_header = self._param.custom_header
+ tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session
diff --git a/agent/component/base.py b/agent/component/base.py
index 264f3972a34..9bceb4ce6d9 100644
--- a/agent/component/base.py
+++ b/agent/component/base.py
@@ -27,6 +27,10 @@
from agent import settings
from common.connection_utils import timeout
+
+
+from common.misc_utils import thread_pool_exec
+
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
_USER_FEEDED_PARAMS = "_user_feeded_params"
@@ -379,6 +383,7 @@ def __str__(self):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Graph # Local import to avoid cyclic dependency
+
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@@ -430,7 +435,7 @@ async def invoke_async(self, **kwargs) -> dict[str, Any]:
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
- await asyncio.to_thread(self._invoke, **kwargs)
+ await thread_pool_exec(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
diff --git a/agent/component/begin.py b/agent/component/begin.py
index bcbfdbf24b7..819e46c2540 100644
--- a/agent/component/begin.py
+++ b/agent/component/begin.py
@@ -45,11 +45,14 @@ def _invoke(self, **kwargs):
if self.check_if_canceled("Begin processing"):
return
- if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
+ if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
- v = FileService.get_files([v["value"]])
+ file_value = v["value"]
+ # Support both single file (backward compatibility) and multiple files
+ files = file_value if isinstance(file_value, list) else [file_value]
+ v = FileService.get_files(files)
else:
v = v.get("value")
self.set_output(k, v)
diff --git a/agent/component/categorize.py b/agent/component/categorize.py
index 27cffb91c88..b5a6a4b9c6a 100644
--- a/agent/component/categorize.py
+++ b/agent/component/categorize.py
@@ -97,6 +97,13 @@ def update_prompt(self):
class Categorize(LLM, ABC):
component_name = "Categorize"
+ def get_input_elements(self) -> dict[str, dict]:
+ query_key = self._param.query or "sys.query"
+ elements = self.get_input_elements_from_text(f"{{{query_key}}}")
+ if not elements:
+ logging.warning(f"[Categorize] input element not detected for query key: {query_key}")
+ return elements
+
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Categorize processing"):
@@ -105,12 +112,15 @@ async def _invoke_async(self, **kwargs):
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
- if kwargs.get("sys.query"):
- msg[-1]["content"] = kwargs["sys.query"]
- self.set_input_value("sys.query", kwargs["sys.query"])
+ query_key = self._param.query or "sys.query"
+ if query_key in kwargs:
+ query_value = kwargs[query_key]
else:
- msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
- self.set_input_value(self._param.query, msg[-1]["content"])
+ query_value = self._canvas.get_variable_value(query_key)
+ if query_value is None:
+ query_value = ""
+ msg[-1]["content"] = query_value
+ self.set_input_value(query_key, msg[-1]["content"])
self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
@@ -137,7 +147,7 @@ async def _invoke_async(self, **kwargs):
category_counts[c] = count
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
- max_category = list(self._param.category_description.keys())[0]
+ max_category = list(self._param.category_description.keys())[-1]
if any(category_counts.values()):
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
cpn_ids = self._param.category_description[max_category]["to"]
diff --git a/agent/component/fillup.py b/agent/component/fillup.py
index 10163d10c0b..b97e6ca526b 100644
--- a/agent/component/fillup.py
+++ b/agent/component/fillup.py
@@ -64,11 +64,14 @@ def _invoke(self, **kwargs):
for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
- if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
+ if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
- v = FileService.get_files([v["value"]])
+ file_value = v["value"]
+ # Support both single file (backward compatibility) and multiple files
+ files = file_value if isinstance(file_value, list) else [file_value]
+ v = FileService.get_files(files)
else:
v = v.get("value")
self.set_output(k, v)
diff --git a/plugin/README.md b/agent/plugin/README.md
similarity index 98%
rename from plugin/README.md
rename to agent/plugin/README.md
index cd11e91dbc9..4f1ac152c15 100644
--- a/plugin/README.md
+++ b/agent/plugin/README.md
@@ -23,7 +23,7 @@ All the execution logic of this tool should go into this method.
When you start RAGFlow, you can see your plugin was loaded in the log:
```
-2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins`
+2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/agent/plugin/embedded_plugins`
2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0
```
diff --git a/plugin/README_zh.md b/agent/plugin/README_zh.md
similarity index 98%
rename from plugin/README_zh.md
rename to agent/plugin/README_zh.md
index 17b3dd703d7..eb9910ba40e 100644
--- a/plugin/README_zh.md
+++ b/agent/plugin/README_zh.md
@@ -23,7 +23,7 @@ RAGFlow将会从`embedded_plugins`子文件夹中递归加载所有的插件。
当你启动RAGFlow时,你会在日志中看见你的插件被加载了:
```
-2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins`
+2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/agent/plugin/embedded_plugins`
2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0
```
diff --git a/plugin/__init__.py b/agent/plugin/__init__.py
similarity index 100%
rename from plugin/__init__.py
rename to agent/plugin/__init__.py
diff --git a/plugin/common.py b/agent/plugin/common.py
similarity index 100%
rename from plugin/common.py
rename to agent/plugin/common.py
diff --git a/plugin/embedded_plugins/llm_tools/bad_calculator.py b/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py
similarity index 94%
rename from plugin/embedded_plugins/llm_tools/bad_calculator.py
rename to agent/plugin/embedded_plugins/llm_tools/bad_calculator.py
index 04c3b815a38..38376aa984f 100644
--- a/plugin/embedded_plugins/llm_tools/bad_calculator.py
+++ b/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py
@@ -1,5 +1,5 @@
import logging
-from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
+from agent.plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
class BadCalculatorPlugin(LLMToolPlugin):
diff --git a/plugin/llm_tool_plugin.py b/agent/plugin/llm_tool_plugin.py
similarity index 100%
rename from plugin/llm_tool_plugin.py
rename to agent/plugin/llm_tool_plugin.py
diff --git a/plugin/plugin_manager.py b/agent/plugin/plugin_manager.py
similarity index 100%
rename from plugin/plugin_manager.py
rename to agent/plugin/plugin_manager.py
diff --git a/sandbox/.env.example b/agent/sandbox/.env.example
similarity index 100%
rename from sandbox/.env.example
rename to agent/sandbox/.env.example
diff --git a/sandbox/Makefile b/agent/sandbox/Makefile
similarity index 100%
rename from sandbox/Makefile
rename to agent/sandbox/Makefile
diff --git a/sandbox/README.md b/agent/sandbox/README.md
similarity index 100%
rename from sandbox/README.md
rename to agent/sandbox/README.md
diff --git a/sandbox/asserts/code_executor_manager.svg b/agent/sandbox/asserts/code_executor_manager.svg
similarity index 100%
rename from sandbox/asserts/code_executor_manager.svg
rename to agent/sandbox/asserts/code_executor_manager.svg
diff --git a/agent/sandbox/client.py b/agent/sandbox/client.py
new file mode 100644
index 00000000000..4d49ae734c6
--- /dev/null
+++ b/agent/sandbox/client.py
@@ -0,0 +1,239 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Sandbox client for agent components.
+
+This module provides a unified interface for agent components to interact
+with the configured sandbox provider.
+"""
+
+import json
+import logging
+from typing import Dict, Any, Optional
+
+from api.db.services.system_settings_service import SystemSettingsService
+from agent.sandbox.providers import ProviderManager
+from agent.sandbox.providers.base import ExecutionResult
+
+logger = logging.getLogger(__name__)
+
+
+# Global provider manager instance
+_provider_manager: Optional[ProviderManager] = None
+
+
+def get_provider_manager() -> ProviderManager:
+ """
+ Get the global provider manager instance.
+
+ Returns:
+ ProviderManager instance with active provider loaded
+ """
+ global _provider_manager
+
+ if _provider_manager is not None:
+ return _provider_manager
+
+ # Initialize provider manager with system settings
+ _provider_manager = ProviderManager()
+ _load_provider_from_settings()
+
+ return _provider_manager
+
+
+def _load_provider_from_settings() -> None:
+ """
+ Load sandbox provider from system settings and configure the provider manager.
+
+ This function reads the system settings to determine which provider is active
+ and initializes it with the appropriate configuration.
+ """
+ global _provider_manager
+
+ if _provider_manager is None:
+ return
+
+ try:
+ # Get active provider type
+ provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
+ if not provider_type_settings:
+ raise RuntimeError(
+ "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings."
+ )
+ provider_type = provider_type_settings[0].value
+
+ # Get provider configuration
+ provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
+
+ if not provider_config_settings:
+ logger.warning(f"No configuration found for provider: {provider_type}")
+ config = {}
+ else:
+ try:
+ config = json.loads(provider_config_settings[0].value)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse sandbox config for {provider_type}: {e}")
+ config = {}
+
+ # Import and instantiate the provider
+ from agent.sandbox.providers import (
+ SelfManagedProvider,
+ AliyunCodeInterpreterProvider,
+ E2BProvider,
+ )
+
+ provider_classes = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+ }
+
+ if provider_type not in provider_classes:
+ logger.error(f"Unknown provider type: {provider_type}")
+ return
+
+ provider_class = provider_classes[provider_type]
+ provider = provider_class()
+
+ # Initialize the provider
+ if not provider.initialize(config):
+ logger.error(f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}")
+ return
+
+ # Set the active provider
+ _provider_manager.set_provider(provider_type, provider)
+ logger.info(f"Sandbox provider '{provider_type}' initialized successfully")
+
+ except Exception as e:
+ logger.error(f"Failed to load sandbox provider from settings: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+def reload_provider() -> None:
+ """
+ Reload the sandbox provider from system settings.
+
+ Use this function when sandbox settings have been updated.
+ """
+ global _provider_manager
+ _provider_manager = None
+ _load_provider_from_settings()
+
+
+def execute_code(
+ code: str,
+ language: str = "python",
+ timeout: int = 30,
+ arguments: Optional[Dict[str, Any]] = None
+) -> ExecutionResult:
+ """
+ Execute code in the configured sandbox.
+
+ This is the main entry point for agent components to execute code.
+
+ Args:
+ code: Source code to execute
+ language: Programming language (python, nodejs, javascript)
+ timeout: Maximum execution time in seconds
+ arguments: Optional arguments dict to pass to main() function
+
+ Returns:
+ ExecutionResult containing stdout, stderr, exit_code, and metadata
+
+ Raises:
+ RuntimeError: If no provider is configured or execution fails
+ """
+ provider_manager = get_provider_manager()
+
+ if not provider_manager.is_configured():
+ raise RuntimeError(
+ "No sandbox provider configured. Please configure sandbox settings in the admin panel."
+ )
+
+ provider = provider_manager.get_provider()
+
+ # Create a sandbox instance
+ instance = provider.create_instance(template=language)
+
+ try:
+ # Execute the code
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code=code,
+ language=language,
+ timeout=timeout,
+ arguments=arguments
+ )
+
+ return result
+
+ finally:
+ # Clean up the instance
+ try:
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ logger.warning(f"Failed to destroy sandbox instance {instance.instance_id}: {e}")
+
+
+def health_check() -> bool:
+ """
+ Check if the sandbox provider is healthy.
+
+ Returns:
+ True if provider is configured and healthy, False otherwise
+ """
+ try:
+ provider_manager = get_provider_manager()
+
+ if not provider_manager.is_configured():
+ return False
+
+ provider = provider_manager.get_provider()
+ return provider.health_check()
+
+ except Exception as e:
+ logger.error(f"Sandbox health check failed: {e}")
+ return False
+
+
+def get_provider_info() -> Dict[str, Any]:
+ """
+ Get information about the current sandbox provider.
+
+ Returns:
+ Dictionary with provider information:
+ - provider_type: Type of the active provider
+ - configured: Whether provider is configured
+ - healthy: Whether provider is healthy
+ """
+ try:
+ provider_manager = get_provider_manager()
+
+ return {
+ "provider_type": provider_manager.get_provider_name(),
+ "configured": provider_manager.is_configured(),
+ "healthy": health_check(),
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to get provider info: {e}")
+ return {
+ "provider_type": None,
+ "configured": False,
+ "healthy": False,
+ }
diff --git a/sandbox/docker-compose.yml b/agent/sandbox/docker-compose.yml
similarity index 100%
rename from sandbox/docker-compose.yml
rename to agent/sandbox/docker-compose.yml
diff --git a/agent/sandbox/executor_manager/Dockerfile b/agent/sandbox/executor_manager/Dockerfile
new file mode 100644
index 00000000000..9444a848763
--- /dev/null
+++ b/agent/sandbox/executor_manager/Dockerfile
@@ -0,0 +1,37 @@
+FROM python:3.11-slim-bookworm
+
+RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \
+ apt-get update && \
+ apt-get install -y curl gcc && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG TARGETARCH
+ARG TARGETVARIANT
+
+RUN set -eux; \
+ case "${TARGETARCH}${TARGETVARIANT}" in \
+ amd64) DOCKER_ARCH=x86_64 ;; \
+ arm64) DOCKER_ARCH=aarch64 ;; \
+ armv7) DOCKER_ARCH=armhf ;; \
+ armv6) DOCKER_ARCH=armel ;; \
+ arm64v8) DOCKER_ARCH=aarch64 ;; \
+ arm64v7) DOCKER_ARCH=armhf ;; \
+ arm*) DOCKER_ARCH=armhf ;; \
+ ppc64le) DOCKER_ARCH=ppc64le ;; \
+ s390x) DOCKER_ARCH=s390x ;; \
+ *) echo "Unsupported architecture: ${TARGETARCH}${TARGETVARIANT}" && exit 1 ;; \
+ esac; \
+ echo "Downloading Docker for architecture: ${DOCKER_ARCH}"; \
+ curl -fsSL "https://download.docker.com/linux/static/stable/${DOCKER_ARCH}/docker-29.1.0.tgz" | \
+ tar xz -C /usr/local/bin --strip-components=1 docker/docker; \
+ ln -sf /usr/local/bin/docker /usr/bin/docker
+
+COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/
+ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
+
+WORKDIR /app
+COPY . .
+
+RUN uv pip install --system -r requirements.txt
+
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9385"]
diff --git a/graphrag/light/__init__.py b/agent/sandbox/executor_manager/api/__init__.py
similarity index 100%
rename from graphrag/light/__init__.py
rename to agent/sandbox/executor_manager/api/__init__.py
diff --git a/sandbox/executor_manager/api/handlers.py b/agent/sandbox/executor_manager/api/handlers.py
similarity index 100%
rename from sandbox/executor_manager/api/handlers.py
rename to agent/sandbox/executor_manager/api/handlers.py
diff --git a/sandbox/executor_manager/api/routes.py b/agent/sandbox/executor_manager/api/routes.py
similarity index 100%
rename from sandbox/executor_manager/api/routes.py
rename to agent/sandbox/executor_manager/api/routes.py
diff --git a/sandbox/executor_manager/api/__init__.py b/agent/sandbox/executor_manager/core/__init__.py
similarity index 100%
rename from sandbox/executor_manager/api/__init__.py
rename to agent/sandbox/executor_manager/core/__init__.py
diff --git a/sandbox/executor_manager/core/config.py b/agent/sandbox/executor_manager/core/config.py
similarity index 100%
rename from sandbox/executor_manager/core/config.py
rename to agent/sandbox/executor_manager/core/config.py
diff --git a/sandbox/executor_manager/core/container.py b/agent/sandbox/executor_manager/core/container.py
similarity index 100%
rename from sandbox/executor_manager/core/container.py
rename to agent/sandbox/executor_manager/core/container.py
diff --git a/sandbox/executor_manager/core/logger.py b/agent/sandbox/executor_manager/core/logger.py
similarity index 100%
rename from sandbox/executor_manager/core/logger.py
rename to agent/sandbox/executor_manager/core/logger.py
diff --git a/sandbox/executor_manager/main.py b/agent/sandbox/executor_manager/main.py
similarity index 100%
rename from sandbox/executor_manager/main.py
rename to agent/sandbox/executor_manager/main.py
diff --git a/sandbox/executor_manager/core/__init__.py b/agent/sandbox/executor_manager/models/__init__.py
similarity index 100%
rename from sandbox/executor_manager/core/__init__.py
rename to agent/sandbox/executor_manager/models/__init__.py
diff --git a/sandbox/executor_manager/models/enums.py b/agent/sandbox/executor_manager/models/enums.py
similarity index 100%
rename from sandbox/executor_manager/models/enums.py
rename to agent/sandbox/executor_manager/models/enums.py
diff --git a/sandbox/executor_manager/models/schemas.py b/agent/sandbox/executor_manager/models/schemas.py
similarity index 100%
rename from sandbox/executor_manager/models/schemas.py
rename to agent/sandbox/executor_manager/models/schemas.py
diff --git a/sandbox/executor_manager/requirements.txt b/agent/sandbox/executor_manager/requirements.txt
similarity index 100%
rename from sandbox/executor_manager/requirements.txt
rename to agent/sandbox/executor_manager/requirements.txt
diff --git a/sandbox/executor_manager/seccomp-profile-default.json b/agent/sandbox/executor_manager/seccomp-profile-default.json
similarity index 100%
rename from sandbox/executor_manager/seccomp-profile-default.json
rename to agent/sandbox/executor_manager/seccomp-profile-default.json
diff --git a/sandbox/executor_manager/models/__init__.py b/agent/sandbox/executor_manager/services/__init__.py
similarity index 100%
rename from sandbox/executor_manager/models/__init__.py
rename to agent/sandbox/executor_manager/services/__init__.py
diff --git a/sandbox/executor_manager/services/execution.py b/agent/sandbox/executor_manager/services/execution.py
similarity index 100%
rename from sandbox/executor_manager/services/execution.py
rename to agent/sandbox/executor_manager/services/execution.py
diff --git a/sandbox/executor_manager/services/limiter.py b/agent/sandbox/executor_manager/services/limiter.py
similarity index 100%
rename from sandbox/executor_manager/services/limiter.py
rename to agent/sandbox/executor_manager/services/limiter.py
diff --git a/sandbox/executor_manager/services/security.py b/agent/sandbox/executor_manager/services/security.py
similarity index 100%
rename from sandbox/executor_manager/services/security.py
rename to agent/sandbox/executor_manager/services/security.py
diff --git a/sandbox/executor_manager/util.py b/agent/sandbox/executor_manager/util.py
similarity index 100%
rename from sandbox/executor_manager/util.py
rename to agent/sandbox/executor_manager/util.py
diff --git a/sandbox/executor_manager/services/__init__.py b/agent/sandbox/executor_manager/utils/__init__.py
similarity index 100%
rename from sandbox/executor_manager/services/__init__.py
rename to agent/sandbox/executor_manager/utils/__init__.py
diff --git a/sandbox/executor_manager/utils/common.py b/agent/sandbox/executor_manager/utils/common.py
similarity index 100%
rename from sandbox/executor_manager/utils/common.py
rename to agent/sandbox/executor_manager/utils/common.py
diff --git a/agent/sandbox/providers/__init__.py b/agent/sandbox/providers/__init__.py
new file mode 100644
index 00000000000..7be1463b9ca
--- /dev/null
+++ b/agent/sandbox/providers/__init__.py
@@ -0,0 +1,43 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Sandbox providers package.
+
+This package contains:
+- base.py: Base interface for all sandbox providers
+- manager.py: Provider manager for managing active provider
+- self_managed.py: Self-managed provider implementation (wraps existing executor_manager)
+- aliyun_codeinterpreter.py: Aliyun Code Interpreter provider implementation
+ Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+- e2b.py: E2B provider implementation
+"""
+
+from .base import SandboxProvider, SandboxInstance, ExecutionResult
+from .manager import ProviderManager
+from .self_managed import SelfManagedProvider
+from .aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+from .e2b import E2BProvider
+
+__all__ = [
+ "SandboxProvider",
+ "SandboxInstance",
+ "ExecutionResult",
+ "ProviderManager",
+ "SelfManagedProvider",
+ "AliyunCodeInterpreterProvider",
+ "E2BProvider",
+]
diff --git a/agent/sandbox/providers/aliyun_codeinterpreter.py b/agent/sandbox/providers/aliyun_codeinterpreter.py
new file mode 100644
index 00000000000..56e66977a3e
--- /dev/null
+++ b/agent/sandbox/providers/aliyun_codeinterpreter.py
@@ -0,0 +1,512 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Aliyun Code Interpreter provider implementation.
+
+This provider integrates with Aliyun Function Compute Code Interpreter service
+for secure code execution in serverless microVMs using the official agentrun-sdk.
+
+Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+Official SDK: https://github.com/Serverless-Devs/agentrun-sdk-python
+
+https://api.aliyun.com/api/AgentRun/2025-09-10/CreateTemplate?lang=PYTHON
+https://api.aliyun.com/api/AgentRun/2025-09-10/CreateSandbox?lang=PYTHON
+"""
+
+import logging
+import os
+import time
+from typing import Dict, Any, List, Optional
+from datetime import datetime, timezone
+
+from agentrun.sandbox import TemplateType, CodeLanguage, Template, TemplateInput, Sandbox
+from agentrun.utils.config import Config
+from agentrun.utils.exception import ServerError
+
+from .base import SandboxProvider, SandboxInstance, ExecutionResult
+
+logger = logging.getLogger(__name__)
+
+
+class AliyunCodeInterpreterProvider(SandboxProvider):
+ """
+ Aliyun Code Interpreter provider implementation.
+
+ This provider uses the official agentrun-sdk to interact with
+ Aliyun Function Compute's Code Interpreter service.
+ """
+
+ def __init__(self):
+ self.access_key_id: Optional[str] = None
+ self.access_key_secret: Optional[str] = None
+ self.account_id: Optional[str] = None
+ self.region: str = "cn-hangzhou"
+ self.template_name: str = ""
+ self.timeout: int = 30
+ self._initialized: bool = False
+ self._config: Optional[Config] = None
+
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """
+ Initialize the provider with Aliyun credentials.
+
+ Args:
+ config: Configuration dictionary with keys:
+ - access_key_id: Aliyun AccessKey ID
+ - access_key_secret: Aliyun AccessKey Secret
+ - account_id: Aliyun primary account ID (主账号ID)
+ - region: Region (default: "cn-hangzhou")
+ - template_name: Optional sandbox template name
+ - timeout: Request timeout in seconds (default: 30, max 30)
+
+ Returns:
+ True if initialization successful, False otherwise
+ """
+ # Get values from config or environment variables
+ access_key_id = config.get("access_key_id") or os.getenv("AGENTRUN_ACCESS_KEY_ID")
+ access_key_secret = config.get("access_key_secret") or os.getenv("AGENTRUN_ACCESS_KEY_SECRET")
+ account_id = config.get("account_id") or os.getenv("AGENTRUN_ACCOUNT_ID")
+ region = config.get("region") or os.getenv("AGENTRUN_REGION", "cn-hangzhou")
+
+ self.access_key_id = access_key_id
+ self.access_key_secret = access_key_secret
+ self.account_id = account_id
+ self.region = region
+ self.template_name = config.get("template_name", "")
+ self.timeout = min(config.get("timeout", 30), 30) # Max 30 seconds
+
+ logger.info(f"Aliyun Code Interpreter: Initializing with account_id={self.account_id}, region={self.region}")
+
+ # Validate required fields
+ if not self.access_key_id or not self.access_key_secret:
+ logger.error("Aliyun Code Interpreter: Missing access_key_id or access_key_secret")
+ return False
+
+ if not self.account_id:
+ logger.error("Aliyun Code Interpreter: Missing account_id (主账号ID)")
+ return False
+
+ # Create SDK configuration
+ try:
+ logger.info(f"Aliyun Code Interpreter: Creating Config object with account_id={self.account_id}")
+ self._config = Config(
+ access_key_id=self.access_key_id,
+ access_key_secret=self.access_key_secret,
+ account_id=self.account_id,
+ region_id=self.region,
+ timeout=self.timeout,
+ )
+ logger.info("Aliyun Code Interpreter: Config object created successfully")
+
+ # Verify connection with health check
+ if not self.health_check():
+ logger.error(f"Aliyun Code Interpreter: Health check failed for region {self.region}")
+ return False
+
+ self._initialized = True
+ logger.info(f"Aliyun Code Interpreter: Initialized successfully for region {self.region}")
+ return True
+
+ except Exception as e:
+ logger.error(f"Aliyun Code Interpreter: Initialization failed - {str(e)}")
+ return False
+
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """
+ Create a new sandbox instance in Aliyun Code Interpreter.
+
+ Args:
+ template: Programming language (python, javascript)
+
+ Returns:
+ SandboxInstance object
+
+ Raises:
+ RuntimeError: If instance creation fails
+ """
+ if not self._initialized or not self._config:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # Normalize language
+ language = self._normalize_language(template)
+
+ try:
+ # Get or create template
+ from agentrun.sandbox import Sandbox
+
+ if self.template_name:
+ # Use existing template
+ template_name = self.template_name
+ else:
+ # Try to get default template, or create one if it doesn't exist
+ default_template_name = f"ragflow-{language}-default"
+ try:
+ # Check if template exists
+ Template.get_by_name(default_template_name, config=self._config)
+ template_name = default_template_name
+ except Exception:
+ # Create default template if it doesn't exist
+ template_input = TemplateInput(
+ template_name=default_template_name,
+ template_type=TemplateType.CODE_INTERPRETER,
+ )
+ Template.create(template_input, config=self._config)
+ template_name = default_template_name
+
+ # Create sandbox directly
+ sandbox = Sandbox.create(
+ template_type=TemplateType.CODE_INTERPRETER,
+ template_name=template_name,
+ sandbox_idle_timeout_seconds=self.timeout,
+ config=self._config,
+ )
+
+ instance_id = sandbox.sandbox_id
+
+ return SandboxInstance(
+ instance_id=instance_id,
+ provider="aliyun_codeinterpreter",
+ status="READY",
+ metadata={
+ "language": language,
+ "region": self.region,
+ "account_id": self.account_id,
+ "template_name": template_name,
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ },
+ )
+
+ except ServerError as e:
+ raise RuntimeError(f"Failed to create sandbox instance: {str(e)}")
+ except Exception as e:
+ raise RuntimeError(f"Unexpected error creating instance: {str(e)}")
+
+ def execute_code(self, instance_id: str, code: str, language: str, timeout: int = 10, arguments: Optional[Dict[str, Any]] = None) -> ExecutionResult:
+ """
+ Execute code in the Aliyun Code Interpreter instance.
+
+ Args:
+ instance_id: ID of the sandbox instance
+ code: Source code to execute
+ language: Programming language (python, javascript)
+ timeout: Maximum execution time in seconds (max 30)
+ arguments: Optional arguments dict to pass to main() function
+
+ Returns:
+ ExecutionResult containing stdout, stderr, exit_code, and metadata
+
+ Raises:
+ RuntimeError: If execution fails
+ TimeoutError: If execution exceeds timeout
+ """
+ if not self._initialized or not self._config:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # Normalize language
+ normalized_lang = self._normalize_language(language)
+
+ # Enforce 30-second hard limit
+ timeout = min(timeout or self.timeout, 30)
+
+ try:
+ # Connect to existing sandbox instance
+ sandbox = Sandbox.connect(sandbox_id=instance_id, config=self._config)
+
+ # Convert language string to CodeLanguage enum
+ code_language = CodeLanguage.PYTHON if normalized_lang == "python" else CodeLanguage.JAVASCRIPT
+
+ # Wrap code to call main() function
+ # Matches self_managed provider behavior: call main(**arguments)
+ if normalized_lang == "python":
+ # Build arguments string for main() call
+ if arguments:
+ import json as json_module
+ args_json = json_module.dumps(arguments)
+ wrapped_code = f'''{code}
+
+if __name__ == "__main__":
+ import json
+ result = main(**{args_json})
+ print(json.dumps(result) if isinstance(result, dict) else result)
+'''
+ else:
+ wrapped_code = f'''{code}
+
+if __name__ == "__main__":
+ import json
+ result = main()
+ print(json.dumps(result) if isinstance(result, dict) else result)
+'''
+ else: # javascript
+ if arguments:
+ import json as json_module
+ args_json = json_module.dumps(arguments)
+ wrapped_code = f'''{code}
+
+// Call main and output result
+const result = main({args_json});
+console.log(typeof result === 'object' ? JSON.stringify(result) : String(result));
+'''
+ else:
+ wrapped_code = f'''{code}
+
+// Call main and output result
+const result = main();
+console.log(typeof result === 'object' ? JSON.stringify(result) : String(result));
+'''
+ logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}")
+
+ start_time = time.time()
+
+ # Execute code using SDK's simplified execute endpoint
+ logger.info(f"Aliyun Code Interpreter: Executing code (language={normalized_lang}, timeout={timeout})")
+ logger.debug(f"Aliyun Code Interpreter: Original code (first 200 chars): {code[:200]}")
+ result = sandbox.context.execute(
+ code=wrapped_code,
+ language=code_language,
+ timeout=timeout,
+ )
+
+ execution_time = time.time() - start_time
+ logger.info(f"Aliyun Code Interpreter: Execution completed in {execution_time:.2f}s")
+ logger.debug(f"Aliyun Code Interpreter: Raw SDK result: {result}")
+
+ # Parse execution result
+ results = result.get("results", []) if isinstance(result, dict) else []
+ logger.info(f"Aliyun Code Interpreter: Parsed {len(results)} result items")
+
+ # Extract stdout and stderr from results
+ stdout_parts = []
+ stderr_parts = []
+ exit_code = 0
+ execution_status = "ok"
+
+ for item in results:
+ result_type = item.get("type", "")
+ text = item.get("text", "")
+
+ if result_type == "stdout":
+ stdout_parts.append(text)
+ elif result_type == "stderr":
+ stderr_parts.append(text)
+ exit_code = 1 # Error occurred
+ elif result_type == "endOfExecution":
+ execution_status = item.get("status", "ok")
+ if execution_status != "ok":
+ exit_code = 1
+ elif result_type == "error":
+ stderr_parts.append(text)
+ exit_code = 1
+
+ stdout = "\n".join(stdout_parts)
+ stderr = "\n".join(stderr_parts)
+
+ logger.info(f"Aliyun Code Interpreter: stdout length={len(stdout)}, stderr length={len(stderr)}, exit_code={exit_code}")
+ if stdout:
+ logger.debug(f"Aliyun Code Interpreter: stdout (first 200 chars): {stdout[:200]}")
+ if stderr:
+ logger.debug(f"Aliyun Code Interpreter: stderr (first 200 chars): {stderr[:200]}")
+
+ return ExecutionResult(
+ stdout=stdout,
+ stderr=stderr,
+ exit_code=exit_code,
+ execution_time=execution_time,
+ metadata={
+ "instance_id": instance_id,
+ "language": normalized_lang,
+ "context_id": result.get("contextId") if isinstance(result, dict) else None,
+ "timeout": timeout,
+ },
+ )
+
+ except ServerError as e:
+ if "timeout" in str(e).lower():
+ raise TimeoutError(f"Execution timed out after {timeout} seconds")
+ raise RuntimeError(f"Failed to execute code: {str(e)}")
+ except Exception as e:
+ raise RuntimeError(f"Unexpected error during execution: {str(e)}")
+
+ def destroy_instance(self, instance_id: str) -> bool:
+ """
+ Destroy an Aliyun Code Interpreter instance.
+
+ Args:
+ instance_id: ID of the instance to destroy
+
+ Returns:
+ True if destruction successful, False otherwise
+ """
+ if not self._initialized or not self._config:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ try:
+ # Delete sandbox by ID directly
+ Sandbox.delete_by_id(sandbox_id=instance_id)
+
+ logger.info(f"Successfully destroyed sandbox instance {instance_id}")
+ return True
+
+ except ServerError as e:
+ logger.error(f"Failed to destroy instance {instance_id}: {str(e)}")
+ return False
+ except Exception as e:
+ logger.error(f"Unexpected error destroying instance {instance_id}: {str(e)}")
+ return False
+
+ def health_check(self) -> bool:
+ """
+ Check if the Aliyun Code Interpreter service is accessible.
+
+ Returns:
+ True if provider is healthy, False otherwise
+ """
+ if not self._initialized and not (self.access_key_id and self.account_id):
+ return False
+
+ try:
+ # Try to list templates to verify connection
+ from agentrun.sandbox import Template
+
+ templates = Template.list(config=self._config)
+ return templates is not None
+
+ except Exception as e:
+ logger.warning(f"Aliyun Code Interpreter health check failed: {str(e)}")
+ # If we get any response (even an error), the service is reachable
+ return "connection" not in str(e).lower()
+
+ def get_supported_languages(self) -> List[str]:
+ """
+ Get list of supported programming languages.
+
+ Returns:
+ List of language identifiers
+ """
+ return ["python", "javascript"]
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """
+ Return configuration schema for Aliyun Code Interpreter provider.
+
+ Returns:
+ Dictionary mapping field names to their schema definitions
+ """
+ return {
+ "access_key_id": {
+ "type": "string",
+ "required": True,
+ "label": "Access Key ID",
+ "placeholder": "LTAI5t...",
+ "description": "Aliyun AccessKey ID for authentication",
+ "secret": False,
+ },
+ "access_key_secret": {
+ "type": "string",
+ "required": True,
+ "label": "Access Key Secret",
+ "placeholder": "••••••••••••••••",
+ "description": "Aliyun AccessKey Secret for authentication",
+ "secret": True,
+ },
+ "account_id": {
+ "type": "string",
+ "required": True,
+ "label": "Account ID",
+ "placeholder": "1234567890...",
+ "description": "Aliyun primary account ID (主账号ID), required for API calls",
+ },
+ "region": {
+ "type": "string",
+ "required": False,
+ "label": "Region",
+ "default": "cn-hangzhou",
+ "description": "Aliyun region for Code Interpreter service",
+ "options": ["cn-hangzhou", "cn-beijing", "cn-shanghai", "cn-shenzhen", "cn-guangzhou"],
+ },
+ "template_name": {
+ "type": "string",
+ "required": False,
+ "label": "Template Name",
+ "placeholder": "my-interpreter",
+ "description": "Optional sandbox template name for pre-configured environments",
+ },
+ "timeout": {
+ "type": "integer",
+ "required": False,
+ "label": "Execution Timeout (seconds)",
+ "default": 30,
+ "min": 1,
+ "max": 30,
+ "description": "Code execution timeout (max 30 seconds - hard limit)",
+ },
+ }
+
+ def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
+ """
+ Validate Aliyun-specific configuration.
+
+ Args:
+ config: Configuration dictionary to validate
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ """
+ # Validate access key format
+ access_key_id = config.get("access_key_id", "")
+ if access_key_id and not access_key_id.startswith("LTAI"):
+ return False, "Invalid AccessKey ID format (should start with 'LTAI')"
+
+ # Validate account ID
+ account_id = config.get("account_id", "")
+ if not account_id:
+ return False, "Account ID is required"
+
+ # Validate region
+ valid_regions = ["cn-hangzhou", "cn-beijing", "cn-shanghai", "cn-shenzhen", "cn-guangzhou"]
+ region = config.get("region", "cn-hangzhou")
+ if region and region not in valid_regions:
+ return False, f"Invalid region. Must be one of: {', '.join(valid_regions)}"
+
+ # Validate timeout range (max 30 seconds)
+ timeout = config.get("timeout", 30)
+ if isinstance(timeout, int) and (timeout < 1 or timeout > 30):
+ return False, "Timeout must be between 1 and 30 seconds"
+
+ return True, None
+
+ def _normalize_language(self, language: str) -> str:
+ """
+ Normalize language identifier to Aliyun format.
+
+ Args:
+ language: Language identifier (python, python3, javascript, nodejs)
+
+ Returns:
+ Normalized language identifier
+ """
+ if not language:
+ return "python"
+
+ lang_lower = language.lower()
+ if lang_lower in ("python", "python3"):
+ return "python"
+ elif lang_lower in ("javascript", "nodejs"):
+ return "javascript"
+ else:
+ return language
diff --git a/agent/sandbox/providers/base.py b/agent/sandbox/providers/base.py
new file mode 100644
index 00000000000..c21b583e02b
--- /dev/null
+++ b/agent/sandbox/providers/base.py
@@ -0,0 +1,212 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Base interface for sandbox providers.
+
+Each sandbox provider (self-managed, SaaS) implements this interface
+to provide code execution capabilities.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Dict, Any, Optional, List
+
+
+@dataclass
+class SandboxInstance:
+ """Represents a sandbox execution instance"""
+ instance_id: str
+ provider: str
+ status: str # running, stopped, error
+ metadata: Dict[str, Any]
+
+ def __post_init__(self):
+ if self.metadata is None:
+ self.metadata = {}
+
+
+@dataclass
+class ExecutionResult:
+ """Result of code execution in a sandbox"""
+ stdout: str
+ stderr: str
+ exit_code: int
+ execution_time: float # in seconds
+ metadata: Dict[str, Any]
+
+ def __post_init__(self):
+ if self.metadata is None:
+ self.metadata = {}
+
+
+class SandboxProvider(ABC):
+ """
+ Base interface for all sandbox providers.
+
+ Each provider implementation (self-managed, Aliyun OpenSandbox, E2B, etc.)
+ must implement these methods to provide code execution capabilities.
+ """
+
+ @abstractmethod
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """
+ Initialize the provider with configuration.
+
+ Args:
+ config: Provider-specific configuration dictionary
+
+ Returns:
+ True if initialization successful, False otherwise
+ """
+ pass
+
+ @abstractmethod
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """
+ Create a new sandbox instance.
+
+ Args:
+ template: Programming language/template for the instance
+ (e.g., "python", "nodejs", "bash")
+
+ Returns:
+ SandboxInstance object representing the created instance
+
+ Raises:
+ RuntimeError: If instance creation fails
+ """
+ pass
+
+ @abstractmethod
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10,
+ arguments: Optional[Dict[str, Any]] = None
+ ) -> ExecutionResult:
+ """
+ Execute code in a sandbox instance.
+
+ Args:
+ instance_id: ID of the sandbox instance
+ code: Source code to execute
+ language: Programming language (python, javascript, etc.)
+ timeout: Maximum execution time in seconds
+ arguments: Optional arguments dict to pass to main() function
+
+ Returns:
+ ExecutionResult containing stdout, stderr, exit_code, and metadata
+
+ Raises:
+ RuntimeError: If execution fails
+ TimeoutError: If execution exceeds timeout
+ """
+ pass
+
+ @abstractmethod
+ def destroy_instance(self, instance_id: str) -> bool:
+ """
+ Destroy a sandbox instance.
+
+ Args:
+ instance_id: ID of the instance to destroy
+
+ Returns:
+ True if destruction successful, False otherwise
+
+ Raises:
+ RuntimeError: If destruction fails
+ """
+ pass
+
+ @abstractmethod
+ def health_check(self) -> bool:
+ """
+ Check if the provider is healthy and accessible.
+
+ Returns:
+ True if provider is healthy, False otherwise
+ """
+ pass
+
+ @abstractmethod
+ def get_supported_languages(self) -> List[str]:
+ """
+ Get list of supported programming languages.
+
+ Returns:
+ List of language identifiers (e.g., ["python", "javascript", "go"])
+ """
+ pass
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """
+ Return configuration schema for this provider.
+
+ The schema defines what configuration fields are required/optional,
+ their types, validation rules, and UI labels.
+
+ Returns:
+ Dictionary mapping field names to their schema definitions.
+
+ Example:
+ {
+ "endpoint": {
+ "type": "string",
+ "required": True,
+ "label": "API Endpoint",
+ "placeholder": "http://localhost:9385"
+ },
+ "timeout": {
+ "type": "integer",
+ "default": 30,
+ "label": "Timeout (seconds)",
+ "min": 5,
+ "max": 300
+ }
+ }
+ """
+ return {}
+
+ def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
+ """
+ Validate provider-specific configuration.
+
+ This method allows providers to implement custom validation logic beyond
+ the basic schema validation. Override this method to add provider-specific
+ checks like URL format validation, API key format validation, etc.
+
+ Args:
+ config: Configuration dictionary to validate
+
+ Returns:
+ Tuple of (is_valid, error_message):
+ - is_valid: True if configuration is valid, False otherwise
+ - error_message: Error message if invalid, None if valid
+
+ Example:
+ >>> def validate_config(self, config):
+ >>> endpoint = config.get("endpoint", "")
+ >>> if not endpoint.startswith(("http://", "https://")):
+ >>> return False, "Endpoint must start with http:// or https://"
+ >>> return True, None
+ """
+ # Default implementation: no custom validation
+ return True, None
\ No newline at end of file
diff --git a/agent/sandbox/providers/e2b.py b/agent/sandbox/providers/e2b.py
new file mode 100644
index 00000000000..5c4bd5d912e
--- /dev/null
+++ b/agent/sandbox/providers/e2b.py
@@ -0,0 +1,233 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+E2B provider implementation.
+
+This provider integrates with E2B Cloud for cloud-based code execution
+using Firecracker microVMs.
+"""
+
+import uuid
+from typing import Dict, Any, List
+
+from .base import SandboxProvider, SandboxInstance, ExecutionResult
+
+
+class E2BProvider(SandboxProvider):
+ """
+ E2B provider implementation.
+
+ This provider uses E2B Cloud service for secure code execution
+ in Firecracker microVMs.
+ """
+
+ def __init__(self):
+ self.api_key: str = ""
+ self.region: str = "us"
+ self.timeout: int = 30
+ self._initialized: bool = False
+
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """
+ Initialize the provider with E2B credentials.
+
+ Args:
+ config: Configuration dictionary with keys:
+ - api_key: E2B API key
+ - region: Region (us, eu) (default: "us")
+ - timeout: Request timeout in seconds (default: 30)
+
+ Returns:
+ True if initialization successful, False otherwise
+ """
+ self.api_key = config.get("api_key", "")
+ self.region = config.get("region", "us")
+ self.timeout = config.get("timeout", 30)
+
+ # Validate required fields
+ if not self.api_key:
+ return False
+
+ # TODO: Implement actual E2B API client initialization
+ # For now, we'll mark as initialized but actual API calls will fail
+ self._initialized = True
+ return True
+
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """
+ Create a new sandbox instance in E2B.
+
+ Args:
+ template: Programming language template (python, nodejs, go, bash)
+
+ Returns:
+ SandboxInstance object
+
+ Raises:
+ RuntimeError: If instance creation fails
+ """
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # Normalize language
+ language = self._normalize_language(template)
+
+ # TODO: Implement actual E2B API call
+ # POST /sandbox with template
+ instance_id = str(uuid.uuid4())
+
+ return SandboxInstance(
+ instance_id=instance_id,
+ provider="e2b",
+ status="running",
+ metadata={
+ "language": language,
+ "region": self.region,
+ }
+ )
+
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10
+ ) -> ExecutionResult:
+ """
+ Execute code in the E2B instance.
+
+ Args:
+ instance_id: ID of the sandbox instance
+ code: Source code to execute
+ language: Programming language (python, nodejs, go, bash)
+ timeout: Maximum execution time in seconds
+
+ Returns:
+ ExecutionResult containing stdout, stderr, exit_code, and metadata
+
+ Raises:
+ RuntimeError: If execution fails
+ TimeoutError: If execution exceeds timeout
+ """
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # TODO: Implement actual E2B API call
+ # POST /sandbox/{sandboxID}/execute
+
+ raise RuntimeError(
+ "E2B provider is not yet fully implemented. "
+ "Please use the self-managed provider or implement the E2B API integration. "
+ "See https://github.com/e2b-dev/e2b for API documentation."
+ )
+
+ def destroy_instance(self, instance_id: str) -> bool:
+ """
+ Destroy an E2B instance.
+
+ Args:
+ instance_id: ID of the instance to destroy
+
+ Returns:
+ True if destruction successful, False otherwise
+ """
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # TODO: Implement actual E2B API call
+ # DELETE /sandbox/{sandboxID}
+ return True
+
+ def health_check(self) -> bool:
+ """
+ Check if the E2B service is accessible.
+
+ Returns:
+ True if provider is healthy, False otherwise
+ """
+ if not self._initialized:
+ return False
+
+ # TODO: Implement actual E2B health check API call
+ # GET /healthz or similar
+ # For now, return True if initialized with API key
+ return bool(self.api_key)
+
+ def get_supported_languages(self) -> List[str]:
+ """
+ Get list of supported programming languages.
+
+ Returns:
+ List of language identifiers
+ """
+ return ["python", "nodejs", "javascript", "go", "bash"]
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """
+ Return configuration schema for E2B provider.
+
+ Returns:
+ Dictionary mapping field names to their schema definitions
+ """
+ return {
+ "api_key": {
+ "type": "string",
+ "required": True,
+ "label": "API Key",
+ "placeholder": "e2b_sk_...",
+ "description": "E2B API key for authentication",
+ "secret": True,
+ },
+ "region": {
+ "type": "string",
+ "required": False,
+ "label": "Region",
+ "default": "us",
+ "description": "E2B service region (us or eu)",
+ },
+ "timeout": {
+ "type": "integer",
+ "required": False,
+ "label": "Request Timeout (seconds)",
+ "default": 30,
+ "min": 5,
+ "max": 300,
+ "description": "API request timeout for code execution",
+ }
+ }
+
+ def _normalize_language(self, language: str) -> str:
+ """
+ Normalize language identifier to E2B template format.
+
+ Args:
+ language: Language identifier
+
+ Returns:
+ Normalized language identifier
+ """
+ if not language:
+ return "python"
+
+ lang_lower = language.lower()
+ if lang_lower in ("python", "python3"):
+ return "python"
+ elif lang_lower in ("javascript", "nodejs"):
+ return "nodejs"
+ else:
+ return language
diff --git a/agent/sandbox/providers/manager.py b/agent/sandbox/providers/manager.py
new file mode 100644
index 00000000000..3a6fce5c25a
--- /dev/null
+++ b/agent/sandbox/providers/manager.py
@@ -0,0 +1,78 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Provider manager for sandbox providers.
+
+Since sandbox configuration is global (system-level), we only use one
+active provider at a time. This manager is a thin wrapper that holds a reference
+to the currently active provider.
+"""
+
+from typing import Optional
+from .base import SandboxProvider
+
+
+class ProviderManager:
+ """
+ Manages the currently active sandbox provider.
+
+ With global configuration, there's only one active provider at a time.
+ This manager simply holds a reference to that provider.
+ """
+
+ def __init__(self):
+ """Initialize an empty provider manager."""
+ self.current_provider: Optional[SandboxProvider] = None
+ self.current_provider_name: Optional[str] = None
+
+ def set_provider(self, name: str, provider: SandboxProvider):
+ """
+ Set the active provider.
+
+ Args:
+ name: Provider identifier (e.g., "self_managed", "e2b")
+ provider: Provider instance
+ """
+ self.current_provider = provider
+ self.current_provider_name = name
+
+ def get_provider(self) -> Optional[SandboxProvider]:
+ """
+ Get the active provider.
+
+ Returns:
+ Currently active SandboxProvider instance, or None if not set
+ """
+ return self.current_provider
+
+ def get_provider_name(self) -> Optional[str]:
+ """
+ Get the active provider name.
+
+ Returns:
+ Provider name (e.g., "self_managed"), or None if not set
+ """
+ return self.current_provider_name
+
+ def is_configured(self) -> bool:
+ """
+ Check if a provider is configured.
+
+ Returns:
+ True if a provider is set, False otherwise
+ """
+ return self.current_provider is not None
diff --git a/agent/sandbox/providers/self_managed.py b/agent/sandbox/providers/self_managed.py
new file mode 100644
index 00000000000..7078f6f761d
--- /dev/null
+++ b/agent/sandbox/providers/self_managed.py
@@ -0,0 +1,359 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Self-managed sandbox provider implementation.
+
+This provider wraps the existing executor_manager HTTP API which manages
+a pool of Docker containers with gVisor for secure code execution.
+"""
+
+import base64
+import time
+import uuid
+from typing import Dict, Any, List, Optional
+
+import requests
+
+from .base import SandboxProvider, SandboxInstance, ExecutionResult
+
+
+class SelfManagedProvider(SandboxProvider):
+ """
+ Self-managed sandbox provider using Daytona/Docker.
+
+ This provider communicates with the executor_manager HTTP API
+ which manages a pool of containers for code execution.
+ """
+
+ def __init__(self):
+ self.endpoint: str = "http://localhost:9385"
+ self.timeout: int = 30
+ self.max_retries: int = 3
+ self.pool_size: int = 10
+ self._initialized: bool = False
+
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """
+ Initialize the provider with configuration.
+
+ Args:
+ config: Configuration dictionary with keys:
+ - endpoint: HTTP endpoint (default: "http://localhost:9385")
+ - timeout: Request timeout in seconds (default: 30)
+ - max_retries: Maximum retry attempts (default: 3)
+ - pool_size: Container pool size for info (default: 10)
+
+ Returns:
+ True if initialization successful, False otherwise
+ """
+ self.endpoint = config.get("endpoint", "http://localhost:9385")
+ self.timeout = config.get("timeout", 30)
+ self.max_retries = config.get("max_retries", 3)
+ self.pool_size = config.get("pool_size", 10)
+
+ # Validate endpoint is accessible
+ if not self.health_check():
+ # Try to fall back to SANDBOX_HOST from settings if we are using localhost
+ if "localhost" in self.endpoint or "127.0.0.1" in self.endpoint:
+ try:
+ from api import settings
+ if settings.SANDBOX_HOST and settings.SANDBOX_HOST not in self.endpoint:
+ original_endpoint = self.endpoint
+ self.endpoint = f"http://{settings.SANDBOX_HOST}:9385"
+ if self.health_check():
+ import logging
+ logging.warning(f"Sandbox self_managed: Connected using settings.SANDBOX_HOST fallback: {self.endpoint} (original: {original_endpoint})")
+ self._initialized = True
+ return True
+ else:
+ self.endpoint = original_endpoint # Restore if fallback also fails
+ except ImportError:
+ pass
+
+ return False
+
+ self._initialized = True
+ return True
+
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """
+ Create a new sandbox instance.
+
+ Note: For self-managed provider, instances are managed internally
+ by the executor_manager's container pool. This method returns
+ a logical instance handle.
+
+ Args:
+ template: Programming language (python, nodejs)
+
+ Returns:
+ SandboxInstance object
+
+ Raises:
+ RuntimeError: If instance creation fails
+ """
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # Normalize language
+ language = self._normalize_language(template)
+
+ # The executor_manager manages instances internally via container pool
+ # We create a logical instance ID for tracking
+ instance_id = str(uuid.uuid4())
+
+ return SandboxInstance(
+ instance_id=instance_id,
+ provider="self_managed",
+ status="running",
+ metadata={
+ "language": language,
+ "endpoint": self.endpoint,
+ "pool_size": self.pool_size,
+ }
+ )
+
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10,
+ arguments: Optional[Dict[str, Any]] = None
+ ) -> ExecutionResult:
+ """
+ Execute code in the sandbox.
+
+ Args:
+ instance_id: ID of the sandbox instance (not used for self-managed)
+ code: Source code to execute
+ language: Programming language (python, nodejs, javascript)
+ timeout: Maximum execution time in seconds
+ arguments: Optional arguments dict to pass to main() function
+
+ Returns:
+ ExecutionResult containing stdout, stderr, exit_code, and metadata
+
+ Raises:
+ RuntimeError: If execution fails
+ TimeoutError: If execution exceeds timeout
+ """
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ # Normalize language
+ normalized_lang = self._normalize_language(language)
+
+ # Prepare request
+ code_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8")
+ payload = {
+ "code_b64": code_b64,
+ "language": normalized_lang,
+ "arguments": arguments or {}
+ }
+
+ url = f"{self.endpoint}/run"
+ exec_timeout = timeout or self.timeout
+
+ start_time = time.time()
+
+ try:
+ response = requests.post(
+ url,
+ json=payload,
+ timeout=exec_timeout,
+ headers={"Content-Type": "application/json"}
+ )
+
+ execution_time = time.time() - start_time
+
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"HTTP {response.status_code}: {response.text}"
+ )
+
+ result = response.json()
+
+ return ExecutionResult(
+ stdout=result.get("stdout", ""),
+ stderr=result.get("stderr", ""),
+ exit_code=result.get("exit_code", 0),
+ execution_time=execution_time,
+ metadata={
+ "status": result.get("status"),
+ "time_used_ms": result.get("time_used_ms"),
+ "memory_used_kb": result.get("memory_used_kb"),
+ "detail": result.get("detail"),
+ "instance_id": instance_id,
+ }
+ )
+
+ except requests.Timeout:
+ execution_time = time.time() - start_time
+ raise TimeoutError(
+ f"Execution timed out after {exec_timeout} seconds"
+ )
+
+ except requests.RequestException as e:
+ raise RuntimeError(f"HTTP request failed: {str(e)}")
+
+ def destroy_instance(self, instance_id: str) -> bool:
+ """
+ Destroy a sandbox instance.
+
+ Note: For self-managed provider, instances are returned to the
+ internal pool automatically by executor_manager after execution.
+ This is a no-op for tracking purposes.
+
+ Args:
+ instance_id: ID of the instance to destroy
+
+ Returns:
+ True (always succeeds for self-managed)
+ """
+ # The executor_manager manages container lifecycle internally
+ # Container is returned to pool after execution
+ return True
+
+ def health_check(self) -> bool:
+ """
+ Check if the provider is healthy and accessible.
+
+ Returns:
+ True if provider is healthy, False otherwise
+ """
+ try:
+ url = f"{self.endpoint}/healthz"
+ response = requests.get(url, timeout=5)
+ return response.status_code == 200
+ except Exception:
+ return False
+
+ def get_supported_languages(self) -> List[str]:
+ """
+ Get list of supported programming languages.
+
+ Returns:
+ List of language identifiers
+ """
+ return ["python", "nodejs", "javascript"]
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """
+ Return configuration schema for self-managed provider.
+
+ Returns:
+ Dictionary mapping field names to their schema definitions
+ """
+ return {
+ "endpoint": {
+ "type": "string",
+ "required": True,
+ "label": "Executor Manager Endpoint",
+ "placeholder": "http://localhost:9385",
+ "default": "http://localhost:9385",
+ "description": "HTTP endpoint of the executor_manager service"
+ },
+ "timeout": {
+ "type": "integer",
+ "required": False,
+ "label": "Request Timeout (seconds)",
+ "default": 30,
+ "min": 5,
+ "max": 300,
+ "description": "HTTP request timeout for code execution"
+ },
+ "max_retries": {
+ "type": "integer",
+ "required": False,
+ "label": "Max Retries",
+ "default": 3,
+ "min": 0,
+ "max": 10,
+ "description": "Maximum number of retry attempts for failed requests"
+ },
+ "pool_size": {
+ "type": "integer",
+ "required": False,
+ "label": "Container Pool Size",
+ "default": 10,
+ "min": 1,
+ "max": 100,
+ "description": "Size of the container pool (configured in executor_manager)"
+ }
+ }
+
+ def _normalize_language(self, language: str) -> str:
+ """
+ Normalize language identifier to executor_manager format.
+
+ Args:
+ language: Language identifier (python, python3, nodejs, javascript)
+
+ Returns:
+ Normalized language identifier
+ """
+ if not language:
+ return "python"
+
+ lang_lower = language.lower()
+ if lang_lower in ("python", "python3"):
+ return "python"
+ elif lang_lower in ("javascript", "nodejs"):
+ return "nodejs"
+ else:
+ return language
+
+ def validate_config(self, config: dict) -> tuple[bool, Optional[str]]:
+ """
+ Validate self-managed provider configuration.
+
+ Performs custom validation beyond the basic schema validation,
+ such as checking URL format.
+
+ Args:
+ config: Configuration dictionary to validate
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ """
+ # Validate endpoint URL format
+ endpoint = config.get("endpoint", "")
+ if endpoint:
+ # Check if it's a valid HTTP/HTTPS URL or localhost
+ import re
+ url_pattern = r'^(https?://|http://localhost|http://[\d\.]+:[a-z]+:[/]|http://[\w\.]+:)'
+ if not re.match(url_pattern, endpoint):
+ return False, f"Invalid endpoint format: {endpoint}. Must start with http:// or https://"
+
+ # Validate pool_size is positive
+ pool_size = config.get("pool_size", 10)
+ if isinstance(pool_size, int) and pool_size <= 0:
+ return False, "Pool size must be greater than 0"
+
+ # Validate timeout is reasonable
+ timeout = config.get("timeout", 30)
+ if isinstance(timeout, int) and (timeout < 1 or timeout > 600):
+ return False, "Timeout must be between 1 and 600 seconds"
+
+ # Validate max_retries
+ max_retries = config.get("max_retries", 3)
+ if isinstance(max_retries, int) and (max_retries < 0 or max_retries > 10):
+ return False, "Max retries must be between 0 and 10"
+
+ return True, None
diff --git a/sandbox/pyproject.toml b/agent/sandbox/pyproject.toml
similarity index 100%
rename from sandbox/pyproject.toml
rename to agent/sandbox/pyproject.toml
diff --git a/sandbox/sandbox_base_image/nodejs/Dockerfile b/agent/sandbox/sandbox_base_image/nodejs/Dockerfile
similarity index 92%
rename from sandbox/sandbox_base_image/nodejs/Dockerfile
rename to agent/sandbox/sandbox_base_image/nodejs/Dockerfile
index ada730faf1c..fe7b19f7733 100644
--- a/sandbox/sandbox_base_image/nodejs/Dockerfile
+++ b/agent/sandbox/sandbox_base_image/nodejs/Dockerfile
@@ -1,4 +1,4 @@
-FROM node:24-bookworm-slim
+FROM node:24.13-bookworm-slim
RUN npm config set registry https://registry.npmmirror.com
diff --git a/sandbox/sandbox_base_image/nodejs/package-lock.json b/agent/sandbox/sandbox_base_image/nodejs/package-lock.json
similarity index 100%
rename from sandbox/sandbox_base_image/nodejs/package-lock.json
rename to agent/sandbox/sandbox_base_image/nodejs/package-lock.json
diff --git a/sandbox/sandbox_base_image/nodejs/package.json b/agent/sandbox/sandbox_base_image/nodejs/package.json
similarity index 100%
rename from sandbox/sandbox_base_image/nodejs/package.json
rename to agent/sandbox/sandbox_base_image/nodejs/package.json
diff --git a/sandbox/sandbox_base_image/python/Dockerfile b/agent/sandbox/sandbox_base_image/python/Dockerfile
similarity index 100%
rename from sandbox/sandbox_base_image/python/Dockerfile
rename to agent/sandbox/sandbox_base_image/python/Dockerfile
diff --git a/sandbox/sandbox_base_image/python/requirements.txt b/agent/sandbox/sandbox_base_image/python/requirements.txt
similarity index 100%
rename from sandbox/sandbox_base_image/python/requirements.txt
rename to agent/sandbox/sandbox_base_image/python/requirements.txt
diff --git a/agent/sandbox/sandbox_spec.md b/agent/sandbox/sandbox_spec.md
new file mode 100644
index 00000000000..56e832aaef7
--- /dev/null
+++ b/agent/sandbox/sandbox_spec.md
@@ -0,0 +1,1848 @@
+# RAGFlow Sandbox multi-provider architecture - design specification
+
+## 1. Overview
+
+The goal of this design specification is to enable RAGFlow to support multiple Sandbox deployment modes:
+
+- Self-Managed: On-premise deployment using Daytona/Docker (current implementation)
+- SaaS providers: Cloud-based sandbox services (Aliyun Code Interpreter, E2B)
+
+### Key requirements
+
+- Provider-agnostic interface for sandbox operations
+- Admin-configurable provider settings with dynamic schema
+- Multi-tenant isolation (1:1 session-to-sandbox mapping)
+- Graceful fallback and error handling
+- Unified monitoring and observability
+
+## Architecture
+
+### Provider abstraction layer
+
+Defines a unified `SandboxProvider` interface, and is located at `agent/sandbox/providers/`.
+
+```python
+# agent/sandbox/providers/base.py
+from abc import ABC, abstractmethod
+from typing import Dict, Any, Optional
+from dataclasses import dataclass
+
+@dataclass
+class SandboxInstance:
+ instance_id: str
+ provider: str
+ status: str # running, stopped, error
+ metadata: Dict[str, Any]
+
+@dataclass
+class ExecutionResult:
+ stdout: str
+ stderr: str
+ exit_code: int
+ execution_time: float
+ metadata: Dict[str, Any]
+
+class SandboxProvider(ABC):
+ """Base interface for all sandbox providers"""
+
+ @abstractmethod
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """Initialize provider with configuration"""
+ pass
+
+ @abstractmethod
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """Create a new sandbox instance"""
+ pass
+
+ @abstractmethod
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10
+ ) -> ExecutionResult:
+ """Execute code in the sandbox"""
+ pass
+
+ @abstractmethod
+ def destroy_instance(self, instance_id: str) -> bool:
+ """Destroy a sandbox instance"""
+ pass
+
+ @abstractmethod
+ def health_check(self) -> bool:
+ """Check if provider is healthy"""
+ pass
+
+ @abstractmethod
+ def get_supported_languages(self) -> list[str]:
+ """Get list of supported programming languages"""
+ pass
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """
+ Return configuration schema for this provider.
+
+ Returns a dictionary mapping field names to their schema definitions,
+ including type, required status, validation rules, labels, and descriptions.
+ """
+ pass
+
+ def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
+ """
+ Validate provider-specific configuration.
+
+ This method allows providers to implement custom validation logic beyond
+ the basic schema validation. Override this method to add provider-specific
+ checks like URL format validation, API key format validation, etc.
+
+ Args:
+ config: Configuration dictionary to validate
+
+ Returns:
+ Tuple of (is_valid, error_message):
+ - is_valid: True if configuration is valid, False otherwise
+ - error_message: Error message if invalid, None if valid
+ """
+ # Default implementation: no custom validation
+ return True, None
+```
+
+### Provider implementations
+
+#### Self-managed provider
+
+Wraps the existing executor_manager implementation. The implementation file is located at `agent/sandbox/providers/self_managed.py`.
+
+**Prerequisites**:
+
+- gVisor (runsc): Required for secure container isolation. Install with:
+ ```bash
+ go install gvisor.dev/gvisor/runsc@latest
+ sudo cp ~/go/bin/runsc /usr/local/bin/
+ runsc --version
+ ```
+ Or download from: https://github.com/google/gvisor/releases
+- Docker: Docker runtime with gVisor support.
+- Base Images: Pull sandbox base images:
+ ```bash
+ docker pull infiniflow/sandbox-base-python:latest
+ docker pull infiniflow/sandbox-base-nodejs:latest
+ ```
+
+**Configuration**: Docker API endpoint, pool size, resource limits:
+
+- `endpoint`: HTTP endpoint (default: "http://localhost:9385")
+- `timeout`: Request timeout in seconds (default: 30)
+- `max_retries`: Maximum retry attempts (default: 3)
+- `pool_size`: Container pool size (default: 10)
+
+**Languages**:
+- Python
+- Node.js
+- JavaScript
+
+**Security**:
+- gVisor (runsc runtime)
+- seccomp
+- read-only filesystem
+- memory limits
+
+**Advantages**:
+- Low latency (<90ms), data privacy, full control
+- No per-execution costs
+- Supports `arguments` parameter for passing data to `main()` function
+
+**Limitations**:
+- Operational overhead, finite resources
+- Requires gVisor installation for security
+- Pool exhaustion causes "Container pool is busy" errors
+
+**Common issues**:
+- `"Container pool is busy"`: Increase `SANDBOX_EXECUTOR_MANAGER_POOL_SIZE` (default: 1 in .env, should be 5+)
+- `Container creation fails`: Ensure gVisor is installed and accessible at `/usr/local/bin/runsc`
+
+#### 2.2.2 Aliyun code interpreter provider
+
+**File**: `agent/sandbox/providers/aliyun_codeinterpreter.py`
+
+SaaS integration with Aliyun Function Compute Code Interpreter service using the official agentrun-sdk.
+
+**Official Resources**:
+- API Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+- Official SDK: https://github.com/Serverless-Devs/agentrun-sdk-python
+- SDK Docs: https://docs.agent.run
+
+**Implementation**:
+- Uses official `agentrun-sdk` package
+- SDK handles authentication (AccessKey signature) automatically
+- Supports environment variable configuration
+- Structured error handling with `ServerError` exceptions
+
+**Configuration**:
+- `access_key_id`: Aliyun AccessKey ID
+- `access_key_secret`: Aliyun AccessKey Secret
+- `account_id`: Aliyun primary account ID - Required for API calls
+- `region`: Region (cn-hangzhou, cn-beijing, cn-shanghai, cn-shenzhen, cn-guangzhou)
+- `template_name`: Optional sandbox template name for pre-configured environments
+- `timeout`: Execution timeout (max 30 seconds - hard limit)
+
+**Languages**: Python, JavaScript
+
+**Security**: Serverless microVM isolation, 30-second hard timeout limit
+
+**Advantages**:
+- Official SDK with automatic signature handling
+- Unlimited scalability, no maintenance
+- China region support with low latency
+- Built-in file system management
+- Support for execution contexts (Jupyter kernel)
+- Context-based execution for state persistence
+
+**Limitations**:
+- Network dependency
+- 30-second execution time limit (hard limit)
+- Pay-as-you-go costs
+- Requires Aliyun primary account ID for API calls
+
+**Setup instructions - Creating a RAM user with minimal privileges**:
+
+⚠️ **Security warning**: Never use your Aliyun primary account (root account) AccessKey for SDK operations. Primary accounts have full resource permissions, and leaked credentials pose significant security risks.
+
+**Step 1: Create a RAM user**
+
+1. Log in to [RAM Console](https://ram.console.aliyun.com/)
+2. Navigate to **People** → **Users**
+3. Click **Create User**
+4. Configure the user:
+ - **Username**: e.g., `ragflow-sandbox-user`
+ - **Display Name**: e.g., `RAGFlow Sandbox Service Account`
+ - **Access Mode**: Check ✅ **OpenAPI/Programmatic Access** (this creates an AccessKey)
+ - **Console Login**: Optional (not needed for SDK-only access)
+5. Click **OK** and save the AccessKey ID and Secret immediately (displayed only once!)
+
+**Step 2: Create a custom authorization policy**
+
+Navigate to **Permissions** → **Policies** → **Create Policy** → **Custom Policy** → **Configuration Script (JSON)**
+
+Choose one of the following policy options based on your security requirements:
+
+**Option A: Minimal privilege policy (Recommended)**
+
+Grants only the permissions required by the AgentRun SDK:
+
+```json
+{
+ "Version": "1",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": [
+ "agentrun:CreateTemplate",
+ "agentrun:GetTemplate",
+ "agentrun:UpdateTemplate",
+ "agentrun:DeleteTemplate",
+ "agentrun:ListTemplates",
+ "agentrun:CreateSandbox",
+ "agentrun:GetSandbox",
+ "agentrun:DeleteSandbox",
+ "agentrun:StopSandbox",
+ "agentrun:ListSandboxes",
+ "agentrun:CreateContext",
+ "agentrun:ExecuteCode",
+ "agentrun:DeleteContext",
+ "agentrun:ListContexts",
+ "agentrun:CreateFile",
+ "agentrun:GetFile",
+ "agentrun:DeleteFile",
+ "agentrun:ListFiles",
+ "agentrun:CreateProcess",
+ "agentrun:GetProcess",
+ "agentrun:KillProcess",
+ "agentrun:ListProcesses",
+ "agentrun:CreateRecording",
+ "agentrun:GetRecording",
+ "agentrun:DeleteRecording",
+ "agentrun:ListRecordings",
+ "agentrun:CheckHealth"
+ ],
+ "Resource": [
+ "acs:agentrun:*:{account_id}:template/*",
+ "acs:agentrun:*:{account_id}:sandbox/*"
+ ]
+ }
+ ]
+}
+```
+
+> Replace `{account_id}` with your Aliyun primary account ID
+
+**Option B: Resource-Level privilege control (most secure)**
+
+Limits access to specific resource prefixes:
+
+```json
+{
+ "Version": "1",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": [
+ "agentrun:CreateTemplate",
+ "agentrun:GetTemplate",
+ "agentrun:UpdateTemplate",
+ "agentrun:DeleteTemplate",
+ "agentrun:ListTemplates"
+ ],
+ "Resource": "acs:agentrun:*:{account_id}:template/ragflow-*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": [
+ "agentrun:CreateSandbox",
+ "agentrun:GetSandbox",
+ "agentrun:DeleteSandbox",
+ "agentrun:StopSandbox",
+ "agentrun:ListSandboxes",
+ "agentrun:CheckHealth"
+ ],
+ "Resource": "acs:agentrun:*:{account_id}:sandbox/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": ["agentrun:*"],
+ "Resource": "acs:agentrun:*:{account_id}:sandbox/*/context/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": ["agentrun:*"],
+ "Resource": "acs:agentrun:*:{account_id}:sandbox/*/file/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": ["agentrun:*"],
+ "Resource": "acs:agentrun:*:{account_id}:sandbox/*/process/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": ["agentrun:*"],
+ "Resource": "acs:agentrun:*:{account_id}:sandbox/*/recording/*"
+ }
+ ]
+}
+```
+
+> This limits template creation to only those prefixed with `ragflow-*`
+
+**Option C: Full access (not recommended for production)**
+
+```json
+{
+ "Version": "1",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": "agentrun:*",
+ "Resource": "*"
+ }
+ ]
+}
+```
+
+**Step 3: Authorize the RAM user**
+
+1. Return to **Users** list
+2. Find the user you just created (e.g., `ragflow-sandbox-user`)
+3. Click **Add Permissions** in the Actions column
+4. In the **Custom Policy** tab, select the policy you created in Step 2
+5. Click **OK**
+
+**Step 4: Configure RAGFlow with the RAM User credentials**
+
+After creating the RAM user and obtaining the AccessKey, configure it in RAGFlow's admin settings or environment variables:
+
+```bash
+# Method 1: Environment variables (for development/testing)
+export AGENTRUN_ACCESS_KEY_ID="LTAI5t..." # RAM user's AccessKey ID
+export AGENTRUN_ACCESS_KEY_SECRET="xxx..." # RAM user's AccessKey Secret
+export AGENTRUN_ACCOUNT_ID="123456789..." # Your primary account ID
+export AGENTRUN_REGION="cn-hangzhou"
+```
+
+Or via Admin UI (recommended for production):
+
+1. Navigate to **Admin Settings** → **Sandbox Providers**
+2. Select **Aliyun Code Interpreter** provider
+3. Fill in the configuration:
+ - `access_key_id`: RAM user's AccessKey ID
+ - `access_key_secret`: RAM user's AccessKey Secret
+ - `account_id`: Your primary account ID
+ - `region`: e.g., `cn-hangzhou`
+
+**Step 5: Verify permissions**
+
+Test if the RAM user permissions are correctly configured:
+
+```python
+from agentrun.sandbox import Sandbox, TemplateInput, TemplateType
+
+try:
+ # Test template creation
+ template = Sandbox.create_template(
+ input=TemplateInput(
+ template_name="ragflow-permission-test",
+ template_type=TemplateType.CODE_INTERPRETER
+ )
+ )
+ print("✅ RAM user permissions are correctly configured")
+except Exception as e:
+ print(f"❌ Permission test failed: {e}")
+finally:
+ # Cleanup test resources
+ try:
+ Sandbox.delete_template("ragflow-permission-test")
+ except:
+ pass
+```
+
+**Security best practices**:
+
+- **Always use RAM user AccessKeys**, never primary account AccessKeys.
+- **Follow the principle of least privilege** - grant only necessary permissions.
+- **Rotate AccessKeys regularly** - recommend every 3-6 months.
+- **Enable MFA** - enable multi-factor authentication for RAM users.
+- **Use secure storage** - store credentials in environment variables or secret management services, never hardcode in code.
+- **Restrict IP access** - add IP whitelist policies for RAM users if needed.
+- **Monitor access logs** - regularly check RAM user access logs in CloudTrail.
+
+**References**:
+
+- [Aliyun RAM Documentation](https://help.aliyun.com/product/28625.html)
+- [RAM Policy Language](https://help.aliyun.com/document_detail/100676.html)
+- [AgentRun Official Documentation](https://docs.agent.run)
+- [AgentRun SDK GitHub](https://github.com/Serverless-Devs/agentrun-sdk-python)
+
+#### E2B provider
+
+The file is located at `agent/sandbox/providers/e2b.py`.
+
+SaaS integration with E2B Cloud.
+- **Configuration**: api_key, region (us/eu)
+- **Languages**: Python, JavaScript, Go, Bash, etc.
+- **Security**: Firecracker microVMs
+- **Advantages**: Global CDN, fast startup, multiple language support
+- **Limitations**: International network latency for China users
+
+### Provider management
+
+The file is located at `agent/sandbox/providers/manager.py`.
+
+Since we only use one active provider at a time (configured globally), the provider management is simplified:
+
+```python
+class ProviderManager:
+ """Manages the currently active sandbox provider"""
+
+ def __init__(self):
+ self.current_provider: Optional[SandboxProvider] = None
+ self.current_provider_name: Optional[str] = None
+
+ def set_provider(self, name: str, provider: SandboxProvider):
+ """Set the active provider"""
+ self.current_provider = provider
+ self.current_provider_name = name
+
+ def get_provider(self) -> Optional[SandboxProvider]:
+ """Get the active provider"""
+ return self.current_provider
+
+ def get_provider_name(self) -> Optional[str]:
+ """Get the active provider name"""
+ return self.current_provider_name
+```
+
+**Rationale**: With global configuration, there's only one active provider at a time. The provider manager simply holds a reference to the currently active provider, making it a thin wrapper rather than a complex multi-provider manager.
+
+## Admin configuration
+
+### Database Schema
+
+Use the existing **SystemSettings** table for global sandbox configuration:
+
+```python
+# In api/db/db_models.py
+
+class SystemSettings(DataBaseModel):
+ name = CharField(max_length=128, primary_key=True)
+ source = CharField(max_length=32, null=False, index=False)
+ data_type = CharField(max_length=32, null=False, index=False)
+ value = CharField(max_length=1024, null=False, index=False)
+```
+
+**Rationale**: Sandbox manager is a **system-level service** shared by all tenants:
+- No per-tenant configuration needed (unlike LLM providers where each tenant has their own API keys)
+- Global settings like system email, DOC_ENGINE, etc.
+- Managed by administrators only
+- Leverages existing `SettingsMgr` in admin interface
+
+**Storage Strategy**: Each provider's configuration stored as a **single JSON object**:
+- `sandbox.provider_type` - Active provider selection ("self_managed", "aliyun_codeinterpreter", "e2b")
+- `sandbox.self_managed` - JSON config for self-managed provider
+- `sandbox.aliyun_codeinterpreter` - JSON config for Aliyun Code Interpreter provider
+- `sandbox.e2b` - JSON config for E2B provider
+
+**Note**: The `value` field has a 1024 character limit, which should be sufficient for typical sandbox configurations. If larger configs are needed, consider using a TextField or a separate configuration table.
+
+### Configuration Schema
+
+Each provider's configuration is stored as a **single JSON object** in the `value` field:
+
+#### Self-managed provider
+
+```json
+{
+ "name": "sandbox.self_managed",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"endpoint\": \"http://localhost:9385\", \"pool_size\": 10, \"max_memory\": \"256m\", \"timeout\": 30}"
+}
+```
+
+#### Aliyun code interpreter
+```json
+{
+ "name": "sandbox.aliyun_codeinterpreter",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"access_key_id\": \"LTAI5t...\", \"access_key_secret\": \"xxxxx\", \"account_id\": \"1234567890...\", \"region\": \"cn-hangzhou\", \"timeout\": 30}"
+}
+```
+
+#### E2B
+```json
+{
+ "name": "sandbox.e2b",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"api_key\": \"e2b_sk_...\", \"region\": \"us\", \"timeout\": 30}"
+}
+```
+
+#### Active provider selection
+```json
+{
+ "name": "sandbox.provider_type",
+ "source": "variable",
+ "data_type": "string",
+ "value": "self_managed"
+}
+```
+
+### Provider self-describing Schema
+
+Each provider class implements a static method to describe its configuration schema:
+
+```python
+# agent/sandbox/providers/base.py
+
+class SandboxProvider(ABC):
+ """Base interface for all sandbox providers"""
+
+ @abstractmethod
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ """Initialize provider with configuration"""
+ pass
+
+ @abstractmethod
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ """Create a new sandbox instance"""
+ pass
+
+ @abstractmethod
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10
+ ) -> ExecutionResult:
+ """Execute code in the sandbox"""
+ pass
+
+ @abstractmethod
+ def destroy_instance(self, instance_id: str) -> bool:
+ """Destroy a sandbox instance"""
+ pass
+
+ @abstractmethod
+ def health_check(self) -> bool:
+ """Check if provider is healthy"""
+ pass
+
+ @abstractmethod
+ def get_supported_languages(self) -> list[str]:
+ """Get list of supported programming languages"""
+ pass
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ """Return configuration schema for this provider"""
+ return {}
+```
+
+**Example implementation**:
+
+```python
+# agent/sandbox/providers/self_managed.py
+
+class SelfManagedProvider(SandboxProvider):
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ return {
+ "endpoint": {
+ "type": "string",
+ "required": True,
+ "label": "API Endpoint",
+ "placeholder": "http://localhost:9385"
+ },
+ "pool_size": {
+ "type": "integer",
+ "default": 10,
+ "label": "Container Pool Size",
+ "min": 1,
+ "max": 100
+ },
+ "max_memory": {
+ "type": "string",
+ "default": "256m",
+ "label": "Max Memory per Container",
+ "options": ["128m", "256m", "512m", "1g"]
+ },
+ "timeout": {
+ "type": "integer",
+ "default": 30,
+ "label": "Execution Timeout (seconds)",
+ "min": 5,
+ "max": 300
+ }
+ }
+
+# agent/sandbox/providers/aliyun_codeinterpreter.py
+
+class AliyunCodeInterpreterProvider(SandboxProvider):
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ return {
+ "access_key_id": {
+ "type": "string",
+ "required": True,
+ "secret": True,
+ "label": "Access Key ID",
+ "description": "Aliyun AccessKey ID for authentication"
+ },
+ "access_key_secret": {
+ "type": "string",
+ "required": True,
+ "secret": True,
+ "label": "Access Key Secret",
+ "description": "Aliyun AccessKey Secret for authentication"
+ },
+ "account_id": {
+ "type": "string",
+ "required": True,
+ "label": "Account ID",
+ "description": "Aliyun primary account ID (主账号ID), required for API calls"
+ },
+ "region": {
+ "type": "string",
+ "default": "cn-hangzhou",
+ "label": "Region",
+ "options": ["cn-hangzhou", "cn-beijing", "cn-shanghai", "cn-shenzhen", "cn-guangzhou"],
+ "description": "Aliyun region for Code Interpreter service"
+ },
+ "template_name": {
+ "type": "string",
+ "required": False,
+ "label": "Template Name",
+ "description": "Optional sandbox template name for pre-configured environments"
+ },
+ "timeout": {
+ "type": "integer",
+ "default": 30,
+ "label": "Execution Timeout (seconds)",
+ "min": 1,
+ "max": 30,
+ "description": "Code execution timeout (max 30 seconds - hard limit)"
+ }
+ }
+
+# agent/sandbox/providers/e2b.py
+
+class E2BProvider(SandboxProvider):
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ return {
+ "api_key": {
+ "type": "string",
+ "required": True,
+ "secret": True,
+ "label": "API Key"
+ },
+ "region": {
+ "type": "string",
+ "default": "us",
+ "label": "Region",
+ "options": ["us", "eu"]
+ },
+ "timeout": {
+ "type": "integer",
+ "default": 30,
+ "label": "Execution Timeout (seconds)",
+ "min": 5,
+ "max": 300
+ }
+ }
+```
+
+**Benefits of Self-describing providers**:
+
+- Single source of truth - schema defined alongside implementation
+- Easy to add new providers - no central registry to update
+- Type safety - schema stays in sync with provider code
+- Flexible - frontend can use schema for validation or hardcode if preferred
+
+### Admin API endpoints
+
+Follow existing pattern in `admin/server/routes.py` and use `SettingsMgr`:
+
+```python
+# admin/server/routes.py (add new endpoints)
+
+from flask import request, jsonify
+import json
+from api.db.services.system_settings_service import SystemSettingsService
+from agent.agent.sandbox.providers.self_managed import SelfManagedProvider
+from agent.agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+from agent.agent.sandbox.providers.e2b import E2BProvider
+from admin.server.services import SettingsMgr
+
+# Map provider IDs to their classes
+PROVIDER_CLASSES = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+}
+
+@admin_bp.route('/api/admin/sandbox/providers', methods=['GET'])
+def list_sandbox_providers():
+ """List available sandbox providers with their schemas"""
+ providers = []
+ for provider_id, provider_class in PROVIDER_CLASSES.items():
+ schema = provider_class.get_config_schema()
+ providers.append({
+ "id": provider_id,
+ "name": provider_id.replace("_", " ").title(),
+ "config_schema": schema
+ })
+ return jsonify({"data": providers})
+
+@admin_bp.route('/api/admin/sandbox/config', methods=['GET'])
+def get_sandbox_config():
+ """Get current sandbox configuration"""
+ # Get active provider
+ active_provider_setting = SystemSettingsService.get_by_name("sandbox.provider_type")
+ active_provider = active_provider_setting[0].value if active_provider_setting else None
+
+ config = {"active": active_provider}
+
+ # Load all provider configs
+ for provider_id in PROVIDER_CLASSES.keys():
+ setting = SystemSettingsService.get_by_name(f"sandbox.{provider_id}")
+ if setting:
+ try:
+ config[provider_id] = json.loads(setting[0].value)
+ except json.JSONDecodeError:
+ config[provider_id] = {}
+ else:
+ # Return default values from schema
+ provider_class = PROVIDER_CLASSES[provider_id]
+ schema = provider_class.get_config_schema()
+ config[provider_id] = {
+ key: field_def.get("default", "")
+ for key, field_def in schema.items()
+ }
+
+ return jsonify({"data": config})
+
+@admin_bp.route('/api/admin/sandbox/config', methods=['POST'])
+def set_sandbox_config():
+ """
+ Update sandbox provider configuration.
+
+ Request Parameters:
+ - provider_type: Provider identifier (e.g., "self_managed", "e2b")
+ - config: Provider configuration dictionary
+ - set_active: (optional) If True, also set this provider as active.
+ Default: True for backward compatibility.
+ Set to False to update config without switching providers.
+ - test_connection: (optional) If True, test connection before saving
+
+ Response: Success message
+ """
+ req = request.json
+ provider_type = req.get('provider_type')
+ config = req.get('config')
+ set_active = req.get('set_active', True) # Default to True
+
+ # Validate provider exists
+ if provider_type not in PROVIDER_CLASSES:
+ return jsonify({"error": "Unknown provider"}), 400
+
+ # Validate configuration against schema
+ provider_class = PROVIDER_CLASSES[provider_type]
+ schema = provider_class.get_config_schema()
+ validation_result = validate_config(config, schema)
+ if not validation_result.valid:
+ return jsonify({"error": "Invalid config", "details": validation_result.errors}), 400
+
+ # Test connection if requested
+ if req.get('test_connection'):
+ test_result = test_provider_connection(provider_type, config)
+ if not test_result.success:
+ return jsonify({"error": "Connection failed", "details": test_result.error}), 400
+
+ # Store entire config as a single JSON record
+ config_json = json.dumps(config)
+ setting_name = f"sandbox.{provider_type}"
+
+ existing = SystemSettingsService.get_by_name(setting_name)
+ if existing:
+ SettingsMgr.update_by_name(setting_name, config_json)
+ else:
+ SystemSettingsService.save(
+ name=setting_name,
+ source="variable",
+ data_type="json",
+ value=config_json
+ )
+
+ # Set as active provider if requested (default: True)
+ if set_active:
+ SettingsMgr.update_by_name("sandbox.provider_type", provider_type)
+
+ return jsonify({"message": "Configuration saved"})
+
+@admin_bp.route('/api/admin/sandbox/test', methods=['POST'])
+def test_sandbox_connection():
+ """Test connection to sandbox provider"""
+ provider_type = request.json.get('provider_type')
+ config = request.json.get('config')
+
+ test_result = test_provider_connection(provider_type, config)
+ return jsonify({
+ "success": test_result.success,
+ "message": test_result.message,
+ "latency_ms": test_result.latency_ms
+ })
+
+@admin_bp.route('/api/admin/sandbox/active', methods=['PUT'])
+def set_active_sandbox_provider():
+ """Set active sandbox provider"""
+ provider_name = request.json.get('provider')
+
+ if provider_name not in PROVIDER_CLASSES:
+ return jsonify({"error": "Unknown provider"}), 400
+
+ # Check if provider is configured
+ provider_setting = SystemSettingsService.get_by_name(f"sandbox.{provider_name}")
+ if not provider_setting:
+ return jsonify({"error": "Provider not configured"}), 400
+
+ SettingsMgr.update_by_name("sandbox.provider_type", provider_name)
+ return jsonify({"message": "Active provider updated"})
+```
+
+## Frontend integration
+
+### Admin settings UI
+
+**Location**: `web/src/pages/SandboxSettings/index.tsx`
+
+```typescript
+import { Form, Select, Input, Button, Card, Space, Tag, message } from 'antd';
+import { listSandboxProviders, getSandboxConfig, setSandboxConfig, testSandboxConnection } from '@/utils/api';
+
+const SandboxSettings: React.FC = () => {
+ const [providers, setProviders] = useState([]);
+ const [configs, setConfigs] = useState([]);
+ const [selectedProvider, setSelectedProvider] = useState('');
+ const [testing, setTesting] = useState(false);
+
+ const providerSchema = providers.find(p => p.id === selectedProvider);
+
+ const renderConfigForm = () => {
+ if (!providerSchema) return null;
+
+ return (
+
+ {schema.secret ? (
+
+ ) : schema.type === 'integer' ? (
+
+ ) : schema.options ? (
+
+ {schema.options.map((opt: string) => (
+ {opt}
+ ))}
+
+ ) : (
+
+ )}
+
+ ))}
+
+ );
+ };
+
+ return (
+
+
+ {/* Provider Selection */}
+
+
+ {providers.map(provider => (
+
+
+
+ {provider.name}
+ {provider.tags.map(tag => (
+ {tag}
+ ))}
+
+
+ ))}
+
+
+
+ {/* Dynamic Configuration Form */}
+ {renderConfigForm()}
+
+ {/* Actions */}
+
+
+ Save Configuration
+
+
+ Test Connection
+
+
+
+
+ );
+};
+```
+
+### API client
+
+**File**: `web/src/utils/api.ts`
+
+```typescript
+export async function listSandboxProviders() {
+ return request<{ data: Provider[] }>('/api/admin/sandbox/providers');
+}
+
+export async function getSandboxConfig() {
+ return request<{ data: SandboxConfig }>('/api/admin/sandbox/config');
+}
+
+export async function setSandboxConfig(config: SandboxConfigRequest) {
+ return request('/api/admin/sandbox/config', {
+ method: 'POST',
+ data: config,
+ });
+}
+
+export async function testSandboxConnection(provider: string, config: any) {
+ return request('/api/admin/sandbox/test', {
+ method: 'POST',
+ data: { provider, config },
+ });
+}
+
+export async function setActiveSandboxProvider(provider: string) {
+ return request('/api/admin/sandbox/active', {
+ method: 'PUT',
+ data: { provider },
+ });
+}
+```
+
+### 4.3 Type Definitions
+
+**File**: `web/src/types/sandbox.ts`
+
+```typescript
+interface Provider {
+ id: string;
+ name: string;
+ description: string;
+ icon: string;
+ tags: string[];
+ config_schema: Record;
+ supported_languages: string[];
+}
+
+interface ConfigField {
+ type: 'string' | 'integer' | 'boolean';
+ required: boolean;
+ secret?: boolean;
+ label: string;
+ placeholder?: string;
+ default?: any;
+ options?: string[];
+ min?: number;
+ max?: number;
+}
+
+// Configuration response grouped by provider
+interface SandboxConfig {
+ active: string; // Currently active provider
+ self_managed?: Record;
+ aliyun_codeinterpreter?: Record;
+ e2b?: Record;
+ // Add more providers as needed
+}
+
+// Request to update provider configuration
+interface SandboxConfigRequest {
+ provider_type: string;
+ config: Record;
+ test_connection?: boolean;
+ set_active?: boolean;
+}
+```
+
+## Integration with Agent system
+
+### Agent component usage
+
+The agent system will use the sandbox through the simplified provider manager, loading global configuration from SystemSettings:
+
+```python
+# In agent/components/code_executor.py
+
+import json
+from agent.agent.sandbox.providers.manager import ProviderManager
+from agent.agent.sandbox.providers.self_managed import SelfManagedProvider
+from agent.agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+from agent.agent.sandbox.providers.e2b import E2BProvider
+from api.db.services.system_settings_service import SystemSettingsService
+
+# Map provider IDs to their classes
+PROVIDER_CLASSES = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+}
+
+class CodeExecutorComponent:
+ def __init__(self):
+ self.provider_manager = ProviderManager()
+ self._load_active_provider()
+
+ def _load_active_provider(self):
+ """Load the active provider from system settings"""
+ # Get active provider
+ active_setting = SystemSettingsService.get_by_name("sandbox.provider_type")
+ if not active_setting:
+ raise RuntimeError("No sandbox provider configured")
+
+ active_provider = active_setting[0].value
+
+ # Load configuration for active provider (single JSON record)
+ provider_setting = SystemSettingsService.get_by_name(f"sandbox.{active_provider}")
+ if not provider_setting:
+ raise RuntimeError(f"Sandbox provider {active_provider} not configured")
+
+ # Parse JSON configuration
+ try:
+ config = json.loads(provider_setting[0].value)
+ except json.JSONDecodeError as e:
+ raise RuntimeError(f"Invalid sandbox configuration for {active_provider}: {e}")
+
+ # Get provider class
+ provider_class = PROVIDER_CLASSES.get(active_provider)
+ if not provider_class:
+ raise RuntimeError(f"Unknown provider: {active_provider}")
+
+ # Initialize provider
+ provider = provider_class()
+ provider.initialize(config)
+
+ # Set as active provider in manager
+ self.provider_manager.set_provider(active_provider, provider)
+
+ def execute(self, code: str, language: str) -> ExecutionResult:
+ """Execute code using the active provider"""
+ provider = self.provider_manager.get_provider()
+
+ if not provider:
+ raise RuntimeError("No sandbox provider configured")
+
+ # Create instance
+ instance = provider.create_instance(template=language)
+
+ try:
+ # Execute code
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code=code,
+ language=language
+ )
+ return result
+ finally:
+ # Always cleanup
+ provider.destroy_instance(instance.instance_id)
+```
+
+## Security considerations
+
+### Credential storage
+- Sensitive credentials (API keys, secrets) encrypted at rest in database
+- Use RAGFlow's existing encryption mechanisms (AES-256)
+- Never log or expose credentials in error messages or API responses
+- Credentials redacted in UI (show only last 4 characters)
+
+### Tenant isolation
+
+- **Configuration**: Global sandbox settings shared by all tenants (admin-only access)
+- **Execution**: Sandboxes never shared across tenants/sessions during runtime
+- **Instance IDs**: Scoped to tenant: `{tenant_id}:{session_id}:{instance_id}`
+- **Network Isolation**: Between tenant sandboxes (VPC per tenant for SaaS providers)
+- **Resource Quotas**: Per-tenant limits on concurrent executions, total execution time
+- **Audit Logging**: All sandbox executions logged with tenant_id for traceability
+
+### Resource limits
+- Timeout limits per execution (configurable per provider, default 30s)
+- Memory/CPU limits enforced at provider level
+- Automatic cleanup of stale instances (max lifetime: 5 minutes)
+- Rate limiting per tenant (max concurrent executions: 10)
+
+### Code security
+- For self-managed: AST-based security analysis before execution
+- Blocked operations: file system writes, network calls, system commands
+- Allowlist approach: only specific imports allowed
+- Runtime monitoring for malicious patterns
+
+### Network security
+- Self-managed: Network isolation by default, no external access
+- SaaS: HTTPS only, certificate pinning
+- IP whitelisting for self-managed endpoint access
+
+## Monitoring and observability
+
+### Metrics to track
+
+**Common metrics (all providers)**:
+- Execution success rate (target: >95%)
+- Average execution time (p50, p95, p99)
+- Error rate by error type
+- Active instance count
+- Queue depth (for self-managed pool)
+
+**Self-managed specific**:
+- Container pool utilization (target: 60-80%)
+- Host resource usage (CPU, memory, disk)
+- Container creation latency
+- Container restart rate
+- gVisor runtime health
+
+**SaaS specific**:
+- API call latency by region
+- Rate limit usage and throttling events
+- Cost estimation (execution count × unit cost)
+- Provider availability (uptime %)
+- API error rate by error code
+
+### Logging
+
+Structured logging for all provider operations:
+```json
+{
+ "timestamp": "2025-01-26T10:00:00Z",
+ "tenant_id": "tenant_123",
+ "provider": "aliyun_codeinterpreter",
+ "operation": "execute_code",
+ "instance_id": "inst_xyz",
+ "language": "python",
+ "code_hash": "sha256:...",
+ "duration_ms": 1234,
+ "status": "success",
+ "exit_code": 0,
+ "memory_used_mb": 64,
+ "region": "cn-hangzhou"
+}
+```
+
+### Alerts
+
+**Critical alerts**:
+- Provider availability < 99%
+- Error rate > 5%
+- Average execution time > 10s
+- Container pool exhaustion (0 available)
+
+**Warning alerts**:
+- Cost spike (2x daily average)
+- Rate limit approaching (>80%)
+- High memory usage (>90%)
+- Slow execution times (p95 > 5s)
+
+## Migration path
+
+### Phase 1: Refactor existing code (week 1-2)
+**Goals**: Extract current implementation into provider pattern
+
+**Tasks**:
+- [ ] Create `agent/sandbox/providers/base.py` with `SandboxProvider` interface
+- [ ] Implement `agent/sandbox/providers/self_managed.py` wrapping executor_manager
+- [ ] Create `agent/sandbox/providers/manager.py` for provider management
+- [ ] Write unit tests for self-managed provider
+- [ ] Document existing behavior and configuration
+
+**Deliverables**:
+- Provider abstraction layer
+- Self-managed provider implementation
+- Unit test suite
+
+### Phase 2: Database entegration (week 3)
+**Goals**: Add sandbox configuration to admin system
+
+**Tasks**:
+- [ ] Add sandbox entries to `conf/system_settings.json` initialization file
+- [ ] Extend `SettingsMgr` in `admin/server/services.py` with sandbox-specific methods
+- [ ] Add admin endpoints to `admin/server/routes.py`
+- [ ] Implement configuration validation logic
+- [ ] Add provider connection testing
+- [ ] Write API tests
+
+**Deliverables**:
+- SystemSettings integration
+- Admin API endpoints (`/api/admin/sandbox/*`)
+- Configuration validation
+- API test suite
+
+### Phase 3: Frontend UI (week 4)
+**Goals**: Build admin settings interface
+
+**Tasks**:
+- [ ] Create `web/src/pages/SandboxSettings/index.tsx`
+- [ ] Implement dynamic form generation from provider schema
+- [ ] Add connection testing UI
+- [ ] Create TypeScript types
+- [ ] Write frontend tests
+
+**Deliverables**:
+- Admin settings UI
+- Type definitions
+- Frontend test suite
+
+### Phase 4: SaaS provider implementation (Week 5-6)
+**Goals**: Implement Aliyun Code Interpreter and E2B providers
+
+**Tasks**:
+- [ ] Implement `agent/sandbox/providers/aliyun_codeinterpreter.py`
+- [ ] Implement `agent/sandbox/providers/e2b.py`
+- [ ] Add provider-specific tests with mocking
+- [ ] Document provider-specific behaviors
+- [ ] Create provider setup guides
+
+**Deliverables**:
+- Aliyun Code Interpreter provider
+- E2B provider
+- Provider documentation
+
+### Phase 5: Agent integration (week 7)
+**Goals**: Update agent components to use new provider system
+
+**Tasks**:
+- [ ] Update `agent/components/code_executor.py` to use ProviderManager
+- [ ] Implement fallback logic
+- [ ] Add tenant-specific provider loading
+- [ ] Update agent tests
+- [ ] Performance testing
+
+**Deliverables**:
+- Agent integration
+- Fallback mechanism
+- Updated test suite
+
+### Phase 6: Monitoring & documentation (week 8)
+**Goals**: Add observability and complete documentation
+
+**Tasks**:
+- [ ] Implement metrics collection
+- [ ] Add structured logging
+- [ ] Configure alerts
+- [ ] Write deployment guide
+- [ ] Write user documentation
+- [ ] Create troubleshooting guide
+
+**Deliverables**:
+- Monitoring dashboards
+- Complete documentation
+- Deployment guides
+
+## Testing strategy
+
+### Unit tests
+
+**Provider tests** (`test/agent/sandbox/providers/test_*.py`):
+```python
+class TestSelfManagedProvider:
+ def test_initialize_with_config():
+ provider = SelfManagedProvider()
+ assert provider.initialize({"endpoint": "http://localhost:9385"})
+
+ def test_create_python_instance():
+ provider = SelfManagedProvider()
+ provider.initialize(test_config)
+ instance = provider.create_instance("python")
+ assert instance.status == "running"
+
+ def test_execute_code():
+ provider = SelfManagedProvider()
+ result = provider.execute_code(instance_id, "print('hello')", "python")
+ assert result.exit_code == 0
+ assert "hello" in result.stdout
+```
+
+**Configuration tests**:
+- Test configuration validation for each provider schema
+- Test error handling for invalid configurations
+- Test secret field redaction
+
+### Integration tests
+
+**Provider Switching**:
+- Test switching between providers
+- Test fallback mechanism
+- Test concurrent provider usage
+
+**Multi-Tenant Isolation**:
+- Test tenant configuration isolation
+- Test instance ID scoping
+- Test resource separation
+
+**Admin API Tests**:
+- Test CRUD operations for configurations
+- Test connection testing endpoint
+- Test validation error responses
+
+### E2E tests
+
+**Complete flow tests**:
+```python
+def test_sandbox_execution_flow():
+ # 1. Configure provider via admin API
+ setSandboxConfig(provider="self_managed", config={...})
+
+ # 2. Create agent task with code execution
+ task = create_agent_task(code="print('test')")
+
+ # 3. Execute task
+ result = execute_agent_task(task.id)
+
+ # 4. Verify result
+ assert result.status == "success"
+ assert "test" in result.output
+
+ # 5. Verify sandbox cleanup
+ assert get_active_instances() == 0
+```
+
+**Admin UI tests**:
+- Test provider configuration flow
+- Test connection testing
+- Test error handling in UI
+
+### Performance tests
+
+**Load Testing**:
+- Test 100 concurrent executions
+- Test pool exhaustion behavior
+- Test queue performance (self-managed)
+
+**Latency Testing**:
+- Measure cold start time per provider
+- Measure execution latency percentiles
+- Compare provider performance
+
+## Cost considerations
+
+### Self-managed costs
+
+**Infrastructure**:
+- Server hosting: $X/month (depends on specs)
+- Maintenance: engineering time
+- Scaling: manual, requires additional servers
+
+**Pros**:
+- Predictable costs
+- No per-execution fees
+- Full control over resources
+
+**Cons**:
+- High initial setup cost
+- Operational overhead
+- Finite capacity
+
+### SaaS costs
+
+**Aliyun Code Interpreter** (estimated):
+- Pricing: execution time × memory configuration
+- Example: 1000 executions/day × 30s × $0.01/1000s = ~$0.30/day
+
+**E2B** (estimated):
+- Pricing: $0.02/execution-second
+- Example: 1000 executions/day × 30s × $0.02/s = ~$600/day
+
+**Pros**:
+- No upfront costs
+- Automatic scaling
+- No maintenance
+
+**Cons**:
+- Variable costs (can spike with usage)
+- Network dependency
+- Potential for runaway costs
+
+### Cost optimization
+
+**Recommendations**:
+- **Hybrid Approach**: Use self-managed for base load, SaaS for spikes
+- **Cost Monitoring**: Set budget alerts per tenant
+- **Resource Limits**: Enforce max executions per tenant/day
+- **Caching**: Reuse instances when possible (self-managed pool)
+- **Smart Routing**: Route to cheapest provider based on availability
+
+## Future extensibility
+
+The architecture supports easy addition of new providers:
+
+### Adding a new provider
+
+**Step 1**: Implement provider class with schema
+
+```python
+# agent/sandbox/providers/new_provider.py
+from .base import SandboxProvider
+
+class NewProvider(SandboxProvider):
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ return {
+ "api_key": {
+ "type": "string",
+ "required": True,
+ "secret": True,
+ "label": "API Key"
+ },
+ "region": {
+ "type": "string",
+ "default": "us-east-1",
+ "label": "Region"
+ }
+ }
+
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ self.api_key = config.get("api_key")
+ self.region = config.get("region", "us-east-1")
+ # Initialize client
+ return True
+
+ # Implement other abstract methods...
+```
+
+**Step 2**: Register in provider mapping
+
+```python
+# In api/apps/sandbox_app.py or wherever providers are listed
+from agent.agent.sandbox.providers.new_provider import NewProvider
+
+PROVIDER_CLASSES = {
+ "self_managed": SelfManagedProvider,
+ "aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
+ "e2b": E2BProvider,
+ "new_provider": NewProvider, # Add here
+}
+```
+
+**No central registry to update** - just import and add to the mapping!
+
+### Potential future providers
+
+- **GitHub Codespaces**: For GitHub-integrated workflows
+- **Gitpod**: Cloud development environments
+- **CodeSandbox**: Frontend code execution
+- **AWS Firecracker**: Raw microVM management
+- **Custom Provider**: User-defined provider implementations
+
+### Advanced features
+
+**Feature pooling**:
+- Share instances across executions (same language, same user)
+- Warm pool for reduced latency
+- Instance hibernation for cost savings
+
+**Feature multi-region**:
+- Route to nearest region
+- Failover across regions
+- Regional cost optimization
+
+**Feature hybrid execution**:
+- Split workloads between providers
+- Dynamic provider selection based on cost/performance
+- A/B testing for provider performance
+
+## Appendix
+
+### Configuration examples
+
+**SystemSettings initialization file** (`conf/system_settings.json` - add these entries):
+
+```json
+{
+ "system_settings": [
+ {
+ "name": "sandbox.provider_type",
+ "source": "variable",
+ "data_type": "string",
+ "value": "self_managed"
+ },
+ {
+ "name": "sandbox.self_managed",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"endpoint\": \"http://sandbox-internal:9385\", \"pool_size\": 20, \"max_memory\": \"512m\", \"timeout\": 60, \"enable_seccomp\": true, \"enable_ast_analysis\": true}"
+ },
+ {
+ "name": "sandbox.aliyun_codeinterpreter",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"access_key_id\": \"\", \"access_key_secret\": \"\", \"account_id\": \"\", \"region\": \"cn-hangzhou\", \"template_name\": \"\", \"timeout\": 30}"
+ },
+ {
+ "name": "sandbox.e2b",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"api_key\": \"\", \"region\": \"us\", \"timeout\": 30}"
+ }
+ ]
+}
+```
+
+**Admin API request example** (POST to `/api/admin/sandbox/config`):
+
+```json
+{
+ "provider_type": "self_managed",
+ "config": {
+ "endpoint": "http://sandbox-internal:9385",
+ "pool_size": 20,
+ "max_memory": "512m",
+ "timeout": 60,
+ "enable_seccomp": true,
+ "enable_ast_analysis": true
+ },
+ "test_connection": true,
+ "set_active": true
+}
+```
+
+**Note**: The `config` object in the request is a plain JSON object. The API will serialize it to a JSON string before storing in SystemSettings.
+
+**Admin API response example** (GET from `/api/admin/sandbox/config`):
+
+```json
+{
+ "data": {
+ "active": "self_managed",
+ "self_managed": {
+ "endpoint": "http://sandbox-internal:9385",
+ "pool_size": 20,
+ "max_memory": "512m",
+ "timeout": 60,
+ "enable_seccomp": true,
+ "enable_ast_analysis": true
+ },
+ "aliyun_codeinterpreter": {
+ "access_key_id": "",
+ "access_key_secret": "",
+ "region": "cn-hangzhou",
+ "workspace_id": ""
+ },
+ "e2b": {
+ "api_key": "",
+ "region": "us",
+ "timeout": 30
+ }
+ }
+}
+```
+
+**Note**: The response deserializes the JSON strings back to objects for easier frontend consumption.
+
+### Error codes
+
+| Code | Description | Resolution |
+|------|-------------|------------|
+| SB001 | Provider not initialized | Configure provider in admin |
+| SB002 | Invalid configuration | Check configuration values |
+| SB003 | Connection failed | Check network and credentials |
+| SB004 | Instance creation failed | Check provider capacity |
+| SB005 | Execution timeout | Increase timeout or optimize code |
+| SB006 | Out of memory | Reduce memory usage or increase limits |
+| SB007 | Code blocked by security policy | Remove blocked imports/operations |
+| SB008 | Rate limit exceeded | Reduce concurrency or upgrade plan |
+| SB009 | Provider unavailable | Check provider status or use fallback |
+
+### References
+
+- [Daytona Documentation](https://daytona.dev/docs)
+- [Aliyun Code Interpreter](https://help.aliyun.com/...)
+- [E2B Documentation](https://e2b.dev/docs)
+
+---
+
+**Document version**: 1.0
+**Last updated**: 2026-01-26
+**Author**: RAGFlow Team
+**Status**: Design Specification - Ready for Review
+
+## Appendix C: configuration storage considerations
+
+### Current implementation
+- **Storage**: SystemSettings table with `value` field as `TextField` (unlimited length)
+- **Migration**: Database migration added to convert from `CharField(1024)` to `TextField`
+- **Benefit**: Supports arbitrarily long API keys, workspace IDs, and other SaaS provider credentials
+
+### Validation
+- **Schema validation**: Type checking, range validation, required field validation
+- **Provider-specific validation**: Custom validation via `validate_config()` method
+- **Example**: SelfManagedProvider validates URL format, timeout ranges, pool size constraints
+
+### Configuration storage format
+Each provider's configuration is stored as JSON in `SystemSettings.value`:
+- `sandbox.provider_type`: Active provider selection
+- `sandbox.self_managed`: Self-managed provider JSON config
+- `sandbox.aliyun_codeinterpreter`: Aliyun provider JSON config
+- `sandbox.e2b`: E2B provider JSON config
+
+## Appendix D: Configuration hot reload limitations
+
+### Current behavior
+**Provider configuration requires restart**: When switching sandbox providers in the admin panel, the ragflow service must be restarted for changes to take effect.
+
+**Reason**:
+- Admin and ragflow are separate processes
+- ragflow loads sandbox provider configuration only at startup
+- The `get_provider_manager()` function caches the provider globally
+- Configuration changes in MySQL are not automatically detected
+
+**Impact**:
+- Switching from `self_managed` → `aliyun_codeinterpreter` requires ragflow restart
+- Updating credentials/config requires ragflow restart
+- Not a dynamic configuration system
+
+**Workarounds**:
+1. **Production**: Restart ragflow service after configuration changes:
+ ```bash
+ cd docker
+ docker compose restart ragflow-server
+ ```
+
+2. **Development**: Use the `reload_provider()` function in code:
+ ```python
+ from agent.sandbox.client import reload_provider
+ reload_provider() # Reloads from MySQL settings
+ ```
+
+**Future enhancement**:
+To support hot reload without restart, implement configuration change detection:
+```python
+# In agent/sandbox/client.py
+_config_timestamp: Optional[int] = None
+
+def get_provider_manager() -> ProviderManager:
+ global _provider_manager, _config_timestamp
+
+ # Check if configuration has changed
+ setting = SystemSettingsService.get_by_name("sandbox.provider_type")
+ current_timestamp = setting[0].update_time if setting else 0
+
+ if _config_timestamp is None or current_timestamp > _config_timestamp:
+ # Configuration changed, reload provider
+ _provider_manager = None
+ _load_provider_from_settings()
+ _config_timestamp = current_timestamp
+
+ return _provider_manager
+```
+
+However, this adds overhead on every `execute_code()` call. For production use, explicit restart is preferred for simplicity and reliability.
+
+## Appendix E: Arguments parameter support
+
+### Overview
+All sandbox providers support passing arguments to the `main()` function in user code. This enables dynamic parameter injection for code execution.
+
+### Implementation details
+
+**Base interface**:
+```python
+# agent/sandbox/providers/base.py
+@abstractmethod
+def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10,
+ arguments: Optional[Dict[str, Any]] = None
+) -> ExecutionResult:
+ """
+ Execute code in the sandbox.
+
+ The code should contain a main() function that will be called with:
+ - Python: main(**arguments) if arguments provided, else main()
+ - JavaScript: main(arguments) if arguments provided, else main()
+ """
+ pass
+```
+
+**Provider implementations**:
+
+1. **Self-managed provider** ([self_managed.py:164](agent/sandbox/providers/self_managed.py:164)):
+ - Passes arguments via HTTP API: `"arguments": arguments or {}`
+ - executor_manager receives and passes to code via command line
+ - Runner script: `args = json.loads(sys.argv[1])` then `result = main(**args)`
+
+2. **Aliyun Code Interpreter** ([aliyun_codeinterpreter.py:260-275](agent/sandbox/providers/aliyun_codeinterpreter.py:260-275)):
+ - Wraps user code to call `main(**arguments)` or `main()` if no arguments
+ - Python example:
+ ```python
+ if arguments:
+ wrapped_code = f'''{code}
+
+ if __name__ == "__main__":
+ import json
+ result = main(**{json.dumps(arguments)})
+ print(json.dumps(result) if isinstance(result, dict) else result)
+ '''
+ ```
+ - JavaScript example:
+ ```javascript
+ if arguments:
+ wrapped_code = f'''{code}
+
+ const result = main({json.dumps(arguments)});
+ console.log(typeof result === 'object' ? JSON.stringify(result) : String(result));
+ '''
+ ```
+
+**Client layer** ([client.py:138-190](agent/sandbox/client.py:138-190)):
+```python
+def execute_code(
+ code: str,
+ language: str = "python",
+ timeout: int = 30,
+ arguments: Optional[Dict[str, Any]] = None
+) -> ExecutionResult:
+ provider_manager = get_provider_manager()
+ provider = provider_manager.get_provider()
+
+ instance = provider.create_instance(template=language)
+ try:
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code=code,
+ language=language,
+ timeout=timeout,
+ arguments=arguments # Passed through to provider
+ )
+ return result
+ finally:
+ provider.destroy_instance(instance.instance_id)
+```
+
+**CodeExec tool integration** ([code_exec.py:136-165](agent/tools/code_exec.py:136-165)):
+```python
+def _execute_code(self, language: str, code: str, arguments: dict):
+ # ... collect arguments from component configuration
+
+ result = sandbox_execute_code(
+ code=code,
+ language=language,
+ timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)),
+ arguments=arguments # Passed through to sandbox client
+ )
+```
+
+### Usage examples
+
+**Python code with arguments**:
+```python
+# User code
+def main(name: str, count: int) -> dict:
+ """Generate greeting"""
+ return {"message": f"Hello {name}!" * count}
+
+# Called with: arguments={"name": "World", "count": 3}
+# Result: {"message": "Hello World!Hello World!Hello World!"}
+```
+
+**JavaScript code with arguments**:
+```javascript
+// User code
+function main(args) {
+ const { name, count } = args;
+ return `Hello ${name}!`.repeat(count);
+}
+
+// Called with: arguments={"name": "World", "count": 3}
+// Result: "Hello World!Hello World!Hello World!"
+```
+
+### Important notes
+
+1. **Function signature**: Code MUST define a `main()` function
+ - Python: `def main(**kwargs)` or `def main()` if no arguments
+ - JavaScript: `function main(args)` or `function main()` if no arguments
+
+2. **Type consistency**: Arguments are passed as JSON, so types are preserved:
+ - Numbers → int/float
+ - Strings → str
+ - Booleans → bool
+ - Objects → dict (Python) / object (JavaScript)
+ - Arrays → list (Python) / array (JavaScript)
+
+3. **Return value**: Return value is serialized as JSON for parsing
+ - Python: `print(json.dumps(result))` if dict
+ - JavaScript: `console.log(JSON.stringify(result))` if object
+
+4. **Provider alignment**: All providers (self_managed, aliyun_codeinterpreter, e2b) implement arguments passing consistently
diff --git a/sandbox/scripts/restart.sh b/agent/sandbox/scripts/restart.sh
similarity index 100%
rename from sandbox/scripts/restart.sh
rename to agent/sandbox/scripts/restart.sh
diff --git a/sandbox/scripts/start.sh b/agent/sandbox/scripts/start.sh
similarity index 100%
rename from sandbox/scripts/start.sh
rename to agent/sandbox/scripts/start.sh
diff --git a/sandbox/scripts/stop.sh b/agent/sandbox/scripts/stop.sh
similarity index 100%
rename from sandbox/scripts/stop.sh
rename to agent/sandbox/scripts/stop.sh
diff --git a/sandbox/scripts/wait-for-it-http.sh b/agent/sandbox/scripts/wait-for-it-http.sh
similarity index 100%
rename from sandbox/scripts/wait-for-it-http.sh
rename to agent/sandbox/scripts/wait-for-it-http.sh
diff --git a/sandbox/scripts/wait-for-it.sh b/agent/sandbox/scripts/wait-for-it.sh
similarity index 100%
rename from sandbox/scripts/wait-for-it.sh
rename to agent/sandbox/scripts/wait-for-it.sh
diff --git a/agent/sandbox/tests/MIGRATION_GUIDE.md b/agent/sandbox/tests/MIGRATION_GUIDE.md
new file mode 100644
index 00000000000..93bb27ba87d
--- /dev/null
+++ b/agent/sandbox/tests/MIGRATION_GUIDE.md
@@ -0,0 +1,261 @@
+# Aliyun Code Interpreter Provider - 使用官方 SDK
+
+## 重要变更
+
+### 官方资源
+- **Code Interpreter API**: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+- **官方 SDK**: https://github.com/Serverless-Devs/agentrun-sdk-python
+- **SDK 文档**: https://docs.agent.run
+
+## 使用官方 SDK 的优势
+
+从手动 HTTP 请求迁移到官方 SDK (`agentrun-sdk`) 有以下优势:
+
+### 1. **自动签名认证**
+- SDK 自动处理 Aliyun API 签名(无需手动实现 `Authorization` 头)
+- 支持多种认证方式:AccessKey、STS Token
+- 自动读取环境变量
+
+### 2. **简化的 API**
+```python
+# 旧实现(手动 HTTP 请求)
+response = requests.post(
+ f"{DATA_ENDPOINT}/sandboxes/{sandbox_id}/execute",
+ headers={"X-Acs-Parent-Id": account_id},
+ json={"code": code, "language": "python"}
+)
+
+# 新实现(使用 SDK)
+sandbox = CodeInterpreterSandbox(template_name="python-sandbox", config=config)
+result = sandbox.context.execute(code="print('hello')")
+```
+
+### 3. **更好的错误处理**
+- 结构化的异常类型 (`ServerError`)
+- 自动重试机制
+- 详细的错误信息
+
+## 主要变更
+
+### 1. 文件重命名
+
+| 旧文件名 | 新文件名 | 说明 |
+|---------|---------|------|
+| `aliyun_opensandbox.py` | `aliyun_codeinterpreter.py` | 提供商实现 |
+| `test_aliyun_provider.py` | `test_aliyun_codeinterpreter.py` | 单元测试 |
+| `test_aliyun_integration.py` | `test_aliyun_codeinterpreter_integration.py` | 集成测试 |
+
+### 2. 配置字段变更
+
+#### 旧配置(OpenSandbox)
+```json
+{
+ "access_key_id": "LTAI5t...",
+ "access_key_secret": "...",
+ "region": "cn-hangzhou",
+ "workspace_id": "ws-xxxxx"
+}
+```
+
+#### 新配置(Code Interpreter)
+```json
+{
+ "access_key_id": "LTAI5t...",
+ "access_key_secret": "...",
+ "account_id": "1234567890...", // 新增:阿里云主账号ID(必需)
+ "region": "cn-hangzhou",
+ "template_name": "python-sandbox", // 新增:沙箱模板名称
+ "timeout": 30 // 最大 30 秒(硬限制)
+}
+```
+
+### 3. 关键差异
+
+| 特性 | OpenSandbox | Code Interpreter |
+|------|-------------|-----------------|
+| **API 端点** | `opensandbox.{region}.aliyuncs.com` | `agentrun.{region}.aliyuncs.com` (控制面) |
+| **API 版本** | `2024-01-01` | `2025-09-10` |
+| **认证** | 需要 AccessKey | 需要 AccessKey + 主账号ID |
+| **请求头** | 标准签名 | 需要 `X-Acs-Parent-Id` 头 |
+| **超时限制** | 可配置 | **最大 30 秒**(硬限制) |
+| **上下文** | 不支持 | 支持上下文(Jupyter kernel) |
+
+### 4. API 调用方式变更
+
+#### 旧实现(假设的 OpenSandbox)
+```python
+# 单一端点
+API_ENDPOINT = "https://opensandbox.cn-hangzhou.aliyuncs.com"
+
+# 简单的请求/响应
+response = requests.post(
+ f"{API_ENDPOINT}/execute",
+ json={"code": "print('hello')", "language": "python"}
+)
+```
+
+#### 新实现(Code Interpreter)
+```python
+# 控制面 API - 管理沙箱生命周期
+CONTROL_ENDPOINT = "https://agentrun.cn-hangzhou.aliyuncs.com/2025-09-10"
+
+# 数据面 API - 执行代码
+DATA_ENDPOINT = "https://{account_id}.agentrun-data.cn-hangzhou.aliyuncs.com"
+
+# 创建沙箱(控制面)
+response = requests.post(
+ f"{CONTROL_ENDPOINT}/sandboxes",
+ headers={"X-Acs-Parent-Id": account_id},
+ json={"templateName": "python-sandbox"}
+)
+
+# 执行代码(数据面)
+response = requests.post(
+ f"{DATA_ENDPOINT}/sandboxes/{sandbox_id}/execute",
+ headers={"X-Acs-Parent-Id": account_id},
+ json={"code": "print('hello')", "language": "python", "timeout": 30}
+)
+```
+
+### 5. 迁移步骤
+
+#### 步骤 1: 更新配置
+
+如果您之前使用的是 `aliyun_opensandbox`:
+
+**旧配置**:
+```json
+{
+ "name": "sandbox.provider_type",
+ "value": "aliyun_opensandbox"
+}
+```
+
+**新配置**:
+```json
+{
+ "name": "sandbox.provider_type",
+ "value": "aliyun_codeinterpreter"
+}
+```
+
+#### 步骤 2: 添加必需的 account_id
+
+在 Aliyun 控制台右上角点击头像,获取主账号 ID:
+1. 登录 [阿里云控制台](https://ram.console.aliyun.com/manage/ak)
+2. 点击右上角头像
+3. 复制主账号 ID(16 位数字)
+
+#### 步骤 3: 更新环境变量
+
+```bash
+# 新增必需的环境变量
+export ALIYUN_ACCOUNT_ID="1234567890123456"
+
+# 其他环境变量保持不变
+export ALIYUN_ACCESS_KEY_ID="LTAI5t..."
+export ALIYUN_ACCESS_KEY_SECRET="..."
+export ALIYUN_REGION="cn-hangzhou"
+```
+
+#### 步骤 4: 运行测试
+
+```bash
+# 单元测试(不需要真实凭据)
+pytest agent/sandbox/tests/test_aliyun_codeinterpreter.py -v
+
+# 集成测试(需要真实凭据)
+pytest agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py -v -m integration
+```
+
+## 文件变更清单
+
+### ✅ 已完成
+
+- [x] 创建 `aliyun_codeinterpreter.py` - 新的提供商实现
+- [x] 更新 `sandbox_spec.md` - 规范文档
+- [x] 更新 `admin/services.py` - 服务管理器
+- [x] 更新 `providers/__init__.py` - 包导出
+- [x] 创建 `test_aliyun_codeinterpreter.py` - 单元测试
+- [x] 创建 `test_aliyun_codeinterpreter_integration.py` - 集成测试
+
+### 📝 可选清理
+
+如果您想删除旧的 OpenSandbox 实现:
+
+```bash
+# 删除旧文件(可选)
+rm agent/sandbox/providers/aliyun_opensandbox.py
+rm agent/sandbox/tests/test_aliyun_provider.py
+rm agent/sandbox/tests/test_aliyun_integration.py
+```
+
+**注意**: 保留旧文件不会影响新功能,只是代码冗余。
+
+## API 参考
+
+### 控制面 API(沙箱管理)
+
+| 端点 | 方法 | 说明 |
+|------|------|------|
+| `/sandboxes` | POST | 创建沙箱实例 |
+| `/sandboxes/{id}/stop` | POST | 停止实例 |
+| `/sandboxes/{id}` | DELETE | 删除实例 |
+| `/templates` | GET | 列出模板 |
+
+### 数据面 API(代码执行)
+
+| 端点 | 方法 | 说明 |
+|------|------|------|
+| `/sandboxes/{id}/execute` | POST | 执行代码(简化版) |
+| `/sandboxes/{id}/contexts` | POST | 创建上下文 |
+| `/sandboxes/{id}/contexts/{ctx_id}/execute` | POST | 在上下文中执行 |
+| `/sandboxes/{id}/health` | GET | 健康检查 |
+| `/sandboxes/{id}/files` | GET/POST | 文件读写 |
+| `/sandboxes/{id}/processes/cmd` | POST | 执行 Shell 命令 |
+
+## 常见问题
+
+### Q: 为什么要添加 account_id?
+
+**A**: Code Interpreter API 需要在请求头中提供 `X-Acs-Parent-Id`(阿里云主账号ID)进行身份验证。这是 Aliyun Code Interpreter API 的必需参数。
+
+### Q: 30 秒超时限制可以绕过吗?
+
+**A**: 不可以。这是 Aliyun Code Interpreter 的**硬限制**,无法通过配置或请求参数绕过。如果代码执行时间超过 30 秒,请考虑:
+1. 优化代码逻辑
+2. 分批处理数据
+3. 使用上下文保持状态
+
+### Q: 旧的 OpenSandbox 配置还能用吗?
+
+**A**: 不能。OpenSandbox 和 Code Interpreter 是两个不同的服务,API 不兼容。必须迁移到新的配置格式。
+
+### Q: 如何获取阿里云主账号 ID?
+
+**A**:
+1. 登录阿里云控制台
+2. 点击右上角的头像
+3. 在弹出的信息中可以看到"主账号ID"
+
+### Q: 迁移后会影响现有功能吗?
+
+**A**:
+- **自我管理提供商(self_managed)**: 不受影响
+- **E2B 提供商**: 不受影响
+- **Aliyun 提供商**: 需要更新配置并重新测试
+
+## 相关文档
+
+- [官方文档](https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter)
+- [sandbox 规范](../docs/develop/sandbox_spec.md)
+- [测试指南](./README.md)
+- [快速开始](./QUICKSTART.md)
+
+## 技术支持
+
+如有问题,请:
+1. 查看官方文档
+2. 检查配置是否正确
+3. 查看测试输出中的错误信息
+4. 联系 RAGFlow 团队
diff --git a/agent/sandbox/tests/QUICKSTART.md b/agent/sandbox/tests/QUICKSTART.md
new file mode 100644
index 00000000000..51a23eeae12
--- /dev/null
+++ b/agent/sandbox/tests/QUICKSTART.md
@@ -0,0 +1,178 @@
+# Aliyun OpenSandbox Provider - 快速测试指南
+
+## 测试说明
+
+### 1. 单元测试(不需要真实凭据)
+
+单元测试使用 mock,**不需要**真实的 Aliyun 凭据,可以随时运行。
+
+```bash
+# 运行 Aliyun 提供商的单元测试
+pytest agent/sandbox/tests/test_aliyun_provider.py -v
+
+# 预期输出:
+# test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_provider_initialization PASSED
+# test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success PASSED
+# ...
+# ========================= 48 passed in 2.34s ==========================
+```
+
+### 2. 集成测试(需要真实凭据)
+
+集成测试会调用真实的 Aliyun API,需要配置凭据。
+
+#### 步骤 1: 配置环境变量
+
+```bash
+export ALIYUN_ACCESS_KEY_ID="LTAI5t..." # 替换为真实的 Access Key ID
+export ALIYUN_ACCESS_KEY_SECRET="..." # 替换为真实的 Access Key Secret
+export ALIYUN_REGION="cn-hangzhou" # 可选,默认为 cn-hangzhou
+```
+
+#### 步骤 2: 运行集成测试
+
+```bash
+# 运行所有集成测试
+pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
+
+# 运行特定测试
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
+```
+
+#### 步骤 3: 预期输出
+
+```
+test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_initialize_provider PASSED
+test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check PASSED
+test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code PASSED
+...
+========================== 10 passed in 15.67s ==========================
+```
+
+### 3. 测试场景
+
+#### 基础功能测试
+
+```bash
+# 健康检查
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
+
+# 创建实例
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_create_python_instance -v
+
+# 执行代码
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code -v
+
+# 销毁实例
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_destroy_instance -v
+```
+
+#### 错误处理测试
+
+```bash
+# 代码执行错误
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code_with_error -v
+
+# 超时处理
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code_timeout -v
+```
+
+#### 真实场景测试
+
+```bash
+# 数据处理工作流
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_data_processing_workflow -v
+
+# 字符串操作
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_string_manipulation -v
+
+# 多次执行
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_multiple_executions_same_instance -v
+```
+
+## 常见问题
+
+### Q: 没有凭据怎么办?
+
+**A:** 运行单元测试即可,不需要真实凭据:
+```bash
+pytest agent/sandbox/tests/test_aliyun_provider.py -v
+```
+
+### Q: 如何跳过集成测试?
+
+**A:** 使用 pytest 标记跳过:
+```bash
+# 只运行单元测试,跳过集成测试
+pytest agent/sandbox/tests/ -v -m "not integration"
+```
+
+### Q: 集成测试失败怎么办?
+
+**A:** 检查以下几点:
+
+1. **凭据是否正确**
+ ```bash
+ echo $ALIYUN_ACCESS_KEY_ID
+ echo $ALIYUN_ACCESS_KEY_SECRET
+ ```
+
+2. **网络连接是否正常**
+ ```bash
+ curl -I https://opensandbox.cn-hangzhou.aliyuncs.com
+ ```
+
+3. **是否有 OpenSandbox 服务权限**
+ - 登录阿里云控制台
+ - 检查是否已开通 OpenSandbox 服务
+ - 检查 AccessKey 权限
+
+4. **查看详细错误信息**
+ ```bash
+ pytest agent/sandbox/tests/test_aliyun_integration.py -v -s
+ ```
+
+### Q: 测试超时怎么办?
+
+**A:** 增加超时时间或检查网络:
+```bash
+# 使用更长的超时
+pytest agent/sandbox/tests/test_aliyun_integration.py -v --timeout=60
+```
+
+## 测试命令速查表
+
+| 命令 | 说明 | 需要凭据 |
+|------|------|---------|
+| `pytest agent/sandbox/tests/test_aliyun_provider.py -v` | 单元测试 | ❌ |
+| `pytest agent/sandbox/tests/test_aliyun_integration.py -v` | 集成测试 | ✅ |
+| `pytest agent/sandbox/tests/ -v -m "not integration"` | 仅单元测试 | ❌ |
+| `pytest agent/sandbox/tests/ -v -m integration` | 仅集成测试 | ✅ |
+| `pytest agent/sandbox/tests/ -v` | 所有测试 | 部分需要 |
+
+## 获取 Aliyun 凭据
+
+1. 访问 [阿里云控制台](https://ram.console.aliyun.com/manage/ak)
+2. 创建 AccessKey
+3. 保存 AccessKey ID 和 AccessKey Secret
+4. 设置环境变量
+
+⚠️ **安全提示:**
+- 不要在代码中硬编码凭据
+- 使用环境变量或配置文件
+- 定期轮换 AccessKey
+- 限制 AccessKey 权限
+
+## 下一步
+
+1. ✅ **运行单元测试** - 验证代码逻辑
+2. 🔧 **配置凭据** - 设置环境变量
+3. 🚀 **运行集成测试** - 测试真实 API
+4. 📊 **查看结果** - 确保所有测试通过
+5. 🎯 **集成到系统** - 使用 admin API 配置提供商
+
+## 需要帮助?
+
+- 查看 [完整文档](README.md)
+- 检查 [sandbox 规范](../../../../../docs/develop/sandbox_spec.md)
+- 联系 RAGFlow 团队
diff --git a/agent/sandbox/tests/README.md b/agent/sandbox/tests/README.md
new file mode 100644
index 00000000000..11b350d3c3c
--- /dev/null
+++ b/agent/sandbox/tests/README.md
@@ -0,0 +1,213 @@
+# Sandbox Provider Tests
+
+This directory contains tests for the RAGFlow sandbox provider system.
+
+## Test Structure
+
+```
+tests/
+├── pytest.ini # Pytest configuration
+├── test_providers.py # Unit tests for all providers (mocked)
+├── test_aliyun_provider.py # Unit tests for Aliyun provider (mocked)
+├── test_aliyun_integration.py # Integration tests for Aliyun (real API)
+└── sandbox_security_tests_full.py # Security tests for self-managed provider
+```
+
+## Test Types
+
+### 1. Unit Tests (No Credentials Required)
+
+Unit tests use mocks and don't require any external services or credentials.
+
+**Files:**
+- `test_providers.py` - Tests for base provider interface and manager
+- `test_aliyun_provider.py` - Tests for Aliyun provider with mocked API calls
+
+**Run unit tests:**
+```bash
+# Run all unit tests
+pytest agent/sandbox/tests/test_providers.py -v
+pytest agent/sandbox/tests/test_aliyun_provider.py -v
+
+# Run specific test
+pytest agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success -v
+
+# Run all unit tests (skip integration)
+pytest agent/sandbox/tests/ -v -m "not integration"
+```
+
+### 2. Integration Tests (Real Credentials Required)
+
+Integration tests make real API calls to Aliyun OpenSandbox service.
+
+**Files:**
+- `test_aliyun_integration.py` - Tests with real Aliyun API calls
+
+**Setup environment variables:**
+```bash
+export ALIYUN_ACCESS_KEY_ID="LTAI5t..."
+export ALIYUN_ACCESS_KEY_SECRET="..."
+export ALIYUN_REGION="cn-hangzhou" # Optional, defaults to cn-hangzhou
+export ALIYUN_WORKSPACE_ID="ws-..." # Optional
+```
+
+**Run integration tests:**
+```bash
+# Run only integration tests
+pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
+
+# Run all tests including integration
+pytest agent/sandbox/tests/ -v
+
+# Run specific integration test
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
+```
+
+### 3. Security Tests
+
+Security tests validate the security features of the self-managed sandbox provider.
+
+**Files:**
+- `sandbox_security_tests_full.py` - Comprehensive security tests
+
+**Run security tests:**
+```bash
+# Run all security tests
+pytest agent/sandbox/tests/sandbox_security_tests_full.py -v
+
+# Run specific security test
+pytest agent/sandbox/tests/sandbox_security_tests_full.py -k "test_dangerous_imports" -v
+```
+
+## Test Commands
+
+### Quick Test Commands
+
+```bash
+# Run all sandbox tests (unit only, fast)
+pytest agent/sandbox/tests/ -v -m "not integration" --tb=short
+
+# Run tests with coverage
+pytest agent/sandbox/tests/ -v --cov=agent.sandbox --cov-report=term-missing -m "not integration"
+
+# Run tests and stop on first failure
+pytest agent/sandbox/tests/ -v -x -m "not integration"
+
+# Run tests in parallel (requires pytest-xdist)
+pytest agent/sandbox/tests/ -v -n auto -m "not integration"
+```
+
+### Aliyun Provider Testing
+
+```bash
+# 1. Run unit tests (no credentials needed)
+pytest agent/sandbox/tests/test_aliyun_provider.py -v
+
+# 2. Set up credentials for integration tests
+export ALIYUN_ACCESS_KEY_ID="your-key-id"
+export ALIYUN_ACCESS_KEY_SECRET="your-secret"
+export ALIYUN_REGION="cn-hangzhou"
+
+# 3. Run integration tests (makes real API calls)
+pytest agent/sandbox/tests/test_aliyun_integration.py -v
+
+# 4. Test specific scenarios
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code -v
+pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios -v
+```
+
+## Understanding Test Results
+
+### Unit Test Output
+
+```
+agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success PASSED
+agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_create_instance_python PASSED
+...
+========================== 48 passed in 2.34s ===========================
+```
+
+### Integration Test Output
+
+```
+agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check PASSED
+agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_create_python_instance PASSED
+agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code PASSED
+...
+========================== 10 passed in 15.67s ===========================
+```
+
+**Note:** Integration tests will be skipped if credentials are not set:
+```
+agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check SKIPPED
+...
+========================== 48 skipped, 10 passed in 0.12s ===========================
+```
+
+## Troubleshooting
+
+### Integration Tests Fail
+
+1. **Check credentials:**
+ ```bash
+ echo $ALIYUN_ACCESS_KEY_ID
+ echo $ALIYUN_ACCESS_KEY_SECRET
+ ```
+
+2. **Check network connectivity:**
+ ```bash
+ curl -I https://opensandbox.cn-hangzhou.aliyuncs.com
+ ```
+
+3. **Verify permissions:**
+ - Make sure your Aliyun account has OpenSandbox service enabled
+ - Check that your AccessKey has the required permissions
+
+4. **Check region:**
+ - Verify the region is correct for your account
+ - Try different regions: cn-hangzhou, cn-beijing, cn-shanghai, etc.
+
+### Tests Timeout
+
+If tests timeout, increase the timeout in the test configuration or run with a longer timeout:
+```bash
+pytest agent/sandbox/tests/test_aliyun_integration.py -v --timeout=60
+```
+
+### Mock Tests Fail
+
+If unit tests fail, it's likely a code issue, not a credentials issue:
+1. Check the test error message
+2. Review the code changes
+3. Run with verbose output: `pytest -vv`
+
+## Contributing
+
+When adding new providers:
+
+1. **Create unit tests** in `test_{provider}_provider.py` with mocks
+2. **Create integration tests** in `test_{provider}_integration.py` with real API calls
+3. **Add markers** to distinguish test types
+4. **Update this README** with provider-specific testing instructions
+
+Example:
+```python
+@pytest.mark.integration
+def test_new_provider_real_api():
+ """Test with real API calls."""
+ # Your test here
+```
+
+## Continuous Integration
+
+In CI/CD pipelines:
+
+```yaml
+# Run unit tests only (fast, no credentials)
+pytest agent/sandbox/tests/ -v -m "not integration"
+
+# Run integration tests if credentials available
+if [ -n "$ALIYUN_ACCESS_KEY_ID" ]; then
+ pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
+fi
+```
diff --git a/sdk/python/test/libs/__init__.py b/agent/sandbox/tests/__init__.py
similarity index 93%
rename from sdk/python/test/libs/__init__.py
rename to agent/sandbox/tests/__init__.py
index 177b91dd051..f6a24fc983e 100644
--- a/sdk/python/test/libs/__init__.py
+++ b/agent/sandbox/tests/__init__.py
@@ -13,3 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
+"""
+Sandbox provider tests package.
+"""
diff --git a/agent/sandbox/tests/pytest.ini b/agent/sandbox/tests/pytest.ini
new file mode 100644
index 00000000000..61b0d3392ec
--- /dev/null
+++ b/agent/sandbox/tests/pytest.ini
@@ -0,0 +1,33 @@
+[pytest]
+# Pytest configuration for sandbox tests
+
+# Test discovery patterns
+python_files = test_*.py
+python_classes = Test*
+python_functions = test_*
+
+# Markers for different test types
+markers =
+ integration: Tests that require external services (Aliyun API, etc.)
+ unit: Fast tests that don't require external services
+ slow: Tests that take a long time to run
+
+# Test paths
+testpaths = .
+
+# Minimum version
+minversion = 7.0
+
+# Output options
+addopts =
+ -v
+ --strict-markers
+ --tb=short
+ --disable-warnings
+
+# Log options
+log_cli = false
+log_cli_level = INFO
+
+# Coverage options (if using pytest-cov)
+# addopts = --cov=agent.sandbox --cov-report=html --cov-report=term
diff --git a/sandbox/tests/sandbox_security_tests_full.py b/agent/sandbox/tests/sandbox_security_tests_full.py
similarity index 100%
rename from sandbox/tests/sandbox_security_tests_full.py
rename to agent/sandbox/tests/sandbox_security_tests_full.py
diff --git a/agent/sandbox/tests/test_aliyun_codeinterpreter.py b/agent/sandbox/tests/test_aliyun_codeinterpreter.py
new file mode 100644
index 00000000000..9b4a369b572
--- /dev/null
+++ b/agent/sandbox/tests/test_aliyun_codeinterpreter.py
@@ -0,0 +1,329 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for Aliyun Code Interpreter provider.
+
+These tests use mocks and don't require real Aliyun credentials.
+
+Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+Official SDK: https://github.com/Serverless-Devs/agentrun-sdk-python
+"""
+
+import pytest
+from unittest.mock import patch, MagicMock
+
+from agent.sandbox.providers.base import SandboxProvider
+from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+
+
+class TestAliyunCodeInterpreterProvider:
+ """Test AliyunCodeInterpreterProvider implementation."""
+
+ def test_provider_initialization(self):
+ """Test provider initialization."""
+ provider = AliyunCodeInterpreterProvider()
+
+ assert provider.access_key_id == ""
+ assert provider.access_key_secret == ""
+ assert provider.account_id == ""
+ assert provider.region == "cn-hangzhou"
+ assert provider.template_name == ""
+ assert provider.timeout == 30
+ assert not provider._initialized
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
+ def test_initialize_success(self, mock_template):
+ """Test successful initialization."""
+ # Mock health check response
+ mock_template.list.return_value = []
+
+ provider = AliyunCodeInterpreterProvider()
+ result = provider.initialize(
+ {
+ "access_key_id": "LTAI5tXXXXXXXXXX",
+ "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
+ "account_id": "1234567890123456",
+ "region": "cn-hangzhou",
+ "template_name": "python-sandbox",
+ "timeout": 20,
+ }
+ )
+
+ assert result is True
+ assert provider.access_key_id == "LTAI5tXXXXXXXXXX"
+ assert provider.access_key_secret == "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+ assert provider.account_id == "1234567890123456"
+ assert provider.region == "cn-hangzhou"
+ assert provider.template_name == "python-sandbox"
+ assert provider.timeout == 20
+ assert provider._initialized
+
+ def test_initialize_missing_credentials(self):
+ """Test initialization with missing credentials."""
+ provider = AliyunCodeInterpreterProvider()
+
+ # Missing access_key_id
+ result = provider.initialize({"access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"})
+ assert result is False
+
+ # Missing access_key_secret
+ result = provider.initialize({"access_key_id": "LTAI5tXXXXXXXXXX"})
+ assert result is False
+
+ # Missing account_id
+ provider2 = AliyunCodeInterpreterProvider()
+ result = provider2.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"})
+ assert result is False
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
+ def test_initialize_default_config(self, mock_template):
+ """Test initialization with default config."""
+ mock_template.list.return_value = []
+
+ provider = AliyunCodeInterpreterProvider()
+ result = provider.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", "account_id": "1234567890123456"})
+
+ assert result is True
+ assert provider.region == "cn-hangzhou"
+ assert provider.template_name == ""
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
+ def test_create_instance_python(self, mock_sandbox_class):
+ """Test creating a Python instance."""
+ # Mock successful instance creation
+ mock_sandbox = MagicMock()
+ mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY"
+ mock_sandbox_class.return_value = mock_sandbox
+
+ provider = AliyunCodeInterpreterProvider()
+ provider._initialized = True
+ provider._config = MagicMock()
+
+ instance = provider.create_instance("python")
+
+ assert instance.provider == "aliyun_codeinterpreter"
+ assert instance.status == "READY"
+ assert instance.metadata["language"] == "python"
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
+ def test_create_instance_javascript(self, mock_sandbox_class):
+ """Test creating a JavaScript instance."""
+ mock_sandbox = MagicMock()
+ mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY"
+ mock_sandbox_class.return_value = mock_sandbox
+
+ provider = AliyunCodeInterpreterProvider()
+ provider._initialized = True
+ provider._config = MagicMock()
+
+ instance = provider.create_instance("javascript")
+
+ assert instance.metadata["language"] == "javascript"
+
+ def test_create_instance_not_initialized(self):
+ """Test creating instance when provider not initialized."""
+ provider = AliyunCodeInterpreterProvider()
+
+ with pytest.raises(RuntimeError, match="Provider not initialized"):
+ provider.create_instance("python")
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
+ def test_execute_code_success(self, mock_sandbox_class):
+ """Test successful code execution."""
+ # Mock sandbox instance
+ mock_sandbox = MagicMock()
+ mock_sandbox.context.execute.return_value = {
+ "results": [{"type": "stdout", "text": "Hello, World!"}, {"type": "result", "text": "None"}, {"type": "endOfExecution", "status": "ok"}],
+ "contextId": "kernel-12345-67890",
+ }
+ mock_sandbox_class.return_value = mock_sandbox
+
+ provider = AliyunCodeInterpreterProvider()
+ provider._initialized = True
+ provider._config = MagicMock()
+
+ result = provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="print('Hello, World!')", language="python", timeout=10)
+
+ assert result.stdout == "Hello, World!"
+ assert result.stderr == ""
+ assert result.exit_code == 0
+ assert result.execution_time > 0
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
+ def test_execute_code_timeout(self, mock_sandbox_class):
+ """Test code execution timeout."""
+ from agentrun.utils.exception import ServerError
+
+ mock_sandbox = MagicMock()
+ mock_sandbox.context.execute.side_effect = ServerError(408, "Request timeout")
+ mock_sandbox_class.return_value = mock_sandbox
+
+ provider = AliyunCodeInterpreterProvider()
+ provider._initialized = True
+ provider._config = MagicMock()
+
+ with pytest.raises(TimeoutError, match="Execution timed out"):
+ provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="while True: pass", language="python", timeout=5)
+
+ @patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
+ def test_execute_code_with_error(self, mock_sandbox_class):
+ """Test code execution with error."""
+ mock_sandbox = MagicMock()
+ mock_sandbox.context.execute.return_value = {
+ "results": [{"type": "stderr", "text": "Traceback..."}, {"type": "error", "text": "NameError: name 'x' is not defined"}, {"type": "endOfExecution", "status": "error"}]
+ }
+ mock_sandbox_class.return_value = mock_sandbox
+
+ provider = AliyunCodeInterpreterProvider()
+ provider._initialized = True
+ provider._config = MagicMock()
+
+ result = provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="print(x)", language="python")
+
+ assert result.exit_code != 0
+ assert len(result.stderr) > 0
+
+ def test_get_supported_languages(self):
+ """Test getting supported languages."""
+ provider = AliyunCodeInterpreterProvider()
+
+ languages = provider.get_supported_languages()
+
+ assert "python" in languages
+ assert "javascript" in languages
+
+ def test_get_config_schema(self):
+ """Test getting configuration schema."""
+ schema = AliyunCodeInterpreterProvider.get_config_schema()
+
+ assert "access_key_id" in schema
+ assert schema["access_key_id"]["required"] is True
+
+ assert "access_key_secret" in schema
+ assert schema["access_key_secret"]["required"] is True
+
+ assert "account_id" in schema
+ assert schema["account_id"]["required"] is True
+
+ assert "region" in schema
+ assert "template_name" in schema
+ assert "timeout" in schema
+
+ def test_validate_config_success(self):
+ """Test successful configuration validation."""
+ provider = AliyunCodeInterpreterProvider()
+
+ is_valid, error_msg = provider.validate_config({"access_key_id": "LTAI5tXXXXXXXXXX", "account_id": "1234567890123456", "region": "cn-hangzhou"})
+
+ assert is_valid is True
+ assert error_msg is None
+
+ def test_validate_config_invalid_access_key(self):
+ """Test validation with invalid access key format."""
+ provider = AliyunCodeInterpreterProvider()
+
+ is_valid, error_msg = provider.validate_config({"access_key_id": "INVALID_KEY"})
+
+ assert is_valid is False
+ assert "AccessKey ID format" in error_msg
+
+ def test_validate_config_missing_account_id(self):
+ """Test validation with missing account ID."""
+ provider = AliyunCodeInterpreterProvider()
+
+ is_valid, error_msg = provider.validate_config({})
+
+ assert is_valid is False
+ assert "Account ID" in error_msg
+
+ def test_validate_config_invalid_region(self):
+ """Test validation with invalid region."""
+ provider = AliyunCodeInterpreterProvider()
+
+ is_valid, error_msg = provider.validate_config(
+ {
+ "access_key_id": "LTAI5tXXXXXXXXXX",
+ "account_id": "1234567890123456", # Provide required field
+ "region": "us-west-1",
+ }
+ )
+
+ assert is_valid is False
+ assert "Invalid region" in error_msg
+
+ def test_validate_config_invalid_timeout(self):
+ """Test validation with invalid timeout (> 30 seconds)."""
+ provider = AliyunCodeInterpreterProvider()
+
+ is_valid, error_msg = provider.validate_config(
+ {
+ "access_key_id": "LTAI5tXXXXXXXXXX",
+ "account_id": "1234567890123456", # Provide required field
+ "timeout": 60,
+ }
+ )
+
+ assert is_valid is False
+ assert "Timeout must be between 1 and 30 seconds" in error_msg
+
+ def test_normalize_language_python(self):
+ """Test normalizing Python language identifier."""
+ provider = AliyunCodeInterpreterProvider()
+
+ assert provider._normalize_language("python") == "python"
+ assert provider._normalize_language("python3") == "python"
+ assert provider._normalize_language("PYTHON") == "python"
+
+ def test_normalize_language_javascript(self):
+ """Test normalizing JavaScript language identifier."""
+ provider = AliyunCodeInterpreterProvider()
+
+ assert provider._normalize_language("javascript") == "javascript"
+ assert provider._normalize_language("nodejs") == "javascript"
+ assert provider._normalize_language("JavaScript") == "javascript"
+
+
+class TestAliyunCodeInterpreterInterface:
+ """Test that Aliyun provider correctly implements the interface."""
+
+ def test_aliyun_provider_is_abstract(self):
+ """Test that AliyunCodeInterpreterProvider is a SandboxProvider."""
+ provider = AliyunCodeInterpreterProvider()
+
+ assert isinstance(provider, SandboxProvider)
+
+ def test_aliyun_provider_has_abstract_methods(self):
+ """Test that AliyunCodeInterpreterProvider implements all abstract methods."""
+ provider = AliyunCodeInterpreterProvider()
+
+ assert hasattr(provider, "initialize")
+ assert callable(provider.initialize)
+
+ assert hasattr(provider, "create_instance")
+ assert callable(provider.create_instance)
+
+ assert hasattr(provider, "execute_code")
+ assert callable(provider.execute_code)
+
+ assert hasattr(provider, "destroy_instance")
+ assert callable(provider.destroy_instance)
+
+ assert hasattr(provider, "health_check")
+ assert callable(provider.health_check)
+
+ assert hasattr(provider, "get_supported_languages")
+ assert callable(provider.get_supported_languages)
diff --git a/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py b/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py
new file mode 100644
index 00000000000..5aa11d52ef2
--- /dev/null
+++ b/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py
@@ -0,0 +1,353 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Integration tests for Aliyun Code Interpreter provider.
+
+These tests require real Aliyun credentials and will make actual API calls.
+To run these tests, set the following environment variables:
+
+ export AGENTRUN_ACCESS_KEY_ID="LTAI5t..."
+ export AGENTRUN_ACCESS_KEY_SECRET="..."
+ export AGENTRUN_ACCOUNT_ID="1234567890..." # Aliyun primary account ID (主账号ID)
+ export AGENTRUN_REGION="cn-hangzhou" # Note: AGENTRUN_REGION (SDK will read this)
+
+Then run:
+ pytest agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py -v
+
+Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
+"""
+
+import os
+import pytest
+from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+
+
+# Skip all tests if credentials are not provided
+pytestmark = pytest.mark.skipif(
+ not all(
+ [
+ os.getenv("AGENTRUN_ACCESS_KEY_ID"),
+ os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
+ os.getenv("AGENTRUN_ACCOUNT_ID"),
+ ]
+ ),
+ reason="Aliyun credentials not set. Set AGENTRUN_ACCESS_KEY_ID, AGENTRUN_ACCESS_KEY_SECRET, and AGENTRUN_ACCOUNT_ID.",
+)
+
+
+@pytest.fixture
+def aliyun_config():
+ """Get Aliyun configuration from environment variables."""
+ return {
+ "access_key_id": os.getenv("AGENTRUN_ACCESS_KEY_ID"),
+ "access_key_secret": os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
+ "account_id": os.getenv("AGENTRUN_ACCOUNT_ID"),
+ "region": os.getenv("AGENTRUN_REGION", "cn-hangzhou"),
+ "template_name": os.getenv("AGENTRUN_TEMPLATE_NAME", ""),
+ "timeout": 30,
+ }
+
+
+@pytest.fixture
+def provider(aliyun_config):
+ """Create an initialized Aliyun provider."""
+ provider = AliyunCodeInterpreterProvider()
+ initialized = provider.initialize(aliyun_config)
+ if not initialized:
+ pytest.skip("Failed to initialize Aliyun provider. Check credentials, account ID, and network.")
+ return provider
+
+
+@pytest.mark.integration
+class TestAliyunCodeInterpreterIntegration:
+ """Integration tests for Aliyun Code Interpreter provider."""
+
+ def test_initialize_provider(self, aliyun_config):
+ """Test provider initialization with real credentials."""
+ provider = AliyunCodeInterpreterProvider()
+ result = provider.initialize(aliyun_config)
+
+ assert result is True
+ assert provider._initialized is True
+
+ def test_health_check(self, provider):
+ """Test health check with real API."""
+ result = provider.health_check()
+
+ assert result is True
+
+ def test_get_supported_languages(self, provider):
+ """Test getting supported languages."""
+ languages = provider.get_supported_languages()
+
+ assert "python" in languages
+ assert "javascript" in languages
+ assert isinstance(languages, list)
+
+ def test_create_python_instance(self, provider):
+ """Test creating a Python sandbox instance."""
+ try:
+ instance = provider.create_instance("python")
+
+ assert instance.provider == "aliyun_codeinterpreter"
+ assert instance.status in ["READY", "CREATING"]
+ assert instance.metadata["language"] == "python"
+ assert len(instance.instance_id) > 0
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Instance creation failed: {str(e)}. API might not be available yet.")
+
+ def test_execute_python_code(self, provider):
+ """Test executing Python code in the sandbox."""
+ try:
+ # Create instance
+ instance = provider.create_instance("python")
+
+ # Execute simple code
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code="print('Hello from Aliyun Code Interpreter!')\nprint(42)",
+ language="python",
+ timeout=30, # Max 30 seconds
+ )
+
+ assert result.exit_code == 0
+ assert "Hello from Aliyun Code Interpreter!" in result.stdout
+ assert "42" in result.stdout
+ assert result.execution_time > 0
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Code execution test failed: {str(e)}. API might not be available yet.")
+
+ def test_execute_python_code_with_arguments(self, provider):
+ """Test executing Python code with arguments parameter."""
+ try:
+ # Create instance
+ instance = provider.create_instance("python")
+
+ # Execute code with arguments
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code="""def main(name: str, count: int) -> dict:
+ return {"message": f"Hello {name}!" * count}
+""",
+ language="python",
+ timeout=30,
+ arguments={"name": "World", "count": 2}
+ )
+
+ assert result.exit_code == 0
+ assert "Hello World!Hello World!" in result.stdout
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Arguments test failed: {str(e)}. API might not be available yet.")
+
+ def test_execute_python_code_with_error(self, provider):
+ """Test executing Python code that produces an error."""
+ try:
+ # Create instance
+ instance = provider.create_instance("python")
+
+ # Execute code with error
+ result = provider.execute_code(instance_id=instance.instance_id, code="raise ValueError('Test error')", language="python", timeout=30)
+
+ assert result.exit_code != 0
+ assert len(result.stderr) > 0 or "ValueError" in result.stdout
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Error handling test failed: {str(e)}. API might not be available yet.")
+
+ def test_execute_javascript_code(self, provider):
+ """Test executing JavaScript code in the sandbox."""
+ try:
+ # Create instance
+ instance = provider.create_instance("javascript")
+
+ # Execute simple code
+ result = provider.execute_code(instance_id=instance.instance_id, code="console.log('Hello from JavaScript!');", language="javascript", timeout=30)
+
+ assert result.exit_code == 0
+ assert "Hello from JavaScript!" in result.stdout
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"JavaScript execution test failed: {str(e)}. API might not be available yet.")
+
+ def test_execute_javascript_code_with_arguments(self, provider):
+ """Test executing JavaScript code with arguments parameter."""
+ try:
+ # Create instance
+ instance = provider.create_instance("javascript")
+
+ # Execute code with arguments
+ result = provider.execute_code(
+ instance_id=instance.instance_id,
+ code="""function main(args) {
+ const { name, count } = args;
+ return `Hello ${name}!`.repeat(count);
+}""",
+ language="javascript",
+ timeout=30,
+ arguments={"name": "World", "count": 2}
+ )
+
+ assert result.exit_code == 0
+ assert "Hello World!Hello World!" in result.stdout
+
+ # Clean up
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"JavaScript arguments test failed: {str(e)}. API might not be available yet.")
+
+ def test_destroy_instance(self, provider):
+ """Test destroying a sandbox instance."""
+ try:
+ # Create instance
+ instance = provider.create_instance("python")
+
+ # Destroy instance
+ result = provider.destroy_instance(instance.instance_id)
+
+ # Note: The API might return True immediately or async
+ assert result is True or result is False
+ except Exception as e:
+ pytest.skip(f"Destroy instance test failed: {str(e)}. API might not be available yet.")
+
+ def test_config_validation(self, provider):
+ """Test configuration validation."""
+ # Valid config
+ is_valid, error = provider.validate_config({"access_key_id": "LTAI5tXXXXXXXXXX", "account_id": "1234567890123456", "region": "cn-hangzhou", "timeout": 30})
+ assert is_valid is True
+ assert error is None
+
+ # Invalid access key
+ is_valid, error = provider.validate_config({"access_key_id": "INVALID_KEY"})
+ assert is_valid is False
+
+ # Missing account ID
+ is_valid, error = provider.validate_config({})
+ assert is_valid is False
+ assert "Account ID" in error
+
+ def test_timeout_limit(self, provider):
+ """Test that timeout is limited to 30 seconds."""
+ # Timeout > 30 should be clamped to 30
+ provider2 = AliyunCodeInterpreterProvider()
+ provider2.initialize(
+ {
+ "access_key_id": os.getenv("AGENTRUN_ACCESS_KEY_ID"),
+ "access_key_secret": os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
+ "account_id": os.getenv("AGENTRUN_ACCOUNT_ID"),
+ "timeout": 60, # Request 60 seconds
+ }
+ )
+
+ # Should be clamped to 30
+ assert provider2.timeout == 30
+
+
+@pytest.mark.integration
+class TestAliyunCodeInterpreterScenarios:
+ """Test real-world usage scenarios."""
+
+ def test_data_processing_workflow(self, provider):
+ """Test a simple data processing workflow."""
+ try:
+ instance = provider.create_instance("python")
+
+ # Execute data processing code
+ code = """
+import json
+data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
+result = json.dumps(data, indent=2)
+print(result)
+"""
+ result = provider.execute_code(instance_id=instance.instance_id, code=code, language="python", timeout=30)
+
+ assert result.exit_code == 0
+ assert "Alice" in result.stdout
+ assert "Bob" in result.stdout
+
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Data processing test failed: {str(e)}")
+
+ def test_string_manipulation(self, provider):
+ """Test string manipulation operations."""
+ try:
+ instance = provider.create_instance("python")
+
+ code = """
+text = "Hello, World!"
+print(text.upper())
+print(text.lower())
+print(text.replace("World", "Aliyun"))
+"""
+ result = provider.execute_code(instance_id=instance.instance_id, code=code, language="python", timeout=30)
+
+ assert result.exit_code == 0
+ assert "HELLO, WORLD!" in result.stdout
+ assert "hello, world!" in result.stdout
+ assert "Hello, Aliyun!" in result.stdout
+
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"String manipulation test failed: {str(e)}")
+
+ def test_context_persistence(self, provider):
+ """Test code execution with context persistence."""
+ try:
+ instance = provider.create_instance("python")
+
+ # First execution - define variable
+ result1 = provider.execute_code(instance_id=instance.instance_id, code="x = 42\nprint(x)", language="python", timeout=30)
+ assert result1.exit_code == 0
+
+ # Second execution - use variable
+ # Note: Context persistence depends on whether the contextId is reused
+ result2 = provider.execute_code(instance_id=instance.instance_id, code="print(f'x is {x}')", language="python", timeout=30)
+
+ # Context might or might not persist depending on API implementation
+ assert result2.exit_code == 0
+
+ provider.destroy_instance(instance.instance_id)
+ except Exception as e:
+ pytest.skip(f"Context persistence test failed: {str(e)}")
+
+
+def test_without_credentials():
+ """Test that tests are skipped without credentials."""
+ # This test should always run (not skipped)
+ if all(
+ [
+ os.getenv("AGENTRUN_ACCESS_KEY_ID"),
+ os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
+ os.getenv("AGENTRUN_ACCOUNT_ID"),
+ ]
+ ):
+ assert True # Credentials are set
+ else:
+ assert True # Credentials not set, test still passes
diff --git a/agent/sandbox/tests/test_providers.py b/agent/sandbox/tests/test_providers.py
new file mode 100644
index 00000000000..fa2e97ad027
--- /dev/null
+++ b/agent/sandbox/tests/test_providers.py
@@ -0,0 +1,423 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for sandbox provider abstraction layer.
+"""
+
+import pytest
+from unittest.mock import Mock, patch
+import requests
+
+from agent.sandbox.providers.base import SandboxProvider, SandboxInstance, ExecutionResult
+from agent.sandbox.providers.manager import ProviderManager
+from agent.sandbox.providers.self_managed import SelfManagedProvider
+
+
+class TestSandboxDataclasses:
+ """Test sandbox dataclasses."""
+
+ def test_sandbox_instance_creation(self):
+ """Test SandboxInstance dataclass creation."""
+ instance = SandboxInstance(
+ instance_id="test-123",
+ provider="self_managed",
+ status="running",
+ metadata={"language": "python"}
+ )
+
+ assert instance.instance_id == "test-123"
+ assert instance.provider == "self_managed"
+ assert instance.status == "running"
+ assert instance.metadata == {"language": "python"}
+
+ def test_sandbox_instance_default_metadata(self):
+ """Test SandboxInstance with None metadata."""
+ instance = SandboxInstance(
+ instance_id="test-123",
+ provider="self_managed",
+ status="running",
+ metadata=None
+ )
+
+ assert instance.metadata == {}
+
+ def test_execution_result_creation(self):
+ """Test ExecutionResult dataclass creation."""
+ result = ExecutionResult(
+ stdout="Hello, World!",
+ stderr="",
+ exit_code=0,
+ execution_time=1.5,
+ metadata={"status": "success"}
+ )
+
+ assert result.stdout == "Hello, World!"
+ assert result.stderr == ""
+ assert result.exit_code == 0
+ assert result.execution_time == 1.5
+ assert result.metadata == {"status": "success"}
+
+ def test_execution_result_default_metadata(self):
+ """Test ExecutionResult with None metadata."""
+ result = ExecutionResult(
+ stdout="output",
+ stderr="error",
+ exit_code=1,
+ execution_time=0.5,
+ metadata=None
+ )
+
+ assert result.metadata == {}
+
+
+class TestProviderManager:
+ """Test ProviderManager functionality."""
+
+ def test_manager_initialization(self):
+ """Test ProviderManager initialization."""
+ manager = ProviderManager()
+
+ assert manager.current_provider is None
+ assert manager.current_provider_name is None
+ assert not manager.is_configured()
+
+ def test_set_provider(self):
+ """Test setting a provider."""
+ manager = ProviderManager()
+ mock_provider = Mock(spec=SandboxProvider)
+
+ manager.set_provider("self_managed", mock_provider)
+
+ assert manager.current_provider == mock_provider
+ assert manager.current_provider_name == "self_managed"
+ assert manager.is_configured()
+
+ def test_get_provider(self):
+ """Test getting the current provider."""
+ manager = ProviderManager()
+ mock_provider = Mock(spec=SandboxProvider)
+
+ manager.set_provider("self_managed", mock_provider)
+
+ assert manager.get_provider() == mock_provider
+
+ def test_get_provider_name(self):
+ """Test getting the current provider name."""
+ manager = ProviderManager()
+ mock_provider = Mock(spec=SandboxProvider)
+
+ manager.set_provider("self_managed", mock_provider)
+
+ assert manager.get_provider_name() == "self_managed"
+
+ def test_get_provider_when_not_set(self):
+ """Test getting provider when none is set."""
+ manager = ProviderManager()
+
+ assert manager.get_provider() is None
+ assert manager.get_provider_name() is None
+
+
+class TestSelfManagedProvider:
+ """Test SelfManagedProvider implementation."""
+
+ def test_provider_initialization(self):
+ """Test provider initialization."""
+ provider = SelfManagedProvider()
+
+ assert provider.endpoint == "http://localhost:9385"
+ assert provider.timeout == 30
+ assert provider.max_retries == 3
+ assert provider.pool_size == 10
+ assert not provider._initialized
+
+ @patch('requests.get')
+ def test_initialize_success(self, mock_get):
+ """Test successful initialization."""
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ provider = SelfManagedProvider()
+ result = provider.initialize({
+ "endpoint": "http://test-endpoint:9385",
+ "timeout": 60,
+ "max_retries": 5,
+ "pool_size": 20
+ })
+
+ assert result is True
+ assert provider.endpoint == "http://test-endpoint:9385"
+ assert provider.timeout == 60
+ assert provider.max_retries == 5
+ assert provider.pool_size == 20
+ assert provider._initialized
+ mock_get.assert_called_once_with("http://test-endpoint:9385/healthz", timeout=5)
+
+ @patch('requests.get')
+ def test_initialize_failure(self, mock_get):
+ """Test initialization failure."""
+ mock_get.side_effect = Exception("Connection error")
+
+ provider = SelfManagedProvider()
+ result = provider.initialize({"endpoint": "http://invalid:9385"})
+
+ assert result is False
+ assert not provider._initialized
+
+ def test_initialize_default_config(self):
+ """Test initialization with default config."""
+ with patch('requests.get') as mock_get:
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ provider = SelfManagedProvider()
+ result = provider.initialize({})
+
+ assert result is True
+ assert provider.endpoint == "http://localhost:9385"
+ assert provider.timeout == 30
+
+ def test_create_instance_python(self):
+ """Test creating a Python instance."""
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ instance = provider.create_instance("python")
+
+ assert instance.provider == "self_managed"
+ assert instance.status == "running"
+ assert instance.metadata["language"] == "python"
+ assert instance.metadata["endpoint"] == "http://localhost:9385"
+ assert len(instance.instance_id) > 0 # Verify instance_id exists
+
+ def test_create_instance_nodejs(self):
+ """Test creating a Node.js instance."""
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ instance = provider.create_instance("nodejs")
+
+ assert instance.metadata["language"] == "nodejs"
+
+ def test_create_instance_not_initialized(self):
+ """Test creating instance when provider not initialized."""
+ provider = SelfManagedProvider()
+
+ with pytest.raises(RuntimeError, match="Provider not initialized"):
+ provider.create_instance("python")
+
+ @patch('requests.post')
+ def test_execute_code_success(self, mock_post):
+ """Test successful code execution."""
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "status": "success",
+ "stdout": '{"result": 42}',
+ "stderr": "",
+ "exit_code": 0,
+ "time_used_ms": 100.0,
+ "memory_used_kb": 1024.0
+ }
+ mock_post.return_value = mock_response
+
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ result = provider.execute_code(
+ instance_id="test-123",
+ code="def main(): return {'result': 42}",
+ language="python",
+ timeout=10
+ )
+
+ assert result.stdout == '{"result": 42}'
+ assert result.stderr == ""
+ assert result.exit_code == 0
+ assert result.execution_time > 0
+ assert result.metadata["status"] == "success"
+ assert result.metadata["instance_id"] == "test-123"
+
+ @patch('requests.post')
+ def test_execute_code_timeout(self, mock_post):
+ """Test code execution timeout."""
+ mock_post.side_effect = requests.Timeout()
+
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ with pytest.raises(TimeoutError, match="Execution timed out"):
+ provider.execute_code(
+ instance_id="test-123",
+ code="while True: pass",
+ language="python",
+ timeout=5
+ )
+
+ @patch('requests.post')
+ def test_execute_code_http_error(self, mock_post):
+ """Test code execution with HTTP error."""
+ mock_response = Mock()
+ mock_response.status_code = 500
+ mock_response.text = "Internal Server Error"
+ mock_post.return_value = mock_response
+
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ with pytest.raises(RuntimeError, match="HTTP 500"):
+ provider.execute_code(
+ instance_id="test-123",
+ code="invalid code",
+ language="python"
+ )
+
+ def test_execute_code_not_initialized(self):
+ """Test executing code when provider not initialized."""
+ provider = SelfManagedProvider()
+
+ with pytest.raises(RuntimeError, match="Provider not initialized"):
+ provider.execute_code(
+ instance_id="test-123",
+ code="print('hello')",
+ language="python"
+ )
+
+ def test_destroy_instance(self):
+ """Test destroying an instance (no-op for self-managed)."""
+ provider = SelfManagedProvider()
+ provider._initialized = True
+
+ # For self-managed, destroy_instance is a no-op
+ result = provider.destroy_instance("test-123")
+
+ assert result is True
+
+ @patch('requests.get')
+ def test_health_check_success(self, mock_get):
+ """Test successful health check."""
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ provider = SelfManagedProvider()
+
+ result = provider.health_check()
+
+ assert result is True
+ mock_get.assert_called_once_with("http://localhost:9385/healthz", timeout=5)
+
+ @patch('requests.get')
+ def test_health_check_failure(self, mock_get):
+ """Test health check failure."""
+ mock_get.side_effect = Exception("Connection error")
+
+ provider = SelfManagedProvider()
+
+ result = provider.health_check()
+
+ assert result is False
+
+ def test_get_supported_languages(self):
+ """Test getting supported languages."""
+ provider = SelfManagedProvider()
+
+ languages = provider.get_supported_languages()
+
+ assert "python" in languages
+ assert "nodejs" in languages
+ assert "javascript" in languages
+
+ def test_get_config_schema(self):
+ """Test getting configuration schema."""
+ schema = SelfManagedProvider.get_config_schema()
+
+ assert "endpoint" in schema
+ assert schema["endpoint"]["type"] == "string"
+ assert schema["endpoint"]["required"] is True
+ assert schema["endpoint"]["default"] == "http://localhost:9385"
+
+ assert "timeout" in schema
+ assert schema["timeout"]["type"] == "integer"
+ assert schema["timeout"]["default"] == 30
+
+ assert "max_retries" in schema
+ assert schema["max_retries"]["type"] == "integer"
+
+ assert "pool_size" in schema
+ assert schema["pool_size"]["type"] == "integer"
+
+ def test_normalize_language_python(self):
+ """Test normalizing Python language identifier."""
+ provider = SelfManagedProvider()
+
+ assert provider._normalize_language("python") == "python"
+ assert provider._normalize_language("python3") == "python"
+ assert provider._normalize_language("PYTHON") == "python"
+ assert provider._normalize_language("Python3") == "python"
+
+ def test_normalize_language_javascript(self):
+ """Test normalizing JavaScript language identifier."""
+ provider = SelfManagedProvider()
+
+ assert provider._normalize_language("javascript") == "nodejs"
+ assert provider._normalize_language("nodejs") == "nodejs"
+ assert provider._normalize_language("JavaScript") == "nodejs"
+ assert provider._normalize_language("NodeJS") == "nodejs"
+
+ def test_normalize_language_default(self):
+ """Test language normalization with empty/unknown input."""
+ provider = SelfManagedProvider()
+
+ assert provider._normalize_language("") == "python"
+ assert provider._normalize_language(None) == "python"
+ assert provider._normalize_language("unknown") == "unknown"
+
+
+class TestProviderInterface:
+ """Test that providers correctly implement the interface."""
+
+ def test_self_managed_provider_is_abstract(self):
+ """Test that SelfManagedProvider is a SandboxProvider."""
+ provider = SelfManagedProvider()
+
+ assert isinstance(provider, SandboxProvider)
+
+ def test_self_managed_provider_has_abstract_methods(self):
+ """Test that SelfManagedProvider implements all abstract methods."""
+ provider = SelfManagedProvider()
+
+ # Check all abstract methods are implemented
+ assert hasattr(provider, 'initialize')
+ assert callable(provider.initialize)
+
+ assert hasattr(provider, 'create_instance')
+ assert callable(provider.create_instance)
+
+ assert hasattr(provider, 'execute_code')
+ assert callable(provider.execute_code)
+
+ assert hasattr(provider, 'destroy_instance')
+ assert callable(provider.destroy_instance)
+
+ assert hasattr(provider, 'health_check')
+ assert callable(provider.health_check)
+
+ assert hasattr(provider, 'get_supported_languages')
+ assert callable(provider.get_supported_languages)
diff --git a/agent/sandbox/tests/verify_sdk.py b/agent/sandbox/tests/verify_sdk.py
new file mode 100644
index 00000000000..94aea18f887
--- /dev/null
+++ b/agent/sandbox/tests/verify_sdk.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+"""
+Quick verification script for Aliyun Code Interpreter provider using official SDK.
+"""
+
+import importlib.util
+import sys
+
+sys.path.insert(0, ".")
+
+print("=" * 60)
+print("Aliyun Code Interpreter Provider - SDK Verification")
+print("=" * 60)
+
+# Test 1: Import provider
+print("\n[1/5] Testing provider import...")
+try:
+ from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
+
+ print("✓ Provider imported successfully")
+except ImportError as e:
+ print(f"✗ Import failed: {e}")
+ sys.exit(1)
+
+# Test 2: Check provider class
+print("\n[2/5] Testing provider class...")
+provider = AliyunCodeInterpreterProvider()
+assert hasattr(provider, "initialize")
+assert hasattr(provider, "create_instance")
+assert hasattr(provider, "execute_code")
+assert hasattr(provider, "destroy_instance")
+assert hasattr(provider, "health_check")
+print("✓ Provider has all required methods")
+
+# Test 3: Check SDK imports
+print("\n[3/5] Testing SDK imports...")
+try:
+ # Check if agentrun SDK is available using importlib
+ if (
+ importlib.util.find_spec("agentrun.sandbox") is None
+ or importlib.util.find_spec("agentrun.utils.config") is None
+ or importlib.util.find_spec("agentrun.utils.exception") is None
+ ):
+ raise ImportError("agentrun SDK not found")
+
+ # Verify imports work (assign to _ to indicate they're intentionally unused)
+ from agentrun.sandbox import CodeInterpreterSandbox, TemplateType, CodeLanguage
+ from agentrun.utils.config import Config
+ from agentrun.utils.exception import ServerError
+ _ = (CodeInterpreterSandbox, TemplateType, CodeLanguage, Config, ServerError)
+
+ print("✓ SDK modules imported successfully")
+except ImportError as e:
+ print(f"✗ SDK import failed: {e}")
+ sys.exit(1)
+
+# Test 4: Check config schema
+print("\n[4/5] Testing configuration schema...")
+schema = AliyunCodeInterpreterProvider.get_config_schema()
+required_fields = ["access_key_id", "access_key_secret", "account_id"]
+for field in required_fields:
+ assert field in schema
+ assert schema[field]["required"] is True
+print(f"✓ All required fields present: {', '.join(required_fields)}")
+
+# Test 5: Check supported languages
+print("\n[5/5] Testing supported languages...")
+languages = provider.get_supported_languages()
+assert "python" in languages
+assert "javascript" in languages
+print(f"✓ Supported languages: {', '.join(languages)}")
+
+print("\n" + "=" * 60)
+print("All verification tests passed! ✓")
+print("=" * 60)
+print("\nNote: This provider now uses the official agentrun-sdk.")
+print("SDK Documentation: https://github.com/Serverless-Devs/agentrun-sdk-python")
+print("API Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter")
diff --git a/sandbox/uv.lock b/agent/sandbox/uv.lock
similarity index 60%
rename from sandbox/uv.lock
rename to agent/sandbox/uv.lock
index ef681064619..e780a44ea65 100644
--- a/sandbox/uv.lock
+++ b/agent/sandbox/uv.lock
@@ -1,7 +1,16 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.10"
+[[package]]
+name = "annotated-doc"
+version = "0.0.4"
+source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
+sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
+wheels = [
+ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
+]
+
[[package]]
name = "annotated-types"
version = "0.7.0"
@@ -16,7 +25,6 @@ name = "anyio"
version = "4.9.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
- { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "sniffio" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
@@ -53,32 +61,6 @@ version = "3.4.2"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/28/9901804da60055b406e1a1c5ba7aac1276fb77f1dde635aabfc7fd84b8ab/charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941", size = 201818, upload-time = "2025-05-02T08:31:46.725Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d9/9b/892a8c8af9110935e5adcbb06d9c6fe741b6bb02608c6513983048ba1a18/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd", size = 144649, upload-time = "2025-05-02T08:31:48.889Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7b/a5/4179abd063ff6414223575e008593861d62abfc22455b5d1a44995b7c101/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6", size = 155045, upload-time = "2025-05-02T08:31:50.757Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3b/95/bc08c7dfeddd26b4be8c8287b9bb055716f31077c8b0ea1cd09553794665/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d", size = 147356, upload-time = "2025-05-02T08:31:52.634Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/2d/7a5b635aa65284bf3eab7653e8b4151ab420ecbae918d3e359d1947b4d61/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86", size = 149471, upload-time = "2025-05-02T08:31:56.207Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ae/38/51fc6ac74251fd331a8cfdb7ec57beba8c23fd5493f1050f71c87ef77ed0/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c", size = 151317, upload-time = "2025-05-02T08:31:57.613Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/17/edee1e32215ee6e9e46c3e482645b46575a44a2d72c7dfd49e49f60ce6bf/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0", size = 146368, upload-time = "2025-05-02T08:31:59.468Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/26/2c/ea3e66f2b5f21fd00b2825c94cafb8c326ea6240cd80a91eb09e4a285830/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef", size = 154491, upload-time = "2025-05-02T08:32:01.219Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/47/7be7fa972422ad062e909fd62460d45c3ef4c141805b7078dbab15904ff7/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6", size = 157695, upload-time = "2025-05-02T08:32:03.045Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2f/42/9f02c194da282b2b340f28e5fb60762de1151387a36842a92b533685c61e/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366", size = 154849, upload-time = "2025-05-02T08:32:04.651Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/67/44/89cacd6628f31fb0b63201a618049be4be2a7435a31b55b5eb1c3674547a/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db", size = 150091, upload-time = "2025-05-02T08:32:06.719Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/79/4b8da9f712bc079c0f16b6d67b099b0b8d808c2292c937f267d816ec5ecc/charset_normalizer-3.4.2-cp310-cp310-win32.whl", hash = "sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a", size = 98445, upload-time = "2025-05-02T08:32:08.66Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7d/d7/96970afb4fb66497a40761cdf7bd4f6fca0fc7bafde3a84f836c1f57a926/charset_normalizer-3.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509", size = 105782, upload-time = "2025-05-02T08:32:10.46Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/85/4c40d00dcc6284a1c1ad5de5e0996b06f39d8232f1031cd23c2f5c07ee86/charset_normalizer-3.4.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:be1e352acbe3c78727a16a455126d9ff83ea2dfdcbc83148d2982305a04714c2", size = 198794, upload-time = "2025-05-02T08:32:11.945Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/41/d9/7a6c0b9db952598e97e93cbdfcb91bacd89b9b88c7c983250a77c008703c/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa88ca0b1932e93f2d961bf3addbb2db902198dca337d88c89e1559e066e7645", size = 142846, upload-time = "2025-05-02T08:32:13.946Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/82/a37989cda2ace7e37f36c1a8ed16c58cf48965a79c2142713244bf945c89/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d524ba3f1581b35c03cb42beebab4a13e6cdad7b36246bd22541fa585a56cccd", size = 153350, upload-time = "2025-05-02T08:32:15.873Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/68/a576b31b694d07b53807269d05ec3f6f1093e9545e8607121995ba7a8313/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28a1005facc94196e1fb3e82a3d442a9d9110b8434fc1ded7a24a2983c9888d8", size = 145657, upload-time = "2025-05-02T08:32:17.283Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/9b/ad67f03d74554bed3aefd56fe836e1623a50780f7c998d00ca128924a499/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdb20a30fe1175ecabed17cbf7812f7b804b8a315a25f24678bcdf120a90077f", size = 147260, upload-time = "2025-05-02T08:32:18.807Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/e6/8aebae25e328160b20e31a7e9929b1578bbdc7f42e66f46595a432f8539e/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f5d9ed7f254402c9e7d35d2f5972c9bbea9040e99cd2861bd77dc68263277c7", size = 149164, upload-time = "2025-05-02T08:32:20.333Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8b/f2/b3c2f07dbcc248805f10e67a0262c93308cfa149a4cd3d1fe01f593e5fd2/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd387a49825780ff861998cd959767800d54f8308936b21025326de4b5a42b9", size = 144571, upload-time = "2025-05-02T08:32:21.86Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/5b/c3f3a94bc345bc211622ea59b4bed9ae63c00920e2e8f11824aa5708e8b7/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f0aa37f3c979cf2546b73e8222bbfa3dc07a641585340179d768068e3455e544", size = 151952, upload-time = "2025-05-02T08:32:23.434Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e2/4d/ff460c8b474122334c2fa394a3f99a04cf11c646da895f81402ae54f5c42/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e70e990b2137b29dc5564715de1e12701815dacc1d056308e2b17e9095372a82", size = 155959, upload-time = "2025-05-02T08:32:24.993Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/2b/b964c6a2fda88611a1fe3d4c400d39c66a42d6c169c924818c848f922415/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0c8c57f84ccfc871a48a47321cfa49ae1df56cd1d965a09abe84066f6853b9c0", size = 153030, upload-time = "2025-05-02T08:32:26.435Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/59/2e/d3b9811db26a5ebf444bc0fa4f4be5aa6d76fc6e1c0fd537b16c14e849b6/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6b66f92b17849b85cad91259efc341dce9c1af48e2173bf38a85c6329f1033e5", size = 148015, upload-time = "2025-05-02T08:32:28.376Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/90/07/c5fd7c11eafd561bb51220d600a788f1c8d77c5eef37ee49454cc5c35575/charset_normalizer-3.4.2-cp311-cp311-win32.whl", hash = "sha256:daac4765328a919a805fa5e2720f3e94767abd632ae410a9062dff5412bae65a", size = 98106, upload-time = "2025-05-02T08:32:30.281Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/05/5e33dbef7e2f773d672b6d79f10ec633d4a71cd96db6673625838a4fd532/charset_normalizer-3.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:e53efc7c7cee4c1e70661e2e112ca46a575f90ed9ae3fef200f2a25e954f4b28", size = 105402, upload-time = "2025-05-02T08:32:32.191Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/a4/37f4d6035c89cac7930395a35cc0f1b872e652eaafb76a6075943754f095/charset_normalizer-3.4.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0c29de6a1a95f24b9a1aa7aefd27d2487263f00dfd55a77719b530788f75cff7", size = 199936, upload-time = "2025-05-02T08:32:33.712Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ee/8a/1a5e33b73e0d9287274f899d967907cd0bf9c343e651755d9307e0dbf2b3/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cddf7bd982eaa998934a91f69d182aec997c6c468898efe6679af88283b498d3", size = 143790, upload-time = "2025-05-02T08:32:35.768Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/52/59521f1d8e6ab1482164fa21409c5ef44da3e9f653c13ba71becdd98dec3/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcbe676a55d7445b22c10967bceaaf0ee69407fbe0ece4d032b6eb8d4565982a", size = 153924, upload-time = "2025-05-02T08:32:37.284Z" },
@@ -141,27 +123,19 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998, upload-time = "2025-01-27T10:46:09.186Z" },
]
-[[package]]
-name = "exceptiongroup"
-version = "1.2.2"
-source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
-sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883, upload-time = "2024-07-12T22:26:00.161Z" }
-wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453, upload-time = "2024-07-12T22:25:58.476Z" },
-]
-
[[package]]
name = "fastapi"
-version = "0.115.12"
+version = "0.128.0"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
+ { name = "annotated-doc" },
{ name = "pydantic" },
{ name = "starlette" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236, upload-time = "2025-03-23T22:55:43.822Z" }
+sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164, upload-time = "2025-03-23T22:55:42.101Z" },
+ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
]
[[package]]
@@ -304,33 +278,6 @@ dependencies = [
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817, upload-time = "2025-04-23T18:30:43.919Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357, upload-time = "2025-04-23T18:30:46.372Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011, upload-time = "2025-04-23T18:30:47.591Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730, upload-time = "2025-04-23T18:30:49.328Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178, upload-time = "2025-04-23T18:30:50.907Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462, upload-time = "2025-04-23T18:30:52.083Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652, upload-time = "2025-04-23T18:30:53.389Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306, upload-time = "2025-04-23T18:30:54.661Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720, upload-time = "2025-04-23T18:30:56.11Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915, upload-time = "2025-04-23T18:30:57.501Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884, upload-time = "2025-04-23T18:30:58.867Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496, upload-time = "2025-04-23T18:31:00.078Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019, upload-time = "2025-04-23T18:31:01.335Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/8d/71db63483d518cbbf290261a1fc2839d17ff89fce7089e08cad07ccfce67/pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7", size = 2028584, upload-time = "2025-04-23T18:31:03.106Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/24/2f/3cfa7244ae292dd850989f328722d2aef313f74ffc471184dc509e1e4e5a/pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246", size = 1855071, upload-time = "2025-04-23T18:31:04.621Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/d3/4ae42d33f5e3f50dd467761304be2fa0a9417fbf09735bc2cce003480f2a/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f", size = 1897823, upload-time = "2025-04-23T18:31:06.377Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/f3/aa5976e8352b7695ff808599794b1fba2a9ae2ee954a3426855935799488/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc", size = 1983792, upload-time = "2025-04-23T18:31:07.93Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d5/7a/cda9b5a23c552037717f2b2a5257e9b2bfe45e687386df9591eff7b46d28/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de", size = 2136338, upload-time = "2025-04-23T18:31:09.283Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/9f/b8f9ec8dd1417eb9da784e91e1667d58a2a4a7b7b34cf4af765ef663a7e5/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a", size = 2730998, upload-time = "2025-04-23T18:31:11.7Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/bc/cd720e078576bdb8255d5032c5d63ee5c0bf4b7173dd955185a1d658c456/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef", size = 2003200, upload-time = "2025-04-23T18:31:13.536Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/22/3602b895ee2cd29d11a2b349372446ae9727c32e78a94b3d588a40fdf187/pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e", size = 2113890, upload-time = "2025-04-23T18:31:15.011Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/e6/e3c5908c03cf00d629eb38393a98fccc38ee0ce8ecce32f69fc7d7b558a7/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d", size = 2073359, upload-time = "2025-04-23T18:31:16.393Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/12/e7/6a36a07c59ebefc8777d1ffdaf5ae71b06b21952582e4b07eba88a421c79/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30", size = 2245883, upload-time = "2025-04-23T18:31:17.892Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/16/3f/59b3187aaa6cc0c1e6616e8045b284de2b6a87b027cce2ffcea073adf1d2/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf", size = 2241074, upload-time = "2025-04-23T18:31:19.205Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e0/ed/55532bb88f674d5d8f67ab121a2a13c385df382de2a1677f30ad385f7438/pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51", size = 1910538, upload-time = "2025-04-23T18:31:20.541Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/1b/25b7cccd4519c0b23c2dd636ad39d381abf113085ce4f7bec2b0dc755eb1/pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab", size = 1952909, upload-time = "2025-04-23T18:31:22.371Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/49/a9/d809358e49126438055884c4366a1f6227f0f84f635a9014e2deb9b9de54/pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65", size = 1897786, upload-time = "2025-04-23T18:31:24.161Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" },
@@ -362,24 +309,6 @@ wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982, upload-time = "2025-04-23T18:32:53.14Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412, upload-time = "2025-04-23T18:32:55.52Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749, upload-time = "2025-04-23T18:32:57.546Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527, upload-time = "2025-04-23T18:32:59.771Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225, upload-time = "2025-04-23T18:33:04.51Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490, upload-time = "2025-04-23T18:33:06.391Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525, upload-time = "2025-04-23T18:33:08.44Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446, upload-time = "2025-04-23T18:33:10.313Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678, upload-time = "2025-04-23T18:33:12.224Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7b/27/d4ae6487d73948d6f20dddcd94be4ea43e74349b56eba82e9bdee2d7494c/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8", size = 2025200, upload-time = "2025-04-23T18:33:14.199Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/b8/b3cb95375f05d33801024079b9392a5ab45267a63400bf1866e7ce0f0de4/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593", size = 1859123, upload-time = "2025-04-23T18:33:16.555Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/bc/0d0b5adeda59a261cd30a1235a445bf55c7e46ae44aea28f7bd6ed46e091/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612", size = 1892852, upload-time = "2025-04-23T18:33:18.513Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/11/d37bdebbda2e449cb3f519f6ce950927b56d62f0b84fd9cb9e372a26a3d5/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7", size = 2067484, upload-time = "2025-04-23T18:33:20.475Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/55/1f95f0a05ce72ecb02a8a8a1c3be0579bbc29b1d5ab68f1378b7bebc5057/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e", size = 2108896, upload-time = "2025-04-23T18:33:22.501Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/89/2b2de6c81fa131f423246a9109d7b2a375e83968ad0800d6e57d0574629b/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8", size = 2069475, upload-time = "2025-04-23T18:33:24.528Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/e9/1f7efbe20d0b2b10f6718944b5d8ece9152390904f29a78e68d4e7961159/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf", size = 2239013, upload-time = "2025-04-23T18:33:26.621Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3c/b2/5309c905a93811524a49b4e031e9851a6b00ff0fb668794472ea7746b448/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb", size = 2238715, upload-time = "2025-04-23T18:33:28.656Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757, upload-time = "2025-04-23T18:33:30.645Z" },
]
[[package]]
@@ -420,14 +349,15 @@ wheels = [
[[package]]
name = "starlette"
-version = "0.46.2"
+version = "0.49.1"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "anyio" },
+ { name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
-sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846, upload-time = "2025-04-13T13:56:17.942Z" }
+sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037, upload-time = "2025-04-13T13:56:16.21Z" },
+ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" },
]
[[package]]
@@ -453,11 +383,11 @@ wheels = [
[[package]]
name = "urllib3"
-version = "2.4.0"
+version = "2.6.3"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
-sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/78/16493d9c386d8e60e442a35feac5e00f0913c0f4b7c217c11e8ec2ff53e0/urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466", size = 390672, upload-time = "2025-04-10T15:23:39.232Z" }
+sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680, upload-time = "2025-04-10T15:23:37.377Z" },
+ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
]
[[package]]
@@ -467,7 +397,6 @@ source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "click" },
{ name = "h11" },
- { name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/ae/9bbb19b9e1c450cf9ecaef06463e40234d98d95bf572fab11b4f19ae5ded/uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328", size = 76815, upload-time = "2025-04-19T06:02:50.101Z" }
wheels = [
@@ -480,28 +409,6 @@ version = "1.17.2"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c3/fc/e91cc220803d7bc4db93fb02facd8461c37364151b8494762cc88b0fbcef/wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3", size = 55531, upload-time = "2025-01-14T10:35:45.465Z" }
wheels = [
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5a/d1/1daec934997e8b160040c78d7b31789f19b122110a75eca3d4e8da0049e1/wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984", size = 53307, upload-time = "2025-01-14T10:33:13.616Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1b/7b/13369d42651b809389c1a7153baa01d9700430576c81a2f5c5e460df0ed9/wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22", size = 38486, upload-time = "2025-01-14T10:33:15.947Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/bf/e0105016f907c30b4bd9e377867c48c34dc9c6c0c104556c9c9126bd89ed/wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7", size = 38777, upload-time = "2025-01-14T10:33:17.462Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/27/70/0f6e0679845cbf8b165e027d43402a55494779295c4b08414097b258ac87/wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c", size = 83314, upload-time = "2025-01-14T10:33:21.282Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0f/77/0576d841bf84af8579124a93d216f55d6f74374e4445264cb378a6ed33eb/wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72", size = 74947, upload-time = "2025-01-14T10:33:24.414Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/90/ec/00759565518f268ed707dcc40f7eeec38637d46b098a1f5143bff488fe97/wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061", size = 82778, upload-time = "2025-01-14T10:33:26.152Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f8/5a/7cffd26b1c607b0b0c8a9ca9d75757ad7620c9c0a9b4a25d3f8a1480fafc/wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2", size = 81716, upload-time = "2025-01-14T10:33:27.372Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7e/09/dccf68fa98e862df7e6a60a61d43d644b7d095a5fc36dbb591bbd4a1c7b2/wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c", size = 74548, upload-time = "2025-01-14T10:33:28.52Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/8e/067021fa3c8814952c5e228d916963c1115b983e21393289de15128e867e/wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62", size = 81334, upload-time = "2025-01-14T10:33:29.643Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4b/0d/9d4b5219ae4393f718699ca1c05f5ebc0c40d076f7e65fd48f5f693294fb/wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563", size = 36427, upload-time = "2025-01-14T10:33:30.832Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/6a/c5a83e8f61aec1e1aeef939807602fb880e5872371e95df2137142f5c58e/wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f", size = 38774, upload-time = "2025-01-14T10:33:32.897Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cd/f7/a2aab2cbc7a665efab072344a8949a71081eed1d2f451f7f7d2b966594a2/wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58", size = 53308, upload-time = "2025-01-14T10:33:33.992Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/ff/149aba8365fdacef52b31a258c4dc1c57c79759c335eff0b3316a2664a64/wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda", size = 38488, upload-time = "2025-01-14T10:33:35.264Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/46/5a917ce85b5c3b490d35c02bf71aedaa9f2f63f2d15d9949cc4ba56e8ba9/wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438", size = 38776, upload-time = "2025-01-14T10:33:38.28Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/74/336c918d2915a4943501c77566db41d1bd6e9f4dbc317f356b9a244dfe83/wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a", size = 83776, upload-time = "2025-01-14T10:33:40.678Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/99/c0c844a5ccde0fe5761d4305485297f91d67cf2a1a824c5f282e661ec7ff/wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000", size = 75420, upload-time = "2025-01-14T10:33:41.868Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b4/b0/9fc566b0fe08b282c850063591a756057c3247b2362b9286429ec5bf1721/wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6", size = 83199, upload-time = "2025-01-14T10:33:43.598Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/4b/71996e62d543b0a0bd95dda485219856def3347e3e9380cc0d6cf10cfb2f/wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b", size = 82307, upload-time = "2025-01-14T10:33:48.499Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/39/35/0282c0d8789c0dc9bcc738911776c762a701f95cfe113fb8f0b40e45c2b9/wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662", size = 75025, upload-time = "2025-01-14T10:33:51.191Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/6d/90c9fd2c3c6fee181feecb620d95105370198b6b98a0770cba090441a828/wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72", size = 81879, upload-time = "2025-01-14T10:33:52.328Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8f/fa/9fb6e594f2ce03ef03eddbdb5f4f90acb1452221a5351116c7c4708ac865/wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317", size = 36419, upload-time = "2025-01-14T10:33:53.551Z" },
- { url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/f8/fb1773491a253cbc123c5d5dc15c86041f746ed30416535f2a8df1f4a392/wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3", size = 38773, upload-time = "2025-01-14T10:33:56.323Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/bd/ab55f849fd1f9a58ed7ea47f5559ff09741b25f00c191231f9f059c83949/wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925", size = 53799, upload-time = "2025-01-14T10:33:57.4Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/18/75ddc64c3f63988f5a1d7e10fb204ffe5762bc663f8023f18ecaf31a332e/wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392", size = 38821, upload-time = "2025-01-14T10:33:59.334Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/48/2a/97928387d6ed1c1ebbfd4efc4133a0633546bec8481a2dd5ec961313a1c7/wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40", size = 38919, upload-time = "2025-01-14T10:34:04.093Z" },
diff --git a/agent/templates/advanced_ingestion_pipeline.json b/agent/templates/advanced_ingestion_pipeline.json
index 2e996e248be..97a4c221055 100644
--- a/agent/templates/advanced_ingestion_pipeline.json
+++ b/agent/templates/advanced_ingestion_pipeline.json
@@ -193,7 +193,7 @@
"presence_penalty": 0.4,
"prompts": [
{
- "content": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
+ "content": "Text Content:\n{Extractor:NineTiesSin@chunks}\n",
"role": "user"
}
],
@@ -226,7 +226,7 @@
"presence_penalty": 0.4,
"prompts": [
{
- "content": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
+ "content": "Text Content:\n\n{Extractor:TastyPointsLay@chunks}\n",
"role": "user"
}
],
@@ -259,7 +259,7 @@
"presence_penalty": 0.4,
"prompts": [
{
- "content": "Content: \n\n{Splitter:CuteBusesBet@chunks}",
+ "content": "Content: \n\n{Extractor:BlueResultsWink@chunks}",
"role": "user"
}
],
@@ -485,7 +485,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
- "prompts": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
+ "prompts": "Text Content:\n{Extractor:NineTiesSin@chunks}\n",
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.",
"temperature": 0.1,
"temperatureEnabled": false,
@@ -522,7 +522,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
- "prompts": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
+ "prompts": "Text Content:\n\n{Extractor:TastyPointsLay@chunks}\n",
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.",
"temperature": 0.1,
"temperatureEnabled": false,
@@ -559,7 +559,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
- "prompts": "Content: \n\n{Splitter:BlueResultsWink@chunks}",
+ "prompts": "Content: \n\n{Extractor:BlueResultsWink@chunks}",
"sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.",
"temperature": 0.1,
"temperatureEnabled": false,
diff --git a/agent/templates/choose_your_knowledge_base_agent.json b/agent/templates/choose_your_knowledge_base_agent.json
index 65c02512cda..a4b7ac93794 100644
--- a/agent/templates/choose_your_knowledge_base_agent.json
+++ b/agent/templates/choose_your_knowledge_base_agent.json
@@ -5,9 +5,9 @@
"de": "Wählen Sie Ihren Wissensdatenbank Agenten",
"zh": "选择知识库智能体"},
"description": {
- "en": "Select your desired knowledge base from the dropdown menu. The Agent will only retrieve from the selected knowledge base and use this content to generate responses.",
- "de": "Wählen Sie Ihre gewünschte Wissensdatenbank aus dem Dropdown-Menü. Der Agent ruft nur Informationen aus der ausgewählten Wissensdatenbank ab und verwendet diesen Inhalt zur Generierung von Antworten.",
- "zh": "从下拉菜单中选择知识库,智能体将仅根据所选知识库内容生成回答。"},
+ "en": "This Agent generates responses solely from the specified dataset (knowledge base). You are required to select a knowledge base from the dropdown when running the Agent.",
+ "de": "Dieser Agent erzeugt Antworten ausschließlich aus dem angegebenen Datensatz (Wissensdatenbank). Beim Ausführen des Agents müssen Sie eine Wissensdatenbank aus dem Dropdown-Menü auswählen.",
+ "zh": "本工作流仅根据指定知识库内容生成回答。运行时,请在下拉菜单选择需要查询的知识库。"},
"canvas_type": "Agent",
"dsl": {
"components": {
@@ -387,10 +387,10 @@
{
"data": {
"form": {
- "text": "Select your desired knowledge base from the dropdown menu. \nThe Agent will only retrieve from the selected knowledge base and use this content to generate responses."
+ "text": "This Agent generates responses solely from the specified dataset (knowledge base). \nYou are required to select a knowledge base from the dropdown when running the Agent."
},
"label": "Note",
- "name": "Workflow overall description"
+ "name": "Workflow description"
},
"dragHandle": ".note-drag-handle",
"dragging": false,
diff --git a/agent/templates/choose_your_knowledge_base_workflow.json b/agent/templates/choose_your_knowledge_base_workflow.json
index 3239bd7d351..79886ed3586 100644
--- a/agent/templates/choose_your_knowledge_base_workflow.json
+++ b/agent/templates/choose_your_knowledge_base_workflow.json
@@ -5,9 +5,9 @@
"de": "Wählen Sie Ihren Wissensdatenbank Workflow",
"zh": "选择知识库工作流"},
"description": {
- "en": "Select your desired knowledge base from the dropdown menu. The retrieval assistant will only use data from your selected knowledge base to generate responses.",
- "de": "Wählen Sie Ihre gewünschte Wissensdatenbank aus dem Dropdown-Menü. Der Abrufassistent verwendet nur Daten aus Ihrer ausgewählten Wissensdatenbank, um Antworten zu generieren.",
- "zh": "从下拉菜单中选择知识库,工作流将仅根据所选知识库内容生成回答。"},
+ "en": "This Agent generates responses solely from the specified dataset (knowledge base). You are required to select a knowledge base from the dropdown when running the Agent.",
+ "de": "Dieser Agent erzeugt Antworten ausschließlich aus dem angegebenen Datensatz (Wissensdatenbank). Beim Ausführen des Agents müssen Sie eine Wissensdatenbank aus dem Dropdown-Menü auswählen.",
+ "zh": "本工作流仅根据指定知识库内容生成回答。运行时,请在下拉菜单选择需要查询的知识库。"},
"canvas_type": "Other",
"dsl": {
"components": {
@@ -334,10 +334,10 @@
{
"data": {
"form": {
- "text": "Select your desired knowledge base from the dropdown menu. \nThe retrieval assistant will only use data from your selected knowledge base to generate responses."
+ "text": "This Agent generates responses solely from the specified dataset (knowledge base). \nYou are required to select a knowledge base from the dropdown when running the Agent."
},
"label": "Note",
- "name": "Workflow overall description"
+ "name": "Workflow description"
},
"dragHandle": ".note-drag-handle",
"dragging": false,
diff --git a/agent/templates/user_interaction.json b/agent/templates/user_interaction.json
index 57790f9bb27..c575fc721d3 100644
--- a/agent/templates/user_interaction.json
+++ b/agent/templates/user_interaction.json
@@ -2,10 +2,12 @@
"id": 27,
"title": {
"en": "Interactive Agent",
+ "de": "Interaktiver Agent",
"zh": "可交互的 Agent"
},
"description": {
"en": "During the Agent’s execution, users can actively intervene and interact with the Agent to adjust or guide its output, ensuring the final result aligns with their intentions.",
+ "de": "Wahrend der Ausführung des Agenten können Benutzer aktiv eingreifen und mit dem Agenten interagieren, um dessen Ausgabe zu steuern, sodass das Endergebnis ihren Vorstellungen entspricht.",
"zh": "在 Agent 的运行过程中,用户可以随时介入,与 Agent 进行交互,以调整或引导生成结果,使最终输出更符合预期。"
},
"canvas_type": "Agent",
diff --git a/agent/tools/base.py b/agent/tools/base.py
index ac8336f5d32..1f629a252bc 100644
--- a/agent/tools/base.py
+++ b/agent/tools/base.py
@@ -27,6 +27,10 @@
from timeit import default_timer as timer
+
+
+from common.misc_utils import thread_pool_exec
+
class ToolParameter(TypedDict):
type: str
description: str
@@ -56,12 +60,12 @@ async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
- resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
+ resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
- resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
+ resp = await thread_pool_exec(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@@ -122,6 +126,7 @@ def get_meta(self):
class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
+
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@@ -164,7 +169,7 @@ async def invoke_async(self, **kwargs):
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
- res = await asyncio.to_thread(self._invoke, **kwargs)
+ res = await thread_pool_exec(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py
index 678d56f020a..bc42415e0f1 100644
--- a/agent/tools/code_exec.py
+++ b/agent/tools/code_exec.py
@@ -110,7 +110,7 @@ def fibonacci_recursive(n):
self.lang = Language.PYTHON.value
self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}'
self.arguments = {}
- self.outputs = {"result": {"value": "", "type": "string"}}
+ self.outputs = {"result": {"value": "", "type": "object"}}
def check(self):
self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
@@ -140,26 +140,61 @@ def _invoke(self, **kwargs):
continue
arguments[k] = self._canvas.get_variable_value(v) if v else None
- self._execute_code(language=lang, code=script, arguments=arguments)
+ return self._execute_code(language=lang, code=script, arguments=arguments)
def _execute_code(self, language: str, code: str, arguments: dict):
import requests
if self.check_if_canceled("CodeExec execution"):
- return
+ return self.output()
try:
+ # Try using the new sandbox provider system first
+ try:
+ from agent.sandbox.client import execute_code as sandbox_execute_code
+
+ if self.check_if_canceled("CodeExec execution"):
+ return
+
+ # Execute code using the provider system
+ result = sandbox_execute_code(
+ code=code,
+ language=language,
+ timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)),
+ arguments=arguments
+ )
+
+ if self.check_if_canceled("CodeExec execution"):
+ return
+
+ # Process the result
+ if result.stderr:
+ self.set_output("_ERROR", result.stderr)
+ return
+
+ parsed_stdout = self._deserialize_stdout(result.stdout)
+ logging.info(f"[CodeExec]: Provider system -> {parsed_stdout}")
+ self._populate_outputs(parsed_stdout, result.stdout)
+ return
+
+ except (ImportError, RuntimeError) as provider_error:
+ # Provider system not available or not configured, fall back to HTTP
+ logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}")
+
+ # Fallback to direct HTTP request
code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
- return
+ return self.output()
self.set_output("_ERROR", "construct code request error: " + str(e))
+ return self.output()
try:
if self.check_if_canceled("CodeExec execution"):
- return "Task has been canceled"
+ self.set_output("_ERROR", "Task has been canceled")
+ return self.output()
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
@@ -174,17 +209,18 @@ def _execute_code(self, language: str, code: str, arguments: dict):
stderr = body.get("stderr")
if stderr:
self.set_output("_ERROR", stderr)
- return
+ return self.output()
raw_stdout = body.get("stdout", "")
parsed_stdout = self._deserialize_stdout(raw_stdout)
logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}")
self._populate_outputs(parsed_stdout, raw_stdout)
else:
self.set_output("_ERROR", "There is no response from sandbox")
+ return self.output()
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
- return
+ return self.output()
self.set_output("_ERROR", "Exception executing code: " + str(e))
@@ -295,6 +331,8 @@ def _populate_outputs(self, parsed_stdout, raw_stdout: str):
if key.startswith("_"):
continue
val = self._get_by_path(parsed_stdout, key)
+ if val is None and len(outputs_items) == 1:
+ val = parsed_stdout
coerced = self._coerce_output_value(val, meta.get("type"))
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
self.set_output(key, coerced)
diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py
index 012b00d84e2..3f969f43164 100644
--- a/agent/tools/exesql.py
+++ b/agent/tools/exesql.py
@@ -53,7 +53,7 @@ def __init__(self):
self.max_records = 1024
def check(self):
- self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
+ self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino', 'oceanbase'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
@@ -86,6 +86,12 @@ def _invoke(self, **kwargs):
def convert_decimals(obj):
from decimal import Decimal
+ import math
+ if isinstance(obj, float):
+ # Handle NaN and Infinity which are not valid JSON values
+ if math.isnan(obj) or math.isinf(obj):
+ return None
+ return obj
if isinstance(obj, Decimal):
return float(obj) # 或 str(obj)
elif isinstance(obj, dict):
@@ -120,6 +126,9 @@ def convert_decimals(obj):
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
+ elif self._param.db_type == 'oceanbase':
+ db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
+ port=self._param.port, password=self._param.password, charset='utf8mb4')
elif self._param.db_type == 'postgres':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py
index 21df960befb..29bddde238d 100644
--- a/agent/tools/retrieval.py
+++ b/agent/tools/retrieval.py
@@ -21,7 +21,7 @@
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.constants import LLMType
-from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@@ -125,7 +125,7 @@ async def _retrieve_kb(self, query_text: str):
doc_ids = []
if self._param.meta_data_filter != {}:
- metas = DocumentService.get_meta_by_kbs(kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
pat = re.compile(self.variable_ref_patt)
@@ -174,7 +174,7 @@ def _resolve_manual_filter(flt: dict) -> dict:
if kbs:
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
- kbinfos = settings.retriever.retrieval(
+ kbinfos = await settings.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
@@ -193,7 +193,7 @@ def _resolve_manual_filter(flt: dict) -> dict:
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
- cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
+ cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
@@ -202,7 +202,7 @@ def _resolve_manual_filter(flt: dict) -> dict:
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
- ck = settings.kg_retriever.retrieval(query,
+ ck = await settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
@@ -215,7 +215,7 @@ def _resolve_manual_filter(flt: dict) -> dict:
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
- ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
+ ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
diff --git a/agentic_reasoning/__init__.py b/agentic_reasoning/__init__.py
deleted file mode 100644
index 1422de46e4f..00000000000
--- a/agentic_reasoning/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .deep_research import DeepResearcher as DeepResearcher
\ No newline at end of file
diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py
deleted file mode 100644
index 20f7017f474..00000000000
--- a/agentic_reasoning/deep_research.py
+++ /dev/null
@@ -1,238 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import logging
-import re
-from functools import partial
-from agentic_reasoning.prompts import BEGIN_SEARCH_QUERY, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT, MAX_SEARCH_LIMIT, \
- END_SEARCH_QUERY, REASON_PROMPT, RELEVANT_EXTRACTION_PROMPT
-from api.db.services.llm_service import LLMBundle
-from rag.nlp import extract_between
-from rag.prompts import kb_prompt
-from rag.utils.tavily_conn import Tavily
-
-
-class DeepResearcher:
- def __init__(self,
- chat_mdl: LLMBundle,
- prompt_config: dict,
- kb_retrieve: partial = None,
- kg_retrieve: partial = None
- ):
- self.chat_mdl = chat_mdl
- self.prompt_config = prompt_config
- self._kb_retrieve = kb_retrieve
- self._kg_retrieve = kg_retrieve
-
- def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
- """General Tag Removal Method"""
- pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
- return re.sub(pattern, "", text)
-
- @staticmethod
- def _remove_query_tags(text: str) -> str:
- """Remove Query Tags"""
- return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
-
- @staticmethod
- def _remove_result_tags(text: str) -> str:
- """Remove Result Tags"""
- return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
-
- async def _generate_reasoning(self, msg_history):
- """Generate reasoning steps"""
- query_think = ""
- if msg_history[-1]["role"] != "user":
- msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
- else:
- msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
-
- async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
- ans = re.sub(r"^.*", "", ans, flags=re.DOTALL)
- if not ans:
- continue
- query_think = ans
- yield query_think
- query_think = ""
- yield query_think
-
- def _extract_search_queries(self, query_think, question, step_index):
- """Extract search queries from thinking"""
- queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
- if not queries and step_index == 0:
- # If this is the first step and no queries are found, use the original question as the query
- queries = [question]
- return queries
-
- def _truncate_previous_reasoning(self, all_reasoning_steps):
- """Truncate previous reasoning steps to maintain a reasonable length"""
- truncated_prev_reasoning = ""
- for i, step in enumerate(all_reasoning_steps):
- truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
-
- prev_steps = truncated_prev_reasoning.split('\n\n')
- if len(prev_steps) <= 5:
- truncated_prev_reasoning = '\n\n'.join(prev_steps)
- else:
- truncated_prev_reasoning = ''
- for i, step in enumerate(prev_steps):
- if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
- truncated_prev_reasoning += step + '\n\n'
- else:
- if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
- truncated_prev_reasoning += '...\n\n'
-
- return truncated_prev_reasoning.strip('\n')
-
- def _retrieve_information(self, search_query):
- """Retrieve information from different sources"""
- # 1. Knowledge base retrieval
- kbinfos = []
- try:
- kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
- except Exception as e:
- logging.error(f"Knowledge base retrieval error: {e}")
-
- # 2. Web retrieval (if Tavily API is configured)
- try:
- if self.prompt_config.get("tavily_api_key"):
- tav = Tavily(self.prompt_config["tavily_api_key"])
- tav_res = tav.retrieve_chunks(search_query)
- kbinfos["chunks"].extend(tav_res["chunks"])
- kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
- except Exception as e:
- logging.error(f"Web retrieval error: {e}")
-
- # 3. Knowledge graph retrieval (if configured)
- try:
- if self.prompt_config.get("use_kg") and self._kg_retrieve:
- ck = self._kg_retrieve(question=search_query)
- if ck["content_with_weight"]:
- kbinfos["chunks"].insert(0, ck)
- except Exception as e:
- logging.error(f"Knowledge graph retrieval error: {e}")
-
- return kbinfos
-
- def _update_chunk_info(self, chunk_info, kbinfos):
- """Update chunk information for citations"""
- if not chunk_info["chunks"]:
- # If this is the first retrieval, use the retrieval results directly
- for k in chunk_info.keys():
- chunk_info[k] = kbinfos[k]
- else:
- # Merge newly retrieved information, avoiding duplicates
- cids = [c["chunk_id"] for c in chunk_info["chunks"]]
- for c in kbinfos["chunks"]:
- if c["chunk_id"] not in cids:
- chunk_info["chunks"].append(c)
-
- dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
- for d in kbinfos["doc_aggs"]:
- if d["doc_id"] not in dids:
- chunk_info["doc_aggs"].append(d)
-
- async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
- """Extract and summarize relevant information"""
- summary_think = ""
- async for ans in self.chat_mdl.async_chat_streamly(
- RELEVANT_EXTRACTION_PROMPT.format(
- prev_reasoning=truncated_prev_reasoning,
- search_query=search_query,
- document="\n".join(kb_prompt(kbinfos, 4096))
- ),
- [{"role": "user",
- "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
- {"temperature": 0.7}):
- ans = re.sub(r"^.*", "", ans, flags=re.DOTALL)
- if not ans:
- continue
- summary_think = ans
- yield summary_think
- summary_think = ""
-
- yield summary_think
-
- async def thinking(self, chunk_info: dict, question: str):
- executed_search_queries = []
- msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
- all_reasoning_steps = []
- think = ""
-
- for step_index in range(MAX_SEARCH_LIMIT + 1):
- # Check if the maximum search limit has been reached
- if step_index == MAX_SEARCH_LIMIT - 1:
- summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
- yield {"answer": think + summary_think + " ", "reference": {}, "audio_binary": None}
- all_reasoning_steps.append(summary_think)
- msg_history.append({"role": "assistant", "content": summary_think})
- break
-
- # Step 1: Generate reasoning
- query_think = ""
- async for ans in self._generate_reasoning(msg_history):
- query_think = ans
- yield {"answer": think + self._remove_query_tags(query_think) + "", "reference": {}, "audio_binary": None}
-
- think += self._remove_query_tags(query_think)
- all_reasoning_steps.append(query_think)
-
- # Step 2: Extract search queries
- queries = self._extract_search_queries(query_think, question, step_index)
- if not queries and step_index > 0:
- # If not the first step and no queries, end the search process
- break
-
- # Process each search query
- for search_query in queries:
- logging.info(f"[THINK]Query: {step_index}. {search_query}")
- msg_history.append({"role": "assistant", "content": search_query})
- think += f"\n\n> {step_index + 1}. {search_query}\n\n"
- yield {"answer": think + "", "reference": {}, "audio_binary": None}
-
- # Check if the query has already been executed
- if search_query in executed_search_queries:
- summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
- yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None}
- all_reasoning_steps.append(summary_think)
- msg_history.append({"role": "user", "content": summary_think})
- think += summary_think
- continue
-
- executed_search_queries.append(search_query)
-
- # Step 3: Truncate previous reasoning steps
- truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
-
- # Step 4: Retrieve information
- kbinfos = self._retrieve_information(search_query)
-
- # Step 5: Update chunk information
- self._update_chunk_info(chunk_info, kbinfos)
-
- # Step 6: Extract relevant information
- think += "\n\n"
- summary_think = ""
- async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
- summary_think = ans
- yield {"answer": think + self._remove_result_tags(summary_think) + "", "reference": {}, "audio_binary": None}
-
- all_reasoning_steps.append(summary_think)
- msg_history.append(
- {"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
- think += self._remove_result_tags(summary_think)
- logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
-
- yield think + ""
diff --git a/agentic_reasoning/prompts.py b/agentic_reasoning/prompts.py
deleted file mode 100644
index 8bf101b291a..00000000000
--- a/agentic_reasoning/prompts.py
+++ /dev/null
@@ -1,147 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
-END_SEARCH_QUERY = "<|end_search_query|>"
-BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
-END_SEARCH_RESULT = "<|end_search_result|>"
-MAX_SEARCH_LIMIT = 6
-
-REASON_PROMPT = f"""You are an advanced reasoning agent. Your goal is to answer the user's question by breaking it down into a series of verifiable steps.
-
-You have access to a powerful search tool to find information.
-
-**Your Task:**
-1. Analyze the user's question.
-2. If you need information, issue a search query to find a specific fact.
-3. Review the search results.
-4. Repeat the search process until you have all the facts needed to answer the question.
-5. Once you have gathered sufficient information, synthesize the facts and provide the final answer directly.
-
-**Tool Usage:**
-- To search, you MUST write your query between the special tokens: {BEGIN_SEARCH_QUERY}your query{END_SEARCH_QUERY}.
-- The system will provide results between {BEGIN_SEARCH_RESULT}search results{END_SEARCH_RESULT}.
-- You have a maximum of {MAX_SEARCH_LIMIT} search attempts.
-
----
-**Example 1: Multi-hop Question**
-
-**Question:** "Are both the directors of Jaws and Casino Royale from the same country?"
-
-**Your Thought Process & Actions:**
-First, I need to identify the director of Jaws.
-{BEGIN_SEARCH_QUERY}who is the director of Jaws?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Jaws is a 1975 American thriller film directed by Steven Spielberg.
-{END_SEARCH_RESULT}
-Okay, the director of Jaws is Steven Spielberg. Now I need to find out his nationality.
-{BEGIN_SEARCH_QUERY}where is Steven Spielberg from?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Steven Allan Spielberg is an American filmmaker. Born in Cincinnati, Ohio...
-{END_SEARCH_RESULT}
-So, Steven Spielberg is from the USA. Next, I need to find the director of Casino Royale.
-{BEGIN_SEARCH_QUERY}who is the director of Casino Royale 2006?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Casino Royale is a 2006 spy film directed by Martin Campbell.
-{END_SEARCH_RESULT}
-The director of Casino Royale is Martin Campbell. Now I need his nationality.
-{BEGIN_SEARCH_QUERY}where is Martin Campbell from?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Martin Campbell (born 24 October 1943) is a New Zealand film and television director.
-{END_SEARCH_RESULT}
-I have all the information. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand. They are not from the same country.
-
-Final Answer: No, the directors of Jaws and Casino Royale are not from the same country. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand.
-
----
-**Example 2: Simple Fact Retrieval**
-
-**Question:** "When was the founder of craigslist born?"
-
-**Your Thought Process & Actions:**
-First, I need to know who founded craigslist.
-{BEGIN_SEARCH_QUERY}who founded craigslist?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Craigslist was founded in 1995 by Craig Newmark.
-{END_SEARCH_RESULT}
-The founder is Craig Newmark. Now I need his birth date.
-{BEGIN_SEARCH_QUERY}when was Craig Newmark born?{END_SEARCH_QUERY}
-[System returns search results]
-{BEGIN_SEARCH_RESULT}
-Craig Newmark was born on December 6, 1952.
-{END_SEARCH_RESULT}
-I have found the answer.
-
-Final Answer: The founder of craigslist, Craig Newmark, was born on December 6, 1952.
-
----
-**Important Rules:**
-- **One Fact at a Time:** Decompose the problem and issue one search query at a time to find a single, specific piece of information.
-- **Be Precise:** Formulate clear and precise search queries. If a search fails, rephrase it.
-- **Synthesize at the End:** Do not provide the final answer until you have completed all necessary searches.
-- **Language Consistency:** Your search queries should be in the same language as the user's question.
-
-Now, begin your work. Please answer the following question by thinking step-by-step.
-"""
-
-RELEVANT_EXTRACTION_PROMPT = """You are a highly efficient information extraction module. Your sole purpose is to extract the single most relevant piece of information from the provided `Searched Web Pages` that directly answers the `Current Search Query`.
-
-**Your Task:**
-1. Read the `Current Search Query` to understand what specific information is needed.
-2. Scan the `Searched Web Pages` to find the answer to that query.
-3. Extract only the essential, factual information that answers the query. Be concise.
-
-**Context (For Your Information Only):**
-The `Previous Reasoning Steps` are provided to give you context on the overall goal, but your primary focus MUST be on answering the `Current Search Query`. Do not use information from the previous steps in your output.
-
-**Output Format:**
-Your response must follow one of two formats precisely.
-
-1. **If a direct and relevant answer is found:**
- - Start your response immediately with `Final Information`.
- - Provide only the extracted fact(s). Do not add any extra conversational text.
-
- *Example:*
- `Current Search Query`: Where is Martin Campbell from?
- `Searched Web Pages`: [Long article snippet about Martin Campbell's career, which includes the sentence "Martin Campbell (born 24 October 1943) is a New Zealand film and television director..."]
-
- *Your Output:*
- Final Information
- Martin Campbell is a New Zealand film and television director.
-
-2. **If no relevant answer that directly addresses the query is found in the web pages:**
- - Start your response immediately with `Final Information`.
- - Write the exact phrase: `No helpful information found.`
-
----
-**BEGIN TASK**
-
-**Inputs:**
-
-- **Previous Reasoning Steps:**
-{prev_reasoning}
-
-- **Current Search Query:**
-{search_query}
-
-- **Searched Web Pages:**
-{document}
-"""
\ No newline at end of file
diff --git a/api/apps/__init__.py b/api/apps/__init__.py
index c329679f8fb..7feae696e35 100644
--- a/api/apps/__init__.py
+++ b/api/apps/__init__.py
@@ -16,21 +16,23 @@
import logging
import os
import sys
+import time
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
-from quart import Blueprint, Quart, request, g, current_app, session
-from flasgger import Swagger
+from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
-from common.constants import StatusEnum
+from common.constants import StatusEnum, RetCode
from api.db.db_models import close_connection, APIToken
from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
-from quart_auth import Unauthorized
+from quart_auth import Unauthorized as QuartAuthUnauthorized
+from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
+from quart_schema import QuartSchema
from common import settings
-from api.utils.api_utils import server_error_response
+from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION
from common.misc_utils import get_uuid
@@ -38,40 +40,27 @@
__all__ = ["app"]
+UNAUTHORIZED_MESSAGE = ""
+
+
+def _unauthorized_message(error):
+ if error is None:
+ return UNAUTHORIZED_MESSAGE
+ try:
+ msg = repr(error)
+ except Exception:
+ return UNAUTHORIZED_MESSAGE
+ if msg == UNAUTHORIZED_MESSAGE:
+ return msg
+ if "Unauthorized" in msg and "401" in msg:
+ return msg
+ return UNAUTHORIZED_MESSAGE
+
app = Quart(__name__)
app = cors(app, allow_origin="*")
-# Add this at the beginning of your file to configure Swagger UI
-swagger_config = {
- "headers": [],
- "specs": [
- {
- "endpoint": "apispec",
- "route": "/apispec.json",
- "rule_filter": lambda rule: True, # Include all endpoints
- "model_filter": lambda tag: True, # Include all models
- }
- ],
- "static_url_path": "/flasgger_static",
- "swagger_ui": True,
- "specs_route": "/apidocs/",
-}
-
-swagger = Swagger(
- app,
- config=swagger_config,
- template={
- "swagger": "2.0",
- "info": {
- "title": "RAGFlow API",
- "description": "",
- "version": "1.0.0",
- },
- "securityDefinitions": {
- "ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
- },
- },
-)
+# openapi supported
+QuartSchema(app)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
@@ -125,18 +114,28 @@ def _load_user():
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
- if not user and len(authorization.split()) == 2:
- objs = APIToken.query(token=authorization.split()[1])
- if objs:
- user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
- except Exception as e:
- logging.warning(f"load_user got exception {e}")
+ except Exception as e_auth:
+ logging.warning(f"load_user got exception {e_auth}")
+ try:
+ authorization = request.headers.get("Authorization")
+ if len(authorization.split()) == 2:
+ objs = APIToken.query(token=authorization.split()[1])
+ if objs:
+ user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
+ if user:
+ if not user[0].access_token or not user[0].access_token.strip():
+ logging.warning(f"User {user[0].email} has empty access_token in database")
+ return None
+ g.user = user[0]
+ return user[0]
+ except Exception as e_api_token:
+ logging.warning(f"load_user got exception {e_api_token}")
current_user = LocalProxy(_load_user)
@@ -164,10 +163,18 @@ async def index():
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
- if not current_user: # or not session.get("_user_id"):
- raise Unauthorized()
- else:
- return await current_app.ensure_async(func)(*args, **kwargs)
+ timing_enabled = os.getenv("RAGFLOW_API_TIMING")
+ t_start = time.perf_counter() if timing_enabled else None
+ user = current_user
+ if timing_enabled:
+ logging.info(
+ "api_timing login_required auth_ms=%.2f path=%s",
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
+ if not user: # or not session.get("_user_id"):
+ raise QuartAuthUnauthorized()
+ return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
@@ -277,13 +284,33 @@ def register_page(page_path):
@app.errorhandler(404)
async def not_found(error):
- error_msg: str = f"The requested URL {request.path} was not found"
- logging.error(error_msg)
- return {
+ logging.error(f"The requested URL {request.path} was not found")
+ message = f"Not Found: {request.path}"
+ response = {
+ "code": RetCode.NOT_FOUND,
+ "message": message,
+ "data": None,
"error": "Not Found",
- "message": error_msg,
- }, 404
+ }
+ return jsonify(response), RetCode.NOT_FOUND
+
+
+@app.errorhandler(401)
+async def unauthorized(error):
+ logging.warning("Unauthorized request")
+ return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
+
+
+@app.errorhandler(QuartAuthUnauthorized)
+async def unauthorized_quart_auth(error):
+ logging.warning("Unauthorized request (quart_auth)")
+ return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
+
+@app.errorhandler(WerkzeugUnauthorized)
+async def unauthorized_werkzeug(error):
+ logging.warning("Unauthorized request (werkzeug)")
+ return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.teardown_request
def _db_close(exception):
diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py
index 21bd237894f..25bfae9534f 100644
--- a/api/apps/canvas_app.py
+++ b/api/apps/canvas_app.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import asyncio
+import copy
import inspect
import json
import logging
@@ -29,9 +29,14 @@
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
-from common.misc_utils import get_uuid
-from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
- get_request_json
+from common.misc_utils import get_uuid, thread_pool_exec
+from api.utils.api_utils import (
+ get_json_result,
+ server_error_response,
+ validate_request,
+ get_data_error_result,
+ get_request_json,
+)
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@@ -42,6 +47,7 @@
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
+from api.db.services.canvas_service import completion as agent_completion
@manager.route('/templates', methods=['GET']) # noqa: F821
@@ -132,12 +138,12 @@ async def run():
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
- if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
+ if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
- e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
+ e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@@ -147,7 +153,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
- ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
+ ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
@@ -180,6 +186,50 @@ async def sse():
return resp
+@manager.route("//completion", methods=["POST"]) # noqa: F821
+@login_required
+async def exp_agent_completion(canvas_id):
+ tenant_id = current_user.id
+ req = await get_request_json()
+ return_trace = bool(req.get("return_trace", False))
+ async def generate():
+ trace_items = []
+ async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
+ if isinstance(answer, str):
+ try:
+ ans = json.loads(answer[5:]) # remove "data:"
+ except Exception:
+ continue
+
+ event = ans.get("event")
+ if event == "node_finished":
+ if return_trace:
+ data = ans.get("data", {})
+ trace_items.append(
+ {
+ "component_id": data.get("component_id"),
+ "trace": [copy.deepcopy(data)],
+ }
+ )
+ ans.setdefault("data", {})["trace"] = trace_items
+ answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
+ yield answer
+
+ if event not in ["message", "message_end"]:
+ continue
+
+ yield answer
+
+ yield "data:[DONE]\n\n"
+
+ resp = Response(generate(), mimetype="text/event-stream")
+ resp.headers.add_header("Cache-control", "no-cache")
+ resp.headers.add_header("Connection", "keep-alive")
+ resp.headers.add_header("X-Accel-Buffering", "no")
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
+ return resp
+
+
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
@@ -249,11 +299,14 @@ async def upload(canvas_id):
user_id = cvs["user_id"]
files = await request.files
- file = files['file'] if files and files.get("file") else None
+ file_objs = files.getlist("file") if files and files.get("file") else []
try:
- return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
+ if len(file_objs) == 1:
+ return get_json_result(data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")))
+ results = [FileService.upload_info(user_id, f) for f in file_objs]
+ return get_json_result(data=results)
except Exception as e:
- return server_error_response(e)
+ return server_error_response(e)
@manager.route('/input_form', methods=['GET']) # noqa: F821
@@ -322,6 +375,9 @@ async def test_db_connect():
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
+ elif req["db_type"] == "oceanbase":
+ db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
+ password=req["password"], charset="utf8mb4")
elif req["db_type"] == 'postgres':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
@@ -522,24 +578,81 @@ def sessions(canvas_id):
from_date = request.args.get("from_date")
to_date = request.args.get("to_date")
orderby = request.args.get("orderby", "update_time")
+ exp_user_id = request.args.get("exp_user_id")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
+
+ if exp_user_id:
+ sess = API4ConversationService.get_names(canvas_id, exp_user_id)
+ return get_json_result(data={"total": len(sess), "sessions": sess})
+
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
- None, user_id, include_dsl, keywords, from_date, to_date)
+ None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
try:
return get_json_result(data={"total": total, "sessions": sess})
except Exception as e:
return server_error_response(e)
+@manager.route('//sessions', methods=['PUT']) # noqa: F821
+@login_required
+async def set_session(canvas_id):
+ req = await get_request_json()
+ tenant_id = current_user.id
+ e, cvs = UserCanvasService.get_by_id(canvas_id)
+ assert e, "Agent not found."
+ if not isinstance(cvs.dsl, str):
+ cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
+ session_id=get_uuid()
+ canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
+ canvas.reset()
+ conv = {
+ "id": session_id,
+ "name": req.get("name", ""),
+ "dialog_id": cvs.id,
+ "user_id": tenant_id,
+ "exp_user_id": tenant_id,
+ "message": [],
+ "source": "agent",
+ "dsl": cvs.dsl,
+ "reference": []
+ }
+ API4ConversationService.save(**conv)
+ return get_json_result(data=conv)
+
+
+@manager.route('//sessions/', methods=['GET']) # noqa: F821
+@login_required
+def get_session(canvas_id, session_id):
+ tenant_id = current_user.id
+ if not UserCanvasService.accessible(canvas_id, tenant_id):
+ return get_json_result(
+ data=False, message='Only owner of canvas authorized for this operation.',
+ code=RetCode.OPERATING_ERROR)
+ _, conv = API4ConversationService.get_by_id(session_id)
+ return get_json_result(data=conv.to_dict())
+
+
+@manager.route('//sessions/', methods=['DELETE']) # noqa: F821
+@login_required
+def del_session(canvas_id, session_id):
+ tenant_id = current_user.id
+ if not UserCanvasService.accessible(canvas_id, tenant_id):
+ return get_json_result(
+ data=False, message='Only owner of canvas authorized for this operation.',
+ code=RetCode.OPERATING_ERROR)
+ return get_json_result(data=API4ConversationService.delete_by_id(session_id))
+
+
@manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required
def prompts():
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
+
return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP,
diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
index f5b248fd5ef..c1be1ef88c6 100644
--- a/api/apps/chunk_app.py
+++ b/api/apps/chunk_app.py
@@ -13,22 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import asyncio
+import base64
import datetime
import json
+import logging
import re
-import base64
import xxhash
from quart import request
from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
-from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
- get_request_json
+from api.utils.api_utils import (
+ get_data_error_result,
+ get_json_result,
+ server_error_response,
+ validate_request,
+ get_request_json,
+)
+from common.misc_utils import thread_pool_exec
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -38,7 +45,6 @@
from common import settings
from api.apps import login_required, current_user
-
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
@@ -61,7 +67,7 @@ async def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
- sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
+ sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@@ -126,10 +132,15 @@ def get():
@validate_request("doc_id", "chunk_id", "content_with_weight")
async def set():
req = await get_request_json()
+ content_with_weight = req["content_with_weight"]
+ if not isinstance(content_with_weight, (str, bytes)):
+ raise TypeError("expected string or bytes-like object")
+ if isinstance(content_with_weight, bytes):
+ content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
d = {
"id": req["chunk_id"],
- "content_with_weight": req["content_with_weight"]}
- d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
+ "content_with_weight": content_with_weight}
+ d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req:
if not isinstance(req["important_kwd"], list):
@@ -171,20 +182,21 @@ def _set_sync():
_d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
- v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
+ v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
# update image
image_base64 = req.get("image_base64", None)
- if image_base64:
- bkt, name = req.get("img_id", "-").split("-")
+ img_id = req.get("img_id", "")
+ if image_base64 and img_id and "-" in img_id:
+ bkt, name = img_id.split("-", 1)
image_binary = base64.b64decode(image_base64)
settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True)
- return await asyncio.to_thread(_set_sync)
+ return await thread_pool_exec(_set_sync)
except Exception as e:
return server_error_response(e)
@@ -207,7 +219,7 @@ def _switch_sync():
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
- return await asyncio.to_thread(_switch_sync)
+ return await thread_pool_exec(_switch_sync)
except Exception as e:
return server_error_response(e)
@@ -222,19 +234,34 @@ def _rm_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
- if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
- search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
- doc.kb_id):
+ condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
+ try:
+ deleted_count = settings.docStoreConn.delete(condition,
+ search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
+ doc.kb_id)
+ except Exception:
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
- chunk_number = len(deleted_chunk_ids)
+ if isinstance(deleted_chunk_ids, list):
+ unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
+ has_ids = len(unique_chunk_ids) > 0
+ else:
+ unique_chunk_ids = [deleted_chunk_ids]
+ has_ids = deleted_chunk_ids not in (None, "")
+ if has_ids and deleted_count == 0:
+ return get_data_error_result(message="Index updating failure")
+ if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
+ deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
+ search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
+ doc.kb_id)
+ chunk_number = deleted_count
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
- return await asyncio.to_thread(_rm_sync)
+ return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)
@@ -244,6 +271,7 @@ def _rm_sync():
@validate_request("doc_id", "content_with_weight")
async def create():
req = await get_request_json()
+ req_id = request.headers.get("X-Request-ID")
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@@ -260,14 +288,23 @@ async def create():
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
- if "tag_feas" in req:
- d["tag_feas"] = req["tag_feas"]
try:
+ def _log_response(resp, code, message):
+ logging.info(
+ "chunk_create response req_id=%s status=%s code=%s message=%s",
+ req_id,
+ getattr(resp, "status_code", None),
+ code,
+ message,
+ )
+
def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
- return get_data_error_result(message="Document not found!")
+ resp = get_data_error_result(message="Document not found!")
+ _log_response(resp, RetCode.DATA_ERROR, "Document not found!")
+ return resp
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
@@ -275,11 +312,15 @@ def _create_sync():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
+ resp = get_data_error_result(message="Tenant not found!")
+ _log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
+ return resp
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
- return get_data_error_result(message="Knowledgebase not found!")
+ resp = get_data_error_result(message="Knowledgebase not found!")
+ _log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
+ return resp
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
@@ -293,10 +334,13 @@ def _create_sync():
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
- return get_json_result(data={"chunk_id": chunck_id})
+ resp = get_json_result(data={"chunk_id": chunck_id})
+ _log_response(resp, RetCode.SUCCESS, "success")
+ return resp
- return await asyncio.to_thread(_create_sync)
+ return await thread_pool_exec(_create_sync)
except Exception as e:
+ logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
return server_error_response(e)
@@ -338,7 +382,7 @@ async def _retrieval():
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
if meta_data_filter:
- metas = DocumentService.get_meta_by_kbs(kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
tenants = UserTenantService.query(user_id=user_id)
@@ -372,16 +416,23 @@ async def _retrieval():
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
- ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
- float(req.get("similarity_threshold", 0.0)),
- float(req.get("vector_similarity_weight", 0.3)),
- top,
- local_doc_ids, rerank_mdl=rerank_mdl,
- highlight=req.get("highlight", False),
- rank_feature=labels
- )
+ ranks = await settings.retriever.retrieval(
+ _question,
+ embd_mdl,
+ tenant_ids,
+ kb_ids,
+ page,
+ size,
+ float(req.get("similarity_threshold", 0.0)),
+ float(req.get("vector_similarity_weight", 0.3)),
+ doc_ids=local_doc_ids,
+ top=top,
+ rerank_mdl=rerank_mdl,
+ rank_feature=labels
+ )
+
if use_kg:
- ck = settings.kg_retriever.retrieval(_question,
+ ck = await settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,
@@ -407,7 +458,7 @@ async def _retrieval():
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
-def knowledge_graph():
+async def knowledge_graph():
doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
@@ -415,7 +466,7 @@ def knowledge_graph():
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
- sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
+ sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py
index fb074419bb5..0e687ea69a7 100644
--- a/api/apps/connector_app.py
+++ b/api/apps/connector_app.py
@@ -52,7 +52,7 @@ async def set_connector():
"source": req["source"],
"input_type": InputType.POLL,
"config": req["config"],
- "refresh_freq": int(req.get("refresh_freq", 30)),
+ "refresh_freq": int(req.get("refresh_freq", 5)),
"prune_freq": int(req.get("prune_freq", 720)),
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,
diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py
index d2aad88ee1a..9b7617797d8 100644
--- a/api/apps/dialog_app.py
+++ b/api/apps/dialog_app.py
@@ -25,6 +25,7 @@
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.apps import login_required, current_user
+import logging
@manager.route('/set', methods=['POST']) # noqa: F821
@@ -42,13 +43,19 @@ async def set_dialog():
if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
- if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
- name = name.strip()
- name = duplicate_name(
- DialogService.query,
- name=name,
- tenant_id=current_user.id,
- status=StatusEnum.VALID.value)
+ name = name.strip()
+ if is_create:
+ # only for chat creating
+ existing_names = {
+ d.name.casefold()
+ for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
+ if d.name
+ }
+ if name.casefold() in existing_names:
+ def _name_exists(name: str, **_kwargs) -> bool:
+ return name.casefold() in existing_names
+
+ name = duplicate_name(_name_exists, name=name)
description = req.get("description", "A helpful dialog")
icon = req.get("icon", "")
@@ -63,16 +70,30 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
+ # Set default parameters for datasets with knowledge retrieval
+ # All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
+ kb_ids = req.get("kb_ids", [])
+ parameters = prompt_config.get("parameters")
+ logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
+ # Check if parameters is missing, None, or empty list
+ if kb_ids and not parameters:
+ # Check if system prompt uses {knowledge} placeholder
+ if "{knowledge}" in prompt_config.get("system", ""):
+ # Set default parameters for any dataset with knowledge placeholder
+ prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
+ logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
+
if not is_create:
- if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
+ # only for chat updating
+ if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
- for p in prompt_config["parameters"]:
- if p["optional"]:
- continue
- if prompt_config["system"].find("{%s}" % p["key"]) < 0:
- return get_data_error_result(
- message="Parameter '{}' is not used".format(p["key"]))
+ for p in prompt_config.get("parameters", []):
+ if p["optional"]:
+ continue
+ if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
+ return get_data_error_result(
+ message="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)
diff --git a/api/apps/document_app.py b/api/apps/document_app.py
index 4fcc07e65c8..cc2b7c8c4a2 100644
--- a/api/apps/document_app.py
+++ b/api/apps/document_app.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
-import asyncio
import json
import os.path
import pathlib
@@ -27,18 +26,20 @@
from api.db.db_models import Task
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse
-from common.metadata_utils import meta_filter, convert_conditions
+from api.db.services.doc_metadata_service import DocMetadataService
+from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of
from api.db.services.user_service import UserTenantService
-from common.misc_utils import get_uuid
+from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
- validate_request, get_request_json,
+ validate_request,
+ get_request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@@ -62,10 +63,21 @@ async def upload():
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = files.getlist("file")
+ def _close_file_objs(objs):
+ for obj in objs:
+ try:
+ obj.close()
+ except Exception:
+ try:
+ obj.stream.close()
+ except Exception:
+ pass
for file_obj in file_objs:
if file_obj.filename == "":
+ _close_file_objs(file_objs)
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
+ _close_file_objs(file_objs)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
@@ -74,8 +86,9 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
- err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
+ err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
if err:
+ files = [f[0] for f in files] if files else []
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
if not files:
@@ -214,6 +227,7 @@ async def list_docs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
+
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
@@ -268,7 +282,7 @@ async def list_docs():
doc_ids_filter = None
metas = None
if metadata_condition or metadata:
- metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
+ metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
if metadata_condition:
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
@@ -333,6 +347,8 @@ async def list_docs():
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
if doc_item.get("source_type"):
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
+ if doc_item["parser_config"].get("metadata"):
+ doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
return get_json_result(data={"total": tol, "docs": docs})
except Exception as e:
@@ -386,7 +402,11 @@ async def doc_infos():
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
docs = DocumentService.get_by_ids(doc_ids)
- return get_json_result(data=list(docs.dicts()))
+ docs_list = list(docs.dicts())
+ # Add meta_fields for each document
+ for doc in docs_list:
+ doc["meta_fields"] = DocMetadataService.get_document_metadata(doc["id"])
+ return get_json_result(data=docs_list)
@manager.route("/metadata/summary", methods=["POST"]) # noqa: F821
@@ -394,6 +414,7 @@ async def doc_infos():
async def metadata_summary():
req = await get_request_json()
kb_id = req.get("kb_id")
+ doc_ids = req.get("doc_ids")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@@ -405,7 +426,7 @@ async def metadata_summary():
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
try:
- summary = DocumentService.get_metadata_summary(kb_id)
+ summary = DocMetadataService.get_metadata_summary(kb_id, doc_ids)
return get_json_result(data={"summary": summary})
except Exception as e:
return server_error_response(e)
@@ -413,36 +434,20 @@ async def metadata_summary():
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
@login_required
+@validate_request("doc_ids")
async def metadata_update():
req = await get_request_json()
kb_id = req.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
-
- tenants = UserTenantService.query(user_id=current_user.id)
- for tenant in tenants:
- if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
- break
- else:
- return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
-
- selector = req.get("selector", {}) or {}
+ document_ids = req.get("doc_ids")
updates = req.get("updates", []) or []
deletes = req.get("deletes", []) or []
- if not isinstance(selector, dict):
- return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
+ if not kb_id:
+ return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
+
if not isinstance(updates, list) or not isinstance(deletes, list):
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
- metadata_condition = selector.get("metadata_condition", {}) or {}
- if metadata_condition and not isinstance(metadata_condition, dict):
- return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
-
- document_ids = selector.get("document_ids", []) or []
- if document_ids and not isinstance(document_ids, list):
- return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
-
for upd in updates:
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
@@ -450,24 +455,8 @@ async def metadata_update():
if not isinstance(d, dict) or not d.get("key"):
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
- kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
- target_doc_ids = set(kb_doc_ids)
- if document_ids:
- invalid_ids = set(document_ids) - set(kb_doc_ids)
- if invalid_ids:
- return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
- target_doc_ids = set(document_ids)
-
- if metadata_condition:
- metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
- filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
- target_doc_ids = target_doc_ids & filtered_ids
- if metadata_condition.get("conditions") and not target_doc_ids:
- return get_json_result(data={"updated": 0, "matched_docs": 0})
-
- target_doc_ids = list(target_doc_ids)
- updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
- return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
+ updated = DocMetadataService.batch_update_metadata(kb_id, document_ids, updates, deletes)
+ return get_json_result(data={"updated": updated, "matched_docs": len(document_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@@ -521,31 +510,61 @@ async def change_status():
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
result = {}
+ has_error = False
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
result[doc_id] = {"error": "No authorization."}
+ has_error = True
continue
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
result[doc_id] = {"error": "No authorization."}
+ has_error = True
continue
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
result[doc_id] = {"error": "Can't find this dataset!"}
+ has_error = True
+ continue
+ current_status = str(doc.status)
+ if current_status == status:
+ result[doc_id] = {"status": status}
continue
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
result[doc_id] = {"error": "Database error (Document update)!"}
+ has_error = True
continue
status_int = int(status)
- if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
- result[doc_id] = {"error": "Database error (docStore update)!"}
+ if getattr(doc, "chunk_num", 0) > 0:
+ try:
+ ok = settings.docStoreConn.update(
+ {"doc_id": doc_id},
+ {"available_int": status_int},
+ search.index_name(kb.tenant_id),
+ doc.kb_id,
+ )
+ except Exception as exc:
+ msg = str(exc)
+ if "3022" in msg:
+ result[doc_id] = {"error": "Document store table missing."}
+ else:
+ result[doc_id] = {"error": f"Document store update failed: {msg}"}
+ has_error = True
+ continue
+ if not ok:
+ result[doc_id] = {"error": "Database error (docStore update)!"}
+ has_error = True
+ continue
result[doc_id] = {"status": status}
except Exception as e:
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
+ has_error = True
+ if has_error:
+ return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
return get_json_result(data=result)
@@ -562,7 +581,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
- errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
+ errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@@ -575,10 +594,11 @@ async def rm():
@validate_request("doc_ids", "run")
async def run():
req = await get_request_json()
+ uid = current_user.id
try:
def _run_sync():
for doc_id in req["doc_ids"]:
- if not DocumentService.accessible(doc_id, current_user.id):
+ if not DocumentService.accessible(doc_id, uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
kb_table_num_map = {}
@@ -597,7 +617,9 @@ def _run_sync():
return get_data_error_result(message="Document not found!")
if str(req["run"]) == TaskStatus.CANCEL.value:
- if str(doc.run) == TaskStatus.RUNNING.value:
+ tasks = list(TaskService.query(doc_id=id))
+ has_unfinished_task = any((task.progress or 0) < 1 for task in tasks)
+ if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
@@ -615,6 +637,7 @@ def _run_sync():
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
raise LookupError("Can't find this dataset!")
+ doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
DocumentService.update_parser_config(doc.id, doc.parser_config)
@@ -623,7 +646,7 @@ def _run_sync():
return get_json_result(data=True)
- return await asyncio.to_thread(_run_sync)
+ return await thread_pool_exec(_run_sync)
except Exception as e:
return server_error_response(e)
@@ -633,9 +656,10 @@ def _run_sync():
@validate_request("doc_id", "name")
async def rename():
req = await get_request_json()
+ uid = current_user.id
try:
def _rename_sync():
- if not DocumentService.accessible(req["doc_id"], current_user.id):
+ if not DocumentService.accessible(req["doc_id"], uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
@@ -674,14 +698,14 @@ def _rename_sync():
)
return get_json_result(data=True)
- return await asyncio.to_thread(_rename_sync)
+ return await thread_pool_exec(_rename_sync)
except Exception as e:
return server_error_response(e)
@manager.route("/get/", methods=["GET"]) # noqa: F821
-# @login_required
+@login_required
async def get(doc_id):
try:
e, doc = DocumentService.get_by_id(doc_id)
@@ -689,7 +713,7 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
- data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower())
@@ -711,7 +735,7 @@ async def get(doc_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
- data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@@ -784,7 +808,7 @@ async def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
- data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG")
return response
@@ -892,7 +916,7 @@ async def set_meta():
if not e:
return get_data_error_result(message="Document not found!")
- if not DocumentService.update_by_id(req["doc_id"], {"meta_fields": meta}):
+ if not DocMetadataService.update_document_metadata(req["doc_id"], meta):
return get_data_error_result(message="Database error (meta updates)!")
return get_json_result(data=True)
diff --git a/api/apps/file_app.py b/api/apps/file_app.py
index 1ce5d4caed9..50cbd185aff 100644
--- a/api/apps/file_app.py
+++ b/api/apps/file_app.py
@@ -14,7 +14,6 @@
# limitations under the License
#
import logging
-import asyncio
import os
import pathlib
import re
@@ -25,7 +24,7 @@
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
-from common.misc_utils import get_uuid
+from common.misc_utils import get_uuid, thread_pool_exec
from common.constants import RetCode, FileSource
from api.db import FileType
from api.db.services import duplicate_name
@@ -35,7 +34,6 @@
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
-
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
@@ -65,7 +63,7 @@ async def upload():
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
- if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
+ if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path
@@ -77,35 +75,35 @@ async def _handle_single_file(file_obj):
file_len = len(file_obj_names)
# get folder
- file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
+ file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# create folder
if file_len != len_id_list:
- e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
+ e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e:
return get_data_error_result(message="Folder not found!")
- last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
+ last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
- e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
+ e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e:
return get_data_error_result(message="Folder not found!")
- last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
+ last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
- while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
+ while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_"
- blob = await asyncio.to_thread(file_obj.read)
- filename = await asyncio.to_thread(
+ blob = await thread_pool_exec(file_obj.read)
+ filename = await thread_pool_exec(
duplicate_name,
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
- await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
+ await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = {
"id": get_uuid(),
"parent_id": last_folder.id,
@@ -116,7 +114,7 @@ async def _handle_single_file(file_obj):
"location": location,
"size": len(blob),
}
- inserted = await asyncio.to_thread(FileService.insert, file_data)
+ inserted = await thread_pool_exec(FileService.insert, file_data)
return inserted.to_json()
for file_obj in file_objs:
@@ -249,6 +247,7 @@ def get_all_parent_folders():
async def rm():
req = await get_request_json()
file_ids = req["file_ids"]
+ uid = current_user.id
try:
def _delete_single_file(file):
@@ -287,21 +286,21 @@ def _rm_sync():
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
- if not check_file_team_permission(file, current_user.id):
+ if not check_file_team_permission(file, uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
- _delete_folder_recursive(file, current_user.id)
+ _delete_folder_recursive(file, uid)
continue
_delete_single_file(file)
return get_json_result(data=True)
- return await asyncio.to_thread(_rm_sync)
+ return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)
@@ -357,10 +356,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
- blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
+ blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
- blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
+ blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@@ -460,7 +459,7 @@ def _move_sync():
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
- return await asyncio.to_thread(_move_sync)
+ return await thread_pool_exec(_move_sync)
except Exception as e:
return server_error_response(e)
diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py
index fff982563f9..efb028bf15f 100644
--- a/api/apps/kb_app.py
+++ b/api/apps/kb_app.py
@@ -17,21 +17,29 @@
import logging
import random
import re
-import asyncio
+from common.metadata_utils import turn2jsonschema
from quart import request
import numpy as np
from api.db.services.connector_service import Connector2KbService
from api.db.services.llm_service import LLMBundle
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
-from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
- get_request_json
+from api.utils.api_utils import (
+ get_error_data_result,
+ server_error_response,
+ get_data_error_result,
+ validate_request,
+ not_allowed_parameters,
+ get_request_json,
+)
+from common.misc_utils import thread_pool_exec
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@@ -44,7 +52,6 @@
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
-
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@validate_request("name")
@@ -82,6 +89,20 @@ async def update():
return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
+ if settings.DOC_ENGINE_INFINITY:
+ parser_id = req.get("parser_id")
+ if isinstance(parser_id, str) and parser_id.lower() == "tag":
+ return get_json_result(
+ code=RetCode.OPERATING_ERROR,
+ message="The chunking method Tag has not been supported by Infinity yet.",
+ data=False,
+ )
+ if "pagerank" in req and req["pagerank"] > 0:
+ return get_json_result(
+ code=RetCode.DATA_ERROR,
+ message="'pagerank' can only be set when doc_engine is elasticsearch",
+ data=False,
+ )
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
@@ -130,7 +151,7 @@ async def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
- await asyncio.to_thread(
+ await thread_pool_exec(
settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
@@ -139,7 +160,7 @@ async def update():
)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
- await asyncio.to_thread(
+ await thread_pool_exec(
settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
@@ -174,6 +195,7 @@ async def update_metadata_setting():
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
kb["parser_config"]["metadata"] = req["metadata"]
+ kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
KnowledgebaseService.update_by_id(kb["id"], kb)
return get_json_result(data=kb)
@@ -198,6 +220,8 @@ def detail():
message="Can't find this dataset!")
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
+ if kb["parser_config"].get("metadata"):
+ kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
if finish_at := kb.get(key):
@@ -249,7 +273,8 @@ async def list_kbs():
@validate_request("kb_id")
async def rm():
req = await get_request_json()
- if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
+ uid = current_user.id
+ if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
return get_json_result(
data=False,
message='No authorization.',
@@ -257,7 +282,7 @@ async def rm():
)
try:
kbs = KnowledgebaseService.query(
- created_by=current_user.id, id=req["kb_id"])
+ created_by=uid, id=req["kb_id"])
if not kbs:
return get_json_result(
data=False, message='Only owner of dataset authorized for this operation.',
@@ -280,17 +305,24 @@ def _rm_sync():
File.name == kbs[0].name,
]
)
+ # Delete the table BEFORE deleting the database record
+ for kb in kbs:
+ try:
+ settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
+ settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
+ logging.info(f"Dropped index for dataset {kb.id}")
+ except Exception as e:
+ logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
+
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
- settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
- settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
- return await asyncio.to_thread(_rm_sync)
+ return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)
@@ -372,7 +404,7 @@ async def rename_tags(kb_id):
@manager.route('//knowledge_graph', methods=['GET']) # noqa: F821
@login_required
-def knowledge_graph(kb_id):
+async def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -388,7 +420,7 @@ def knowledge_graph(kb_id):
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
- sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
+ sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@@ -436,7 +468,7 @@ def get_meta():
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
- return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
+ return get_json_result(data=DocMetadataService.get_flatted_meta_by_kbs(kb_ids))
@manager.route("/basic_info", methods=["GET"]) # noqa: F821
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 9a68e825606..9d2fed80262 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import logging
import json
import os
@@ -64,13 +65,17 @@ async def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
extra = {"provider": factory}
+ timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
- arr, tc = mdl.encode(["Test if the api key is available"])
+ arr, tc = await asyncio.wait_for(
+ asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
+ timeout=timeout_seconds,
+ )
if len(arr[0]) == 0:
raise Exception("Fail")
embd_passed = True
@@ -80,17 +85,27 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
- m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
+ m, tc = await asyncio.wait_for(
+ mdl.async_chat(
+ None,
+ [{"role": "user", "content": "Hello! How are you doing!"}],
+ {"temperature": 0.9, "max_tokens": 50},
+ ),
+ timeout=timeout_seconds,
+ )
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
- elif not rerank_passed and llm.model_type == LLMType.RERANK:
+ elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
- arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
+ arr, tc = await asyncio.wait_for(
+ asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
+ timeout=timeout_seconds,
+ )
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
@@ -101,6 +116,9 @@ async def set_api_key():
msg = ""
break
+ if req.get("verify", False):
+ return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
+
if msg:
return get_data_error_result(message=msg)
@@ -133,6 +151,7 @@ async def add_llm():
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
+ timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
if factory not in [f.name for f in get_allowed_llm_factories()]:
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
@@ -146,10 +165,6 @@ def apikey_json(keys):
# Assemble ark_api_key endpoint_id into api_key
api_key = apikey_json(["ark_api_key", "endpoint_id"])
- elif factory == "Tencent Hunyuan":
- req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
- return await set_api_key()
-
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return await set_api_key()
@@ -195,6 +210,9 @@ def apikey_json(keys):
elif factory == "MinerU":
api_key = apikey_json(["api_key", "provider_order"])
+ elif factory == "PaddleOCR":
+ api_key = apikey_json(["api_key", "provider_order"])
+
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
@@ -216,7 +234,10 @@ def apikey_json(keys):
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
- arr, tc = mdl.encode(["Test if the api key is available"])
+ arr, tc = await asyncio.wait_for(
+ asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
+ timeout=timeout_seconds,
+ )
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
@@ -230,8 +251,14 @@ def apikey_json(keys):
**extra,
)
try:
- m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
- {"temperature": 0.9})
+ m, tc = await asyncio.wait_for(
+ mdl.async_chat(
+ None,
+ [{"role": "user", "content": "Hello! How are you doing!"}],
+ {"temperature": 0.9},
+ ),
+ timeout=timeout_seconds,
+ )
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@@ -241,7 +268,10 @@ def apikey_json(keys):
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
- arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
+ arr, tc = await asyncio.wait_for(
+ asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
+ timeout=timeout_seconds,
+ )
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
@@ -254,7 +284,10 @@ def apikey_json(keys):
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
image_data = test_image
- m, tc = mdl.describe(image_data)
+ m, tc = await asyncio.wait_for(
+ asyncio.to_thread(mdl.describe, image_data),
+ timeout=timeout_seconds,
+ )
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@@ -263,20 +296,29 @@ def apikey_json(keys):
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
- for resp in mdl.tts("Hello~ RAGFlower!"):
- pass
+ def drain_tts():
+ for _ in mdl.tts("Hello~ RAGFlower!"):
+ pass
+
+ await asyncio.wait_for(
+ asyncio.to_thread(drain_tts),
+ timeout=timeout_seconds,
+ )
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
- ok, reason = mdl.check_available()
+ ok, reason = await asyncio.wait_for(
+ asyncio.to_thread(mdl.check_available),
+ timeout=timeout_seconds,
+ )
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
- case LLMType.SPEECH2TEXT:
+ case LLMType.SPEECH2TEXT.value:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
@@ -286,6 +328,9 @@ def apikey_json(keys):
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
+ if req.get("verify", False):
+ return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
+
if msg:
return get_data_error_result(message=msg)
@@ -371,17 +416,18 @@ def my_llms():
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
-def list_app():
+async def list_app():
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = []
model_type = request.args.get("model_type")
+ tenant_id = current_user.id
try:
- TenantLLMService.ensure_mineru_from_env(current_user.id)
- objs = TenantLLMService.query(tenant_id=current_user.id)
+ TenantLLMService.ensure_mineru_from_env(tenant_id)
+ objs = TenantLLMService.query(tenant_id=tenant_id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
llms = LLMService.get_all()
- llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)]
+ llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == "Builtin" or (m.llm_name + "@" + m.fid) in status)]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py
index 62ae2e3c06b..187560d626b 100644
--- a/api/apps/mcp_server_app.py
+++ b/api/apps/mcp_server_app.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import asyncio
-
from quart import Response, request
from api.apps import current_user, login_required
@@ -23,12 +21,11 @@
from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
-from common.misc_utils import get_uuid
+from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
-
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
async def list_mcp() -> Response:
@@ -108,7 +105,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
- server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
+ server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@@ -160,7 +157,7 @@ async def update() -> Response:
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
- server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
+ server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@@ -244,7 +241,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
- server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
+ server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
results.append({"server": base_name, "success": False, "message": err_message})
continue
@@ -324,7 +321,7 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session)
try:
- tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
+ tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"MCP list tools error: {e}")
@@ -341,7 +338,7 @@ async def list_tools() -> Response:
return server_error_response(e)
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
- await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
+ await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@@ -368,10 +365,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
- result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
+ result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
- await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
+ await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@@ -425,12 +422,12 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try:
- tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
+ tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
- await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
+ await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
tool_dict = tool.model_dump()
diff --git a/api/apps/plugin_app.py b/api/apps/plugin_app.py
index 6e7a8769018..fb0a7bb6106 100644
--- a/api/apps/plugin_app.py
+++ b/api/apps/plugin_app.py
@@ -18,7 +18,7 @@
from quart import Response
from api.apps import login_required
from api.utils.api_utils import get_json_result
-from plugin import GlobalPluginManager
+from agent.plugin import GlobalPluginManager
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py
index e6a68786992..0d5962a4f6a 100644
--- a/api/apps/sdk/agents.py
+++ b/api/apps/sdk/agents.py
@@ -51,7 +51,7 @@ def list_agents(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
order_by = request.args.get("orderby", "update_time")
- if request.args.get("desc") == "False" or request.args.get("desc") == "false":
+ if str(request.args.get("desc","false")).lower() == "false":
desc = False
else:
desc = True
@@ -162,6 +162,7 @@ async def webhook(agent_id: str):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
+ webhook_cfg = {}
components = dsl.get("components", {})
for k, _ in components.items():
cpn_obj = components[k]["obj"]
diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py
index 1efb628f1bc..786d1a733f7 100644
--- a/api/apps/sdk/chat.py
+++ b/api/apps/sdk/chat.py
@@ -51,7 +51,9 @@ async def create(tenant_id):
req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
- if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
+ model_type = llm.get("model_type")
+ model_type = model_type if model_type in ["chat", "image2text"] else "chat"
+ if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
req["llm_setting"] = req.pop("llm")
e, tenant = TenantService.get_by_id(tenant_id)
@@ -174,7 +176,7 @@ async def update(tenant_id, chat_id):
req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
- model_type = llm.pop("model_type")
+ model_type = llm.get("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py
index 7d52c3fec50..d0d7ff0c66a 100644
--- a/api/apps/sdk/dataset.py
+++ b/api/apps/sdk/dataset.py
@@ -233,6 +233,15 @@ async def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
+
+ # Drop index for this dataset
+ try:
+ from rag.nlp import search
+ idxnm = search.index_name(kb.tenant_id)
+ settings.docStoreConn.delete_idx(idxnm, kb_id)
+ except Exception as e:
+ logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
+
if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}")
continue
@@ -481,7 +490,7 @@ def list_datasets(tenant_id):
@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821
@token_required
-def knowledge_graph(tenant_id, dataset_id):
+async def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
@@ -497,7 +506,7 @@ def knowledge_graph(tenant_id, dataset_id):
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
- sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
+ sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):
return get_result(data=obj)
diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py
index 7a11688ddcb..881614e5d97 100644
--- a/api/apps/sdk/dify_retrieval.py
+++ b/api/apps/sdk/dify_retrieval.py
@@ -18,6 +18,7 @@
from quart import jsonify
from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import meta_filter, convert_conditions
@@ -121,7 +122,7 @@ async def retrieval(tenant_id):
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
metadata_condition = req.get("metadata_condition", {}) or {}
- metas = DocumentService.get_meta_by_kbs([kb_id])
+ metas = DocMetadataService.get_meta_by_kbs([kb_id])
doc_ids = []
try:
@@ -135,7 +136,7 @@ async def retrieval(tenant_id):
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition:
doc_ids = ["-999"]
- ranks = settings.retriever.retrieval(
+ ranks = await settings.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,
@@ -148,9 +149,10 @@ async def retrieval(tenant_id):
doc_ids=doc_ids,
rank_feature=label_question(question, [kb])
)
+ ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], [tenant_id])
if use_kg:
- ck = settings.kg_retriever.retrieval(question,
+ ck = await settings.kg_retriever.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py
index bef03d38ec4..16f5a2e8d27 100644
--- a/api/apps/sdk/doc.py
+++ b/api/apps/sdk/doc.py
@@ -27,8 +27,9 @@
from api.constants import FILE_NAME_LEN_LIMIT
from api.db import FileType
-from api.db.db_models import File, Task
+from api.db.db_models import APIToken, File, Task
from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -255,7 +256,8 @@ async def update_doc(tenant_id, dataset_id, document_id):
if "meta_fields" in req:
if not isinstance(req["meta_fields"], dict):
return get_error_data_result(message="meta_fields must be a dictionary")
- DocumentService.update_meta_fields(document_id, req["meta_fields"])
+ if not DocMetadataService.update_document_metadata(document_id, req["meta_fields"]):
+ return get_error_data_result(message="Failed to update metadata")
if "name" in req and req["name"] != doc.name:
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
@@ -417,6 +419,36 @@ async def download(tenant_id, dataset_id, document_id):
)
+@manager.route("/documents/", methods=["GET"]) # noqa: F821
+async def download_doc(document_id):
+ token = request.headers.get("Authorization").split()
+ if len(token) != 2:
+ return get_error_data_result(message='Authorization is not valid!"')
+ token = token[1]
+ objs = APIToken.query(beta=token)
+ if not objs:
+ return get_error_data_result(message='Authentication error: API key is invalid!"')
+
+ if not document_id:
+ return get_error_data_result(message="Specify document_id please.")
+ doc = DocumentService.query(id=document_id)
+ if not doc:
+ return get_error_data_result(message=f"The dataset not own the document {document_id}.")
+ # The process of downloading
+ doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
+ file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
+ if not file_stream:
+ return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
+ file = BytesIO(file_stream)
+ # Use send_file with a proper filename and MIME type
+ return await send_file(
+ file,
+ as_attachment=True,
+ attachment_filename=doc[0].name,
+ mimetype="application/octet-stream", # Set a default MIME type
+ )
+
+
@manager.route("/datasets//documents", methods=["GET"]) # noqa: F821
@token_required
def list_docs(dataset_id, tenant_id):
@@ -568,7 +600,7 @@ def list_docs(dataset_id, tenant_id):
doc_ids_filter = None
if metadata_condition:
- metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
+ metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
if metadata_condition.get("conditions") and not doc_ids_filter:
return get_result(data={"total": 0, "docs": []})
@@ -606,12 +638,12 @@ def list_docs(dataset_id, tenant_id):
@manager.route("/datasets//metadata/summary", methods=["GET"]) # noqa: F821
@token_required
-def metadata_summary(dataset_id, tenant_id):
+async def metadata_summary(dataset_id, tenant_id):
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
-
+ req = await get_request_json()
try:
- summary = DocumentService.get_metadata_summary(dataset_id)
+ summary = DocMetadataService.get_metadata_summary(dataset_id, req.get("doc_ids"))
return get_result(data={"summary": summary})
except Exception as e:
return server_error_response(e)
@@ -647,24 +679,24 @@ async def metadata_batch_update(dataset_id, tenant_id):
for d in deletes:
if not isinstance(d, dict) or not d.get("key"):
return get_error_data_result(message="Each delete requires key.")
-
- kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
- target_doc_ids = set(kb_doc_ids)
+
if document_ids:
+ kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
+ target_doc_ids = set(kb_doc_ids)
invalid_ids = set(document_ids) - set(kb_doc_ids)
if invalid_ids:
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
target_doc_ids = set(document_ids)
if metadata_condition:
- metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
+ metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
target_doc_ids = target_doc_ids & filtered_ids
if metadata_condition.get("conditions") and not target_doc_ids:
return get_result(data={"updated": 0, "matched_docs": 0})
target_doc_ids = list(target_doc_ids)
- updated = DocumentService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
+ updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821
@@ -935,7 +967,7 @@ async def stop_parsing(tenant_id, dataset_id):
@manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821
@token_required
-def list_chunks(tenant_id, dataset_id, document_id):
+async def list_chunks(tenant_id, dataset_id, document_id):
"""
List chunks of a document.
---
@@ -1081,7 +1113,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
_ = Chunk(**final_chunk)
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
- sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
+ sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:
d = {
@@ -1514,32 +1546,51 @@ async def retrieval_test(tenant_id):
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
question = req["question"]
+ # Trim whitespace and validate question
+ if isinstance(question, str):
+ question = question.strip()
+ # Return empty result if question is empty or whitespace-only
+ if not question:
+ return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
toc_enhance = req.get("toc_enhance", False)
langs = req.get("cross_languages", [])
if not isinstance(doc_ids, list):
- return get_error_data_result("`documents` should be a list")
- doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
- for doc_id in doc_ids:
- if doc_id not in doc_ids_list:
- return get_error_data_result(f"The datasets don't own the document {doc_id}")
+ return get_error_data_result("`documents` should be a list")
+ if doc_ids:
+ doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
+ for doc_id in doc_ids:
+ if doc_id not in doc_ids_list:
+ return get_error_data_result(f"The datasets don't own the document {doc_id}")
if not doc_ids:
- metadata_condition = req.get("metadata_condition", {}) or {}
- metas = DocumentService.get_meta_by_kbs(kb_ids)
- doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
- # If metadata_condition has conditions but no docs match, return empty result
- if not doc_ids and metadata_condition.get("conditions"):
- return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
- if metadata_condition and not doc_ids:
- doc_ids = ["-999"]
+ metadata_condition = req.get("metadata_condition")
+ if metadata_condition:
+ metas = DocMetadataService.get_meta_by_kbs(kb_ids)
+ doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
+ # If metadata_condition has conditions but no docs match, return empty result
+ if not doc_ids and metadata_condition.get("conditions"):
+ return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
+ if metadata_condition and not doc_ids:
+ doc_ids = ["-999"]
+ else:
+ # If doc_ids is None all documents of the datasets are used
+ doc_ids = None
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
- if req.get("highlight") == "False" or req.get("highlight") == "false":
+ highlight_val = req.get("highlight", None)
+ if highlight_val is None:
highlight = False
+ elif isinstance(highlight_val, bool):
+ highlight = highlight_val
+ elif isinstance(highlight_val, str):
+ if highlight_val.lower() in ["true", "false"]:
+ highlight = highlight_val.lower() == "true"
+ else:
+ return get_error_data_result("`highlight` should be a boolean")
else:
- highlight = True
+ return get_error_data_result("`highlight` should be a boolean")
try:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
@@ -1558,7 +1609,7 @@ async def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += await keyword_extraction(chat_mdl, question)
- ranks = settings.retriever.retrieval(
+ ranks = await settings.retriever.retrieval(
question,
embd_mdl,
tenant_ids,
@@ -1575,11 +1626,12 @@ async def retrieval_test(tenant_id):
)
if toc_enhance:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
- cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
+ cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
if cks:
ranks["chunks"] = cks
+ ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
if use_kg:
- ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
+ ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py
index a618777884e..759dfae80dd 100644
--- a/api/apps/sdk/files.py
+++ b/api/apps/sdk/files.py
@@ -14,7 +14,6 @@
# limitations under the License.
#
-import asyncio
import pathlib
import re
from quart import request, make_response
@@ -24,7 +23,7 @@
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
-from common.misc_utils import get_uuid
+from common.misc_utils import get_uuid, thread_pool_exec
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
@@ -33,7 +32,6 @@
from common import settings
from common.constants import RetCode
-
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
async def upload(tenant_id):
@@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
async def download_attachment(tenant_id, attachment_id):
try:
ext = request.args.get("ext", "markdown")
- data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
diff --git a/api/apps/memories_app.py b/api/apps/sdk/memories.py
similarity index 73%
rename from api/apps/memories_app.py
rename to api/apps/sdk/memories.py
index 66fcabb4c99..ada4b34fab9 100644
--- a/api/apps/memories_app.py
+++ b/api/apps/sdk/memories.py
@@ -14,6 +14,8 @@
# limitations under the License.
#
import logging
+import os
+import time
from quart import request
from api.apps import login_required, current_user
@@ -21,6 +23,7 @@
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
+from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
@@ -30,26 +33,60 @@
from common.constants import MemoryType, RetCode, ForgettingPolicy
-@manager.route("", methods=["POST"]) # noqa: F821
+@manager.route("/memories", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
+ timing_enabled = os.getenv("RAGFLOW_API_TIMING")
+ t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
+ t_parsed = time.perf_counter() if timing_enabled else None
# check name length
name = req["name"]
memory_name = name.strip()
if len(memory_name) == 0:
+ if timing_enabled:
+ logging.info(
+ "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
+ (t_parsed - t_start) * 1000,
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
+ if timing_enabled:
+ logging.info(
+ "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
+ (t_parsed - t_start) * 1000,
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid
+ if not isinstance(req["memory_type"], list):
+ if timing_enabled:
+ logging.info(
+ "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
+ (t_parsed - t_start) * 1000,
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
+ return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
+ if timing_enabled:
+ logging.info(
+ "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
+ (t_parsed - t_start) * 1000,
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type)
try:
+ t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory(
tenant_id=current_user.id,
name=memory_name,
@@ -57,6 +94,15 @@ async def create_memory():
embd_id=req["embd_id"],
llm_id=req["llm_id"]
)
+ if timing_enabled:
+ logging.info(
+ "api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
+ (t_parsed - t_start) * 1000,
+ (t_before_db - t_parsed) * 1000,
+ (time.perf_counter() - t_before_db) * 1000,
+ (time.perf_counter() - t_start) * 1000,
+ request.path,
+ )
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@@ -67,7 +113,7 @@ async def create_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
-@manager.route("/", methods=["PUT"]) # noqa: F821
+@manager.route("/memories/", methods=["PUT"]) # noqa: F821
@login_required
async def update_memory(memory_id):
req = await get_request_json()
@@ -151,7 +197,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
-@manager.route("/", methods=["DELETE"]) # noqa: F821
+@manager.route("/memories/", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@@ -167,7 +213,7 @@ async def delete_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
-@manager.route("", methods=["GET"]) # noqa: F821
+@manager.route("/memories", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@@ -179,13 +225,18 @@ async def list_memory():
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
# make filter dict
- filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
+ filter_dict: dict = {"storage_type": storage_type}
if not tenant_ids:
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else:
+ if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
+ tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids
+ if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
+ memory_types = memory_types[0].split(',')
+ filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
@@ -196,7 +247,7 @@ async def list_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
-@manager.route("//config", methods=["GET"]) # noqa: F821
+@manager.route("/memories//config", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
@@ -205,11 +256,13 @@ async def get_memory_config(memory_id):
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
-@manager.route("/", methods=["GET"]) # noqa: F821
+@manager.route("/memories/", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
+ if len(agent_ids) == 1 and ',' in agent_ids[0]:
+ agent_ids = agent_ids[0].split(',')
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
@@ -220,9 +273,19 @@ async def get_memory_detail(memory_id):
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
+ extract_task_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
+ task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
+ if task_list:
+ task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
+ for task in task_list:
+ # the 'digest' field carries the source_id when a task is created, so use 'digest' as key
+ extract_task_mapping.update({int(task["digest"]): task})
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
+ message["task"] = extract_task_mapping.get(message["message_id"], {})
+ for extract_msg in message["extract"]:
+ extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
diff --git a/api/apps/messages_app.py b/api/apps/sdk/messages.py
similarity index 79%
rename from api/apps/messages_app.py
rename to api/apps/sdk/messages.py
index 2963baefa4a..5ed5902188a 100644
--- a/api/apps/messages_app.py
+++ b/api/apps/sdk/messages.py
@@ -24,44 +24,31 @@
from common.constants import RetCode
-@manager.route("", methods=["POST"]) # noqa: F821
+@manager.route("/messages", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
- agent_id = req["agent_id"]
- session_id = req["session_id"]
- user_id = req["user_id"] if req.get("user_id") else ""
- user_input = req["user_input"]
- agent_response = req["agent_response"]
-
- res = []
- for memory_id in memory_ids:
- success, msg = await memory_message_service.save_to_memory(
- memory_id,
- {
- "user_id": user_id,
- "agent_id": agent_id,
- "session_id": session_id,
- "user_input": user_input,
- "agent_response": agent_response
- }
- )
- res.append({
- "memory_id": memory_id,
- "success": success,
- "message": msg
- })
-
- if all([r["success"] for r in res]):
- return get_json_result(message="Successfully added to memories.")
-
- return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
-
-
-@manager.route("/:", methods=["DELETE"]) # noqa: F821
+
+ message_dict = {
+ "user_id": req.get("user_id"),
+ "agent_id": req["agent_id"],
+ "session_id": req["session_id"],
+ "user_input": req["user_input"],
+ "agent_response": req["agent_response"],
+ }
+
+ res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict)
+
+ if res:
+ return get_json_result(message=msg)
+
+ return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg)
+
+
+@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
@@ -80,7 +67,7 @@ async def forget_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
-@manager.route("/:", methods=["PUT"]) # noqa: F821
+@manager.route("/messages/:", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
@@ -100,16 +87,17 @@ async def update_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
-@manager.route("/search", methods=["GET"]) # noqa: F821
+@manager.route("/messages/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
- print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
+ if len(memory_ids) == 1 and ',' in memory_ids[0]:
+ memory_ids = memory_ids[0].split(',')
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
@@ -132,11 +120,13 @@ async def search_message():
return get_json_result(message=True, data=res)
-@manager.route("", methods=["GET"]) # noqa: F821
+@manager.route("/messages", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
+ if len(memory_ids) == 1 and ',' in memory_ids[0]:
+ memory_ids = memory_ids[0].split(',')
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
@@ -154,7 +144,7 @@ async def get_messages():
return get_json_result(message=True, data=res)
-@manager.route("/:/content", methods=["GET"]) # noqa: F821
+@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index f9615e36ba1..589521f0dbd 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -18,9 +18,14 @@
import re
import time
-import tiktoken
+import os
+import tempfile
+import logging
+
from quart import Response, jsonify, request
+from common.token_utils import num_tokens_from_string
+
from agent.canvas import Canvas
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
@@ -30,12 +35,12 @@
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
-from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
from api.db.services.search_service import SearchService
-from api.db.services.user_service import UserTenantService
+from api.db.services.user_service import TenantService,UserTenantService
from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request
@@ -142,7 +147,7 @@ async def chat_completion(tenant_id, chat_id):
return get_error_data_result(message="metadata_condition must be an object.")
if metadata_condition and req.get("question"):
- metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
+ metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
@@ -187,6 +192,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
- If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
- If `stream` is False, the reference will be included in `choices[0].message.reference`.
+ - If `extra_body.reference_metadata.include` is True, each reference chunk may include `document_metadata` in both streaming and non-streaming responses.
Example usage:
@@ -201,7 +207,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
Alternatively, you can use Python's `OpenAI` client:
+ NOTE: Streaming via `client.chat.completions.create(stream=True, ...)` does
+ not return `reference` currently. The only way to return `reference` is
+ non-stream mode with `with_raw_response`.
+
from openai import OpenAI
+ import json
model = "model"
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/")
@@ -209,17 +220,20 @@ async def chat_completion_openai_like(tenant_id, chat_id):
stream = True
reference = True
- completion = client.chat.completions.create(
- model=model,
+ request_kwargs = dict(
+ model="model",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am an AI assistant named..."},
{"role": "user", "content": "Can you tell me how to install neovim"},
],
- stream=stream,
extra_body={
"reference": reference,
+ "reference_metadata": {
+ "include": True,
+ "fields": ["author", "year", "source"],
+ },
"metadata_condition": {
"logic": "and",
"conditions": [
@@ -230,19 +244,25 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
]
}
- }
+ },
)
if stream:
- for chunk in completion:
- print(chunk)
- if reference and chunk.choices[0].finish_reason == "stop":
- print(f"Reference:\n{chunk.choices[0].delta.reference}")
- print(f"Final content:\n{chunk.choices[0].delta.final_content}")
+ completion = client.chat.completions.create(stream=True, **request_kwargs)
+ for chunk in completion:
+ print(chunk)
else:
- print(completion.choices[0].message.content)
- if reference:
- print(completion.choices[0].message.reference)
+ resp = client.chat.completions.with_raw_response.create(
+ stream=False, **request_kwargs
+ )
+ print("status:", resp.http_response.status_code)
+ raw_text = resp.http_response.text
+ print("raw:", raw_text)
+
+ data = json.loads(raw_text)
+ print("assistant:", data["choices"][0]["message"].get("content"))
+ print("reference:", data["choices"][0]["message"].get("reference"))
+
"""
req = await get_request_json()
@@ -251,6 +271,13 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return get_error_data_result("extra_body must be an object.")
need_reference = bool(extra_body.get("reference", False))
+ reference_metadata = extra_body.get("reference_metadata") or {}
+ if reference_metadata and not isinstance(reference_metadata, dict):
+ return get_error_data_result("reference_metadata must be an object.")
+ include_reference_metadata = bool(reference_metadata.get("include", False))
+ metadata_fields = reference_metadata.get("fields")
+ if metadata_fields is not None and not isinstance(metadata_fields, list):
+ return get_error_data_result("reference_metadata.fields must be an array.")
messages = req.get("messages", [])
# To prevent empty [] input
@@ -261,7 +288,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
prompt = messages[-1]["content"]
# Treat context tokens as reasoning tokens
- context_token_used = sum(len(message["content"]) for message in messages)
+ context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages)
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
if not dia:
@@ -274,7 +301,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
doc_ids_str = None
if metadata_condition:
- metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
+ metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
@@ -304,9 +331,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg):
token_used = 0
- answer_cache = ""
- reasoning_cache = ""
last_ans = {}
+ full_content = ""
+ full_reasoning = ""
+ final_answer = None
+ final_reference = None
+ in_think = False
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
@@ -336,47 +366,30 @@ async def streamed_response_generator(chat_id, dia, msg):
chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans
- answer = ans["answer"]
-
- reasoning_match = re.search(r"(.*?) ", answer, flags=re.DOTALL)
- if reasoning_match:
- reasoning_part = reasoning_match.group(1)
- content_part = answer[reasoning_match.end() :]
- else:
- reasoning_part = ""
- content_part = answer
-
- reasoning_incremental = ""
- if reasoning_part:
- if reasoning_part.startswith(reasoning_cache):
- reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
- else:
- reasoning_incremental = reasoning_part
- reasoning_cache = reasoning_part
-
- content_incremental = ""
- if content_part:
- if content_part.startswith(answer_cache):
- content_incremental = content_part.replace(answer_cache, "", 1)
- else:
- content_incremental = content_part
- answer_cache = content_part
-
- token_used += len(reasoning_incremental) + len(content_incremental)
-
- if not any([reasoning_incremental, content_incremental]):
+ if ans.get("final"):
+ if ans.get("answer"):
+ full_content = ans["answer"]
+ final_answer = ans.get("answer") or full_content
+ final_reference = ans.get("reference", {})
continue
-
- if reasoning_incremental:
- response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
+ if ans.get("start_to_think"):
+ in_think = True
+ continue
+ if ans.get("end_to_think"):
+ in_think = False
+ continue
+ delta = ans.get("answer") or ""
+ if not delta:
+ continue
+ token_used += num_tokens_from_string(delta)
+ if in_think:
+ full_reasoning += delta
+ response["choices"][0]["delta"]["reasoning_content"] = delta
+ response["choices"][0]["delta"]["content"] = None
else:
+ full_content += delta
+ response["choices"][0]["delta"]["content"] = delta
response["choices"][0]["delta"]["reasoning_content"] = None
-
- if content_incremental:
- response["choices"][0]["delta"]["content"] = content_incremental
- else:
- response["choices"][0]["delta"]["content"] = None
-
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@@ -386,10 +399,16 @@ async def streamed_response_generator(chat_id, dia, msg):
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
- response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
+ prompt_tokens = num_tokens_from_string(prompt)
+ response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used}
if need_reference:
- response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
- response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
+ reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
+ response["choices"][0]["delta"]["reference"] = _build_reference_chunks(
+ reference_payload,
+ include_metadata=include_reference_metadata,
+ metadata_fields=metadata_fields,
+ )
+ response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"
@@ -416,12 +435,12 @@ async def streamed_response_generator(chat_id, dia, msg):
"created": int(time.time()),
"model": req.get("model", ""),
"usage": {
- "prompt_tokens": len(prompt),
- "completion_tokens": len(content),
- "total_tokens": len(prompt) + len(content),
+ "prompt_tokens": num_tokens_from_string(prompt),
+ "completion_tokens": num_tokens_from_string(content),
+ "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content),
"completion_tokens_details": {
"reasoning_tokens": context_token_used,
- "accepted_prediction_tokens": len(content),
+ "accepted_prediction_tokens": num_tokens_from_string(content),
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
@@ -438,7 +457,11 @@ async def streamed_response_generator(chat_id, dia, msg):
],
}
if need_reference:
- response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
+ response["choices"][0]["message"]["reference"] = _build_reference_chunks(
+ answer.get("reference", {}),
+ include_metadata=include_reference_metadata,
+ metadata_fields=metadata_fields,
+ )
return jsonify(response)
@@ -448,7 +471,6 @@ async def streamed_response_generator(chat_id, dia, msg):
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await get_request_json()
- tiktoken_encode = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
return get_error_data_result("You must provide at least one message.")
@@ -456,7 +478,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
- prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
+ prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages)
if not filtered_messages:
return jsonify(
get_data_openai(
@@ -464,7 +486,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
- completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
+ completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."),
prompt_tokens=prompt_tokens,
)
)
@@ -943,6 +965,7 @@ async def chatbots_inputs(dialog_id):
"title": dialog.name,
"avatar": dialog.icon,
"prologue": dialog.prompt_config.get("prologue", ""),
+ "has_tavily_key": bool(dialog.prompt_config.get("tavily_api_key", "").strip()),
}
)
@@ -1058,11 +1081,13 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
+ rerank_id = req.get("rerank_id", "")
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
async def _retrieval():
+ nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
_question = question
@@ -1074,13 +1099,22 @@ async def _retrieval():
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
+ # Apply search_config settings if not explicitly provided in request
+ if not req.get("similarity_threshold"):
+ similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
+ if not req.get("vector_similarity_weight"):
+ vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
+ if not req.get("top_k"):
+ top = int(search_config.get("top_k", top))
+ if not req.get("rerank_id"):
+ rerank_id = search_config.get("rerank_id", "")
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
if meta_data_filter:
- metas = DocumentService.get_meta_by_kbs(kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
tenants = UserTenantService.query(user_id=tenant_id)
@@ -1103,20 +1137,20 @@ async def _retrieval():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
- if req.get("rerank_id"):
- rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
+ if rerank_id:
+ rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=rerank_id)
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
- ranks = settings.retriever.retrieval(
+ ranks = await settings.retriever.retrieval(
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
- ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
+ ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
@@ -1233,3 +1267,135 @@ async def mindmap():
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
+
+@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
+@token_required
+async def sequence2txt(tenant_id):
+ req = await request.form
+ stream_mode = req.get("stream", "false").lower() == "true"
+ files = await request.files
+ if "file" not in files:
+ return get_error_data_result(message="Missing 'file' in multipart form-data")
+
+ uploaded = files["file"]
+
+ ALLOWED_EXTS = {
+ ".wav", ".mp3", ".m4a", ".aac",
+ ".flac", ".ogg", ".webm",
+ ".opus", ".wma"
+ }
+
+ filename = uploaded.filename or ""
+ suffix = os.path.splitext(filename)[-1].lower()
+ if suffix not in ALLOWED_EXTS:
+ return get_error_data_result(message=
+ f"Unsupported audio format: {suffix}. "
+ f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
+ )
+ fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
+ os.close(fd)
+ await uploaded.save(temp_audio_path)
+
+ tenants = TenantService.get_info_by(tenant_id)
+ if not tenants:
+ return get_error_data_result(message="Tenant not found!")
+
+ asr_id = tenants[0]["asr_id"]
+ if not asr_id:
+ return get_error_data_result(message="No default ASR model is set")
+
+ asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
+ if not stream_mode:
+ text = asr_mdl.transcription(temp_audio_path)
+ try:
+ os.remove(temp_audio_path)
+ except Exception as e:
+ logging.error(f"Failed to remove temp audio file: {str(e)}")
+ return get_json_result(data={"text": text})
+ async def event_stream():
+ try:
+ for evt in asr_mdl.stream_transcription(temp_audio_path):
+ yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
+ except Exception as e:
+ err = {"event": "error", "text": str(e)}
+ yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
+ finally:
+ try:
+ os.remove(temp_audio_path)
+ except Exception as e:
+ logging.error(f"Failed to remove temp audio file: {str(e)}")
+
+ return Response(event_stream(), content_type="text/event-stream")
+
+@manager.route("/tts", methods=["POST"]) # noqa: F821
+@token_required
+async def tts(tenant_id):
+ req = await get_request_json()
+ text = req["text"]
+
+ tenants = TenantService.get_info_by(tenant_id)
+ if not tenants:
+ return get_error_data_result(message="Tenant not found!")
+
+ tts_id = tenants[0]["tts_id"]
+ if not tts_id:
+ return get_error_data_result(message="No default TTS model is set")
+
+ tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
+
+ def stream_audio():
+ try:
+ for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
+ for chunk in tts_mdl.tts(txt):
+ yield chunk
+ except Exception as e:
+ yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
+
+ resp = Response(stream_audio(), mimetype="audio/mpeg")
+ resp.headers.add_header("Cache-Control", "no-cache")
+ resp.headers.add_header("Connection", "keep-alive")
+ resp.headers.add_header("X-Accel-Buffering", "no")
+
+ return resp
+
+
+def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None):
+ chunks = chunks_format(reference)
+ if not include_metadata:
+ return chunks
+
+ doc_ids_by_kb = {}
+ for chunk in chunks:
+ kb_id = chunk.get("dataset_id")
+ doc_id = chunk.get("document_id")
+ if not kb_id or not doc_id:
+ continue
+ doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id)
+
+ if not doc_ids_by_kb:
+ return chunks
+
+ meta_by_doc = {}
+ for kb_id, doc_ids in doc_ids_by_kb.items():
+ meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id)
+ if meta_map:
+ meta_by_doc.update(meta_map)
+
+ if metadata_fields is not None:
+ metadata_fields = {f for f in metadata_fields if isinstance(f, str)}
+ if not metadata_fields:
+ return chunks
+
+ for chunk in chunks:
+ doc_id = chunk.get("document_id")
+ if not doc_id:
+ continue
+ meta = meta_by_doc.get(doc_id)
+ if not meta:
+ continue
+ if metadata_fields is not None:
+ meta = {k: v for k, v in meta.items() if k in metadata_fields}
+ if meta:
+ chunk["document_metadata"] = meta
+
+ return chunks
diff --git a/api/apps/system_app.py b/api/apps/system_app.py
index 379b597de9d..b15054490b0 100644
--- a/api/apps/system_app.py
+++ b/api/apps/system_app.py
@@ -35,7 +35,7 @@
from rag.utils.redis_conn import REDIS_CONN
from quart import jsonify
-from api.utils.health_utils import run_health_checks
+from api.utils.health_utils import run_health_checks, get_oceanbase_status
from common import settings
@@ -178,10 +178,46 @@ def healthz():
@manager.route("/ping", methods=["GET"]) # noqa: F821
-def ping():
+async def ping():
return "pong", 200
+@manager.route("/oceanbase/status", methods=["GET"]) # noqa: F821
+@login_required
+def oceanbase_status():
+ """
+ Get OceanBase health status and performance metrics.
+ ---
+ tags:
+ - System
+ security:
+ - ApiKeyAuth: []
+ responses:
+ 200:
+ description: OceanBase status retrieved successfully.
+ schema:
+ type: object
+ properties:
+ status:
+ type: string
+ description: Status (alive/timeout).
+ message:
+ type: object
+ description: Detailed status information including health and performance metrics.
+ """
+ try:
+ status_info = get_oceanbase_status()
+ return get_json_result(data=status_info)
+ except Exception as e:
+ return get_json_result(
+ data={
+ "status": "error",
+ "message": f"Failed to get OceanBase status: {str(e)}"
+ },
+ code=500
+ )
+
+
@manager.route("/new_token", methods=["POST"]) # noqa: F821
@login_required
def new_token():
diff --git a/api/apps/user_app.py b/api/apps/user_app.py
index e1ad157bc72..3eb8e6c3d3a 100644
--- a/api/apps/user_app.py
+++ b/api/apps/user_app.py
@@ -98,9 +98,7 @@ async def login():
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
email = json_body.get("email", "")
- if email == "admin@ragflow.io":
- return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Default admin account cannot be used to login normal services!")
-
+
users = UserService.query(email=email)
if not users:
return get_json_result(
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 738e26a06ac..ca72be2101c 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -48,6 +48,7 @@
class TextFieldType(Enum):
MYSQL = "LONGTEXT"
+ OCEANBASE = "LONGTEXT"
POSTGRES = "TEXT"
@@ -281,7 +282,11 @@ def _handle_connection_loss(self):
except Exception as e:
logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1)
- self.connect()
+ try:
+ self.connect()
+ except Exception as e2:
+ logging.error(f"Failed to reconnect on second attempt: {e2}")
+ raise
def begin(self):
for attempt in range(self.max_retries + 1):
@@ -352,7 +357,11 @@ def _handle_connection_loss(self):
except Exception as e:
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
time.sleep(0.1)
- self.connect()
+ try:
+ self.connect()
+ except Exception as e2:
+ logging.error(f"Failed to reconnect to PostgreSQL on second attempt: {e2}")
+ raise
def begin(self):
for attempt in range(self.max_retries + 1):
@@ -375,13 +384,95 @@ def begin(self):
return None
+class RetryingPooledOceanBaseDatabase(PooledMySQLDatabase):
+ """Pooled OceanBase database with retry mechanism.
+
+ OceanBase is compatible with MySQL protocol, so we inherit from PooledMySQLDatabase.
+ This class provides connection pooling and automatic retry for connection issues.
+ """
+ def __init__(self, *args, **kwargs):
+ self.max_retries = kwargs.pop("max_retries", 5)
+ self.retry_delay = kwargs.pop("retry_delay", 1)
+ super().__init__(*args, **kwargs)
+
+ def execute_sql(self, sql, params=None, commit=True):
+ for attempt in range(self.max_retries + 1):
+ try:
+ return super().execute_sql(sql, params, commit)
+ except (OperationalError, InterfaceError) as e:
+ # OceanBase/MySQL specific error codes
+ # 2013: Lost connection to MySQL server during query
+ # 2006: MySQL server has gone away
+ error_codes = [2013, 2006]
+ error_messages = ['', 'Lost connection', 'gone away']
+
+ should_retry = (
+ (hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
+ any(msg in str(e).lower() for msg in error_messages) or
+ (hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
+ )
+
+ if should_retry and attempt < self.max_retries:
+ logging.warning(
+ f"OceanBase connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
+ )
+ self._handle_connection_loss()
+ time.sleep(self.retry_delay * (2 ** attempt))
+ else:
+ logging.error(f"OceanBase execution failure: {e}")
+ raise
+ return None
+
+ def _handle_connection_loss(self):
+ try:
+ self.close()
+ except Exception:
+ pass
+ try:
+ self.connect()
+ except Exception as e:
+ logging.error(f"Failed to reconnect to OceanBase: {e}")
+ time.sleep(0.1)
+ try:
+ self.connect()
+ except Exception as e2:
+ logging.error(f"Failed to reconnect to OceanBase on second attempt: {e2}")
+ raise
+
+ def begin(self):
+ for attempt in range(self.max_retries + 1):
+ try:
+ return super().begin()
+ except (OperationalError, InterfaceError) as e:
+ error_codes = [2013, 2006]
+ error_messages = ['', 'Lost connection']
+
+ should_retry = (
+ (hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
+ (str(e) in error_messages) or
+ (hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
+ )
+
+ if should_retry and attempt < self.max_retries:
+ logging.warning(
+ f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
+ )
+ self._handle_connection_loss()
+ time.sleep(self.retry_delay * (2 ** attempt))
+ else:
+ raise
+ return None
+
+
class PooledDatabase(Enum):
MYSQL = RetryingPooledMySQLDatabase
+ OCEANBASE = RetryingPooledOceanBaseDatabase
POSTGRES = RetryingPooledPostgresqlDatabase
class DatabaseMigrator(Enum):
MYSQL = MySQLMigrator
+ OCEANBASE = MySQLMigrator
POSTGRES = PostgresqlMigrator
@@ -540,6 +631,7 @@ def magic(*args, **kwargs):
class DatabaseLock(Enum):
MYSQL = MysqlDatabaseLock
+ OCEANBASE = MysqlDatabaseLock
POSTGRES = PostgresDatabaseLock
@@ -787,7 +879,6 @@ class Document(DataBaseModel):
progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True, index=True)
process_duration = FloatField(default=0)
- meta_fields = JSONField(null=True, default={})
suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True)
@@ -900,8 +991,10 @@ class Meta:
class API4Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
+ name = CharField(max_length=255, null=True, help_text="conversation name", index=False)
dialog_id = CharField(max_length=32, null=False, index=True)
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
+ exp_user_id = CharField(max_length=255, null=True, help_text="exp_user_id", index=True)
message = JSONField(null=True)
reference = JSONField(null=True, default=[])
tokens = IntegerField(default=0)
@@ -1197,224 +1290,96 @@ class Memory(DataBaseModel):
class Meta:
db_table = "memory"
+class SystemSettings(DataBaseModel):
+ name = CharField(max_length=128, primary_key=True)
+ source = CharField(max_length=32, null=False, index=False)
+ data_type = CharField(max_length=32, null=False, index=False)
+ value = TextField(null=False, help_text="Configuration value (JSON, string, etc.)")
+ class Meta:
+ db_table = "system_settings"
-def migrate_db():
- logging.disable(logging.ERROR)
- migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
- try:
- migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
- except Exception:
- pass
- try:
- migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
- except Exception:
- pass
+def alter_db_add_column(migrator, table_name, column_name, column_type):
try:
- migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
- except Exception:
- pass
- try:
- migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
- except Exception:
- pass
+ migrate(migrator.add_column(table_name, column_name, column_type))
+ except OperationalError as ex:
+ error_codes = [1060]
+ error_messages = ['Duplicate column name']
+
+ should_skip_error = (
+ (hasattr(ex, 'args') and ex.args and ex.args[0] in error_codes) or
+ (str(ex) in error_messages)
+ )
- try:
- migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
- except Exception:
- pass
- try:
- migrate(migrator.rename_column("task", "process_duation", "process_duration"))
- except Exception:
- pass
- try:
- migrate(migrator.rename_column("document", "process_duation", "process_duration"))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
- except Exception:
- pass
- try:
- migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
- except Exception:
- pass
- try:
- migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
- except Exception:
- pass
- try:
- migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
- except Exception:
- pass
+ if not should_skip_error:
+ logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, operation error: {ex}")
- # RAG Evaluation tables
- try:
- migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
- except Exception:
- pass
- try:
- migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
- except Exception:
+ except Exception as ex:
+ logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, error: {ex}")
pass
+
+def alter_db_column_type(migrator, table_name, column_name, new_column_type):
try:
- migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
- except Exception:
+ migrate(migrator.alter_column_type(table_name, column_name, new_column_type))
+ except Exception as ex:
+ logging.critical(f"Failed to alter {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name} type, error: {ex}")
pass
+
+def alter_db_rename_column(migrator, table_name, old_column_name, new_column_name):
try:
- migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
+ migrate(migrator.rename_column(table_name, old_column_name, new_column_name))
except Exception:
+ # rename fail will lead to a weired error.
+ # logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
pass
+def migrate_db():
+ logging.disable(logging.ERROR)
+ migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
+ alter_db_add_column(migrator, "file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))
+ alter_db_add_column(migrator, "tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
+ alter_db_add_column(migrator, "dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
+ alter_db_column_type(migrator, "dialog", "top_k", IntegerField(default=1024))
+ alter_db_add_column(migrator, "tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))
+ alter_db_add_column(migrator, "api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
+ alter_db_add_column(migrator, "tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
+ alter_db_add_column(migrator, "api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
+ alter_db_add_column(migrator, "task", "retry_count", IntegerField(default=0))
+ alter_db_column_type(migrator, "api_token", "dialog_id", CharField(max_length=32, null=True, index=True))
+ alter_db_add_column(migrator, "tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
+ alter_db_add_column(migrator, "api_4_conversation", "dsl", JSONField(null=True, default={}))
+ alter_db_add_column(migrator, "knowledgebase", "pagerank", IntegerField(default=0, index=False))
+ alter_db_add_column(migrator, "api_token", "beta", CharField(max_length=255, null=True, index=True))
+ alter_db_add_column(migrator, "task", "digest", TextField(null=True, help_text="task digest", default=""))
+ alter_db_add_column(migrator, "task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
+ alter_db_add_column(migrator, "conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True))
+ alter_db_add_column(migrator, "task", "task_type", CharField(max_length=32, null=False, default=""))
+ alter_db_add_column(migrator, "task", "priority", IntegerField(default=0))
+ alter_db_add_column(migrator, "user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))
+ alter_db_add_column(migrator, "llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))
+ alter_db_add_column(migrator, "mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))
+ alter_db_rename_column(migrator, "task", "process_duation", "process_duration")
+ alter_db_rename_column(migrator, "document", "process_duation", "process_duration")
+ alter_db_add_column(migrator, "document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))
+ alter_db_add_column(migrator, "api_4_conversation", "errors", TextField(null=True, help_text="errors"))
+ alter_db_add_column(migrator, "dialog", "meta_data_filter", JSONField(null=True, default={}))
+ alter_db_column_type(migrator, "canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))
+ alter_db_column_type(migrator, "canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))
+ alter_db_add_column(migrator, "user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
+ alter_db_add_column(migrator, "canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
+ alter_db_add_column(migrator, "knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
+ alter_db_add_column(migrator, "document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
+ alter_db_add_column(migrator, "knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))
+ alter_db_add_column(migrator, "knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))
+ alter_db_add_column(migrator, "knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))
+ alter_db_add_column(migrator, "knowledgebase", "raptor_task_finish_at", CharField(null=True))
+ alter_db_add_column(migrator, "knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))
+ alter_db_add_column(migrator, "knowledgebase", "mindmap_task_finish_at", CharField(null=True))
+ alter_db_column_type(migrator, "tenant_llm", "api_key", TextField(null=True, help_text="API KEY"))
+ alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
+ alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
+ alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
+ alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False))
+ alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True))
+ # Migrate system_settings.value from CharField to TextField for longer sandbox configs
+ alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
logging.disable(logging.NOTSET)
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 77f676f0962..49a094eb323 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -30,7 +30,8 @@
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
-from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
+from api.db.services.system_settings_service import SystemSettingsService
+from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@@ -158,13 +159,15 @@ def add_graph_templates():
CanvasTemplateService.save(**cnvs)
except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
- except Exception:
- logging.exception("Add agent templates error: ")
+ except Exception as e:
+ logging.exception(f"Add agent templates error: {e}")
def init_web_data():
start_time = time.time()
+ init_table()
+
init_llm_factory()
# if not UserService.get_all().count():
# init_superuser()
@@ -172,8 +175,34 @@ def init_web_data():
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
+ fix_missing_tokenized_memory()
logging.info("init web data success:{}".format(time.time() - start_time))
+def init_table():
+ # init system_settings
+ with open(os.path.join(get_project_base_directory(), "conf", "system_settings.json"), "r") as f:
+ records_from_file = json.load(f)["system_settings"]
+
+ record_index = {}
+ records_from_db = SystemSettingsService.get_all()
+ for index, record in enumerate(records_from_db):
+ record_index[record.name] = index
+
+ to_save = []
+ for record in records_from_file:
+ setting_name = record["name"]
+ if setting_name not in record_index:
+ to_save.append(record)
+
+ len_to_save = len(to_save)
+ if len_to_save > 0:
+ # not initialized
+ try:
+ SystemSettingsService.insert_many(to_save, len_to_save)
+ except Exception as e:
+ logging.exception("System settings init error: {}".format(e))
+ raise e
+
if __name__ == '__main__':
init_web_db()
diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py
index 79848cad5c3..8f662124724 100644
--- a/api/db/joint_services/memory_message_service.py
+++ b/api/db/joint_services/memory_message_service.py
@@ -16,7 +16,6 @@
import logging
from typing import List
-from api.db.services.task_service import TaskService
from common import settings
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
@@ -24,6 +23,7 @@
from common.misc_utils import get_uuid
from api.db.db_utils import bulk_insert_into_db
from api.db.db_models import Task
+from api.db.services.task_service import TaskService
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
@@ -90,13 +90,19 @@ async def save_to_memory(memory_id: str, message_dict: dict):
return await embed_and_save(memory, message_list)
-async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
+async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
- return False, f"Memory '{memory_id}' not found."
+ msg = f"Memory '{memory_id}' not found."
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
+ return False, msg
if memory.memory_type == MemoryType.RAW.value:
- return True, f"Memory '{memory_id}' don't need to extract."
+ msg = f"Memory '{memory_id}' don't need to extract."
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
+ return True, msg
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
@@ -105,7 +111,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
- message_dict.get("agent_response", "")
+ message_dict.get("agent_response", ""),
+ task_id=task_id
)
message_list = [{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
@@ -122,13 +129,18 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
"status": True
} for content in extracted_content]
if not message_list:
- return True, "No memory extracted from raw message."
+ msg = "No memory extracted from raw message."
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
+ return True, msg
- return await embed_and_save(memory, message_list)
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
+ return await embed_and_save(memory, message_list, task_id)
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
- agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
+ agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
@@ -143,8 +155,12 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
@@ -153,16 +169,23 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
-async def embed_and_save(memory, message_list: list[dict]):
+async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
vector_dimension = len(vector_list[0])
if not MessageService.has_index(memory.tenant_id, memory.id):
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
if not created:
- return False, "Failed to create message index."
+ error_msg = "Failed to create message index."
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
+ return False, error_msg
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
@@ -174,11 +197,19 @@ async def embed_and_save(memory, message_list: list[dict]):
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
decrease_memory_size_cache(memory.id, delete_size)
else:
- return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
+ error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
+ return False, error_msg
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
if fail_cases:
- return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
+ error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
+ return False, error_msg
+ if task_id:
+ TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
increase_memory_size_cache(memory.id, new_msg_size)
return True, "Message saved successfully."
@@ -275,6 +306,24 @@ def init_memory_size_cache():
logging.info("Memory size cache init done.")
+def fix_missing_tokenized_memory():
+ if settings.DOC_ENGINE != "elasticsearch":
+ logging.info("Not using elasticsearch as doc engine, no need to fix missing tokenized memory.")
+ return
+ memory_list = MemoryService.get_all_memory()
+ if not memory_list:
+ logging.info("No memory found, no need to fix missing tokenized memory.")
+ else:
+ for m in memory_list:
+ message_list = MessageService.get_missing_field_messages(m.id, m.tenant_id, "tokenized_content_ltks")
+ for msg in message_list:
+ # update content to refresh tokenized field
+ MessageService.update_message({"message_id": msg["message_id"], "memory_id": m.id}, {"content": msg["content"]}, m.tenant_id, m.id)
+ if message_list:
+ logging.info(f"Fixed {len(message_list)} messages missing tokenized field in memory: {m.name}.")
+ logging.info("Fix missing tokenized memory done.")
+
+
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
@@ -379,11 +428,11 @@ async def handle_save_to_memory_task(task_param: dict):
memory_id = task_param["memory_id"]
source_id = task_param["source_id"]
message_dict = task_param["message_dict"]
- success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
+ success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
if success:
- TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
+ TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
logging.error(msg)
- TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
+ TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return False, msg
diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py
index 2e4dfeaab23..7490c9bad22 100644
--- a/api/db/joint_services/user_account_service.py
+++ b/api/db/joint_services/user_account_service.py
@@ -23,6 +23,7 @@
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
@@ -107,6 +108,11 @@ def create_new_user(user_info: dict) -> dict:
except Exception as create_error:
logging.exception(create_error)
# rollback
+ try:
+ metadata_index_name = DocMetadataService._get_doc_meta_index_name(user_id)
+ settings.docStoreConn.delete_idx(metadata_index_name, "")
+ except Exception as e:
+ logging.exception(e)
try:
TenantService.delete_by_id(user_id)
except Exception as e:
@@ -165,6 +171,12 @@ def delete_user_data(user_id: str) -> dict:
# step1.1.2 delete file and document info in db
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
if doc_ids:
+ for doc in doc_ids:
+ try:
+ DocMetadataService.delete_document_metadata(doc["id"], skip_empty_check=True)
+ except Exception as e:
+ logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}")
+
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {doc_delete_res} document records.\n"
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
@@ -202,6 +214,13 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
+ try:
+ metadata_index_name = DocMetadataService._get_doc_meta_index_name(tenant_id)
+ settings.docStoreConn.delete_idx(metadata_index_name, "")
+ done_msg += f"- Deleted metadata table {metadata_index_name}.\n"
+ except Exception as e:
+ logging.warning(f"Failed to delete metadata table for tenant {tenant_id}: {e}")
+ done_msg += "- Warning: Failed to delete metadata table (continuing).\n"
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
@@ -269,6 +288,11 @@ def delete_user_data(user_id: str) -> dict:
# step2.1.5 delete document record
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {doc_delete_res} documents.\n"
+ for doc in created_documents:
+ try:
+ DocMetadataService.delete_document_metadata(doc['id'])
+ except Exception as e:
+ logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}")
# step2.1.6 update dataset doc&chunk&token cnt
for kb_id, doc_num in kb_doc_info.items():
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py
index aee35422b7f..be41dc1b642 100644
--- a/api/db/services/api_service.py
+++ b/api/db/services/api_service.py
@@ -48,8 +48,8 @@ class API4ConversationService(CommonService):
@DB.connection_context()
def get_list(cls, dialog_id, tenant_id,
page_number, items_per_page,
- orderby, desc, id, user_id=None, include_dsl=True, keywords="",
- from_date=None, to_date=None
+ orderby, desc, id=None, user_id=None, include_dsl=True, keywords="",
+ from_date=None, to_date=None, exp_user_id=None
):
if include_dsl:
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
@@ -66,6 +66,8 @@ def get_list(cls, dialog_id, tenant_id,
sessions = sessions.where(cls.model.create_date >= from_date)
if to_date:
sessions = sessions.where(cls.model.create_date <= to_date)
+ if exp_user_id:
+ sessions = sessions.where(cls.model.exp_user_id == exp_user_id)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
@@ -74,6 +76,17 @@ def get_list(cls, dialog_id, tenant_id,
sessions = sessions.paginate(page_number, items_per_page)
return count, list(sessions.dicts())
+
+ @classmethod
+ @DB.connection_context()
+ def get_names(cls, dialog_id, exp_user_id):
+ fields = [cls.model.id, cls.model.name,]
+ sessions = cls.model.select(*fields).where(
+ cls.model.dialog_id == dialog_id,
+ cls.model.exp_user_id == exp_user_id
+ ).order_by(cls.model.getter_by("create_date").desc())
+
+ return list(sessions.dicts())
@classmethod
@DB.connection_context()
diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py
index 763e9c4601e..99cb1990044 100644
--- a/api/db/services/canvas_service.py
+++ b/api/db/services/canvas_service.py
@@ -146,7 +146,6 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
cls.model.id,
cls.model.avatar,
cls.model.title,
- cls.model.dsl,
cls.model.description,
cls.model.permission,
cls.model.user_id.alias("tenant_id"),
@@ -195,6 +194,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
+ custom_header = kwargs.get("custom_header", "")
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
@@ -203,7 +203,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv.message = []
if not isinstance(conv.dsl, str):
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
- canvas = Canvas(conv.dsl, tenant_id, agent_id)
+ canvas = Canvas(conv.dsl, tenant_id, agent_id, canvas_id=agent_id, custom_header=custom_header)
else:
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
@@ -211,7 +211,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
- canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id)
+ canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header)
canvas.reset()
conv = {
"id": session_id,
@@ -229,7 +229,8 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv.message.append({
"role": "user",
"content": query,
- "id": message_id
+ "id": message_id,
+ "files": files
})
txt = ""
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py
index 60db241cc8e..df95debb5f0 100644
--- a/api/db/services/common_service.py
+++ b/api/db/services/common_service.py
@@ -190,10 +190,15 @@ def insert_many(cls, data_list, batch_size=100):
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
+ current_ts = current_timestamp()
+ current_datetime = datetime_format(datetime.now())
with DB.atomic():
for d in data_list:
- d["create_time"] = current_timestamp()
- d["create_date"] = datetime_format(datetime.now())
+ d["create_time"] = current_ts
+ d["create_date"] = current_datetime
+ d["update_time"] = current_ts
+ d["update_date"] = current_datetime
+
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i : i + batch_size]).execute()
diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py
index 660530c824b..d2fcb1b41d8 100644
--- a/api/db/services/connector_service.py
+++ b/api/db/services/connector_service.py
@@ -25,11 +25,11 @@
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
+from api.db.services.document_service import DocMetadataService
from common.misc_utils import get_uuid
from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date
-
class ConnectorService(CommonService):
model = Connector
@@ -202,6 +202,7 @@ def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
return None
class FileObj(BaseModel):
+ id: str
filename: str
blob: bytes
@@ -209,7 +210,7 @@ def read(self) -> bytes:
return self.blob
errs = []
- files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
+ files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
doc_ids = []
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
errs.extend(err)
@@ -227,7 +228,7 @@ def read(self) -> bytes:
# Set metadata if available for this document
if doc["name"] in metadata_map:
- DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]})
+ DocMetadataService.update_document_metadata(doc["id"], metadata_map[doc["name"]])
if not auto_parse or auto_parse == "0":
continue
diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py
index 2a5b06601dc..3287ac15784 100644
--- a/api/db/services/conversation_service.py
+++ b/api/db/services/conversation_service.py
@@ -64,11 +64,13 @@ def get_all_conversation_by_dialog_ids(cls, dialog_ids):
offset += limit
return res
+
def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"]
if not isinstance(reference, dict):
reference = {}
ans["reference"] = {}
+ is_final = ans.get("final", True)
chunk_list = chunks_format(reference)
@@ -81,14 +83,32 @@ def structure_answer(conv, ans, message_id, session_id):
if not conv.message:
conv.message = []
+ content = ans["answer"]
+ if ans.get("start_to_think"):
+ content = ""
+ elif ans.get("end_to_think"):
+ content = " "
+
if not conv.message or conv.message[-1].get("role", "") != "assistant":
- conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
+ conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
else:
- conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
+ if is_final:
+ if ans.get("answer"):
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
+ else:
+ conv.message[-1]["created_at"] = time.time()
+ conv.message[-1]["id"] = message_id
+ else:
+ conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
+ conv.message[-1]["created_at"] = time.time()
+ conv.message[-1]["id"] = message_id
if conv.reference:
- conv.reference[-1] = reference
+ should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
+ if should_update_reference:
+ conv.reference[-1] = reference
return ans
+
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index 4bc24210b20..66025d13ef8 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import binascii
import logging
import re
@@ -23,20 +24,19 @@
from timeit import default_timer as timer
from langfuse import Langfuse
from peewee import fn
-from agentic_reasoning import DeepResearcher
from api.db.services.file_service import FileService
from common.constants import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
-from api.db.services.document_service import DocumentService
+from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.tenant_llm_service import TenantLLMService
from common.time_utils import current_timestamp, datetime_format
-from graphrag.general.mind_map_extractor import MindMapExtractor
-from rag.app.resume import forbidden_select_fields4resume
+from rag.graphrag.general.mind_map_extractor import MindMapExtractor
+from rag.advanced_rag import DeepResearcher
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
@@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
if attachments and msg:
msg[-1]["content"] += attachments
if stream:
- last_ans = ""
- delta_ans = ""
- answer = ""
- async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
- answer = ans
- delta_ans = ans[len(last_ans):]
- if num_tokens_from_string(delta_ans) < 16:
+ stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
+ async for kind, value, state in _stream_with_think_delta(stream_iter):
+ if kind == "marker":
+ flags = {"start_to_think": True} if value == "" else {"end_to_think": True}
+ yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
continue
- last_ans = answer
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
- delta_ans = ""
- if delta_ans:
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
+ yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
@@ -279,6 +273,7 @@ def replacement(match):
async def async_chat(dialog, messages, stream=True, **kwargs):
+ logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream):
@@ -301,10 +296,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
- if langfuse.auth_check():
- langfuse_tracer = langfuse
- trace_id = langfuse_tracer.create_trace_id()
- trace_context = {"trace_id": trace_id}
+ try:
+ if langfuse.auth_check():
+ langfuse_tracer = langfuse
+ trace_id = langfuse_tracer.create_trace_id()
+ trace_context = {"trace_id": trace_id}
+ except Exception:
+ # Skip langfuse tracing if connection fails
+ pass
check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
@@ -324,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
+ logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
- if ans:
+ # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
+ if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
yield ans
return
+ else:
+ logging.debug("SQL failed or returned no results, falling back to vector search")
+
+ param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
+ logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@@ -349,7 +355,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
if dialog.meta_data_filter:
- metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids)
attachments = await apply_meta_data_filter(
dialog.meta_data_filter,
metas,
@@ -367,10 +373,11 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
- if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
+ if attachments is not None and "knowledge" in param_keys:
+ logging.debug("Proceeding with retrieval")
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
- if prompt_config.get("reasoning", False):
+ if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
@@ -386,16 +393,28 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
doc_ids=attachments,
),
)
+ queue = asyncio.Queue()
+ async def callback(msg:str):
+ nonlocal queue
+ await queue.put(msg + " ")
+
+ await callback("")
+ task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
+ while True:
+ msg = await queue.get()
+ if msg.find("") == 0:
+ yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
+ elif msg.find("") == 0:
+ yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
+ break
+ else:
+ yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
+
+ await task
- async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
- if isinstance(think, str):
- thought = think
- knowledges = [t for t in think.split("\n") if t]
- elif stream:
- yield think
else:
if embd_mdl:
- kbinfos = retriever.retrieval(
+ kbinfos = await retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
@@ -411,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
rank_feature=label_question(" ".join(questions), kbs),
)
if prompt_config.get("toc_enhance"):
- cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
+ cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
@@ -421,21 +440,19 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
- ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
+ ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
- knowledges = kb_prompt(kbinfos, max_tokens)
-
+ knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
- "audio_binary": tts(tts_mdl, empty_res)}
- yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
+ "audio_binary": tts(tts_mdl, empty_res), "final": True}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@@ -538,21 +555,22 @@ def decorate_answer(answer):
)
if stream:
- last_ans = ""
- answer = ""
- async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
- if thought:
- ans = re.sub(r"^.* ", "", ans, flags=re.DOTALL)
- answer = ans
- delta_ans = ans[len(last_ans):]
- if num_tokens_from_string(delta_ans) < 16:
+ stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
+ last_state = None
+ async for kind, value, state in _stream_with_think_delta(stream_iter):
+ last_state = state
+ if kind == "marker":
+ flags = {"start_to_think": True} if value == "" else {"end_to_think": True}
+ yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
continue
- last_ans = answer
- yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
- delta_ans = answer[len(last_ans):]
- if delta_ans:
- yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
- yield decorate_answer(thought + answer)
+ yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
+ full_answer = last_state.full_text if last_state else ""
+ if full_answer:
+ final = decorate_answer(thought + full_answer)
+ final["final"] = True
+ final["audio_binary"] = None
+ final["answer"] = ""
+ yield final
else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
@@ -565,112 +583,362 @@ def decorate_answer(answer):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
- sys_prompt = """
-You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
-Ensure that:
-1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
-2. Write only the SQL, no explanations or additional text.
-"""
- user_prompt = """
-Table name: {};
-Table of database fields are as follows:
-{}
+ logging.debug(f"use_sql: Question: {question}")
-Question are as follows:
+ # Determine which document engine we're using
+ if settings.DOC_ENGINE_INFINITY:
+ doc_engine = "infinity"
+ elif settings.DOC_ENGINE_OCEANBASE:
+ doc_engine = "oceanbase"
+ else:
+ doc_engine = "es"
+
+ # Construct the full table name
+ # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
+ # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
+ base_table = index_name(tenant_id)
+ if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
+ # Infinity: append kb_id to table name
+ table_name = f"{base_table}_{kb_ids[0]}"
+ logging.debug(f"use_sql: Using Infinity table name: {table_name}")
+ else:
+ # Elasticsearch/OpenSearch: use base index name
+ table_name = base_table
+ logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
+
+ def is_row_count_question(q: str) -> bool:
+ q = (q or "").lower()
+ if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
+ return False
+ return bool(re.search(r"\bdataset\b|\btable\b|\bspreadsheet\b|\bexcel\b", q))
+
+ # Generate engine-specific SQL prompts
+ if doc_engine == "infinity":
+ # Build Infinity prompts with JSON extraction context
+ json_field_names = list(field_map.keys())
+ row_count_override = (
+ f"SELECT COUNT(*) AS rows FROM {table_name}"
+ if is_row_count_question(question)
+ else None
+ )
+ sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
+
+JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
+Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
+NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
+
+RULES:
+1. Use EXACT field names (case-sensitive) from the list below
+2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
+3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
+4. Add AS alias for extracted field names
+5. DO NOT select 'content' field
+6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
+ - Question asks to "show me" or "display" specific columns
+ - Question mentions "not null" or "excluding null"
+ - Add NULL check for count specific column
+ - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
+7. Output ONLY the SQL, no explanations"""
+ user_prompt = """Table: {}
+Fields (EXACT case): {}
{}
-Please write the SQL, only SQL, without any other explanations or text.
-""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
+Question: {}
+Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
+ table_name,
+ ", ".join(json_field_names),
+ "\n".join([f" - {field}" for field in json_field_names]),
+ question
+ )
+ elif doc_engine == "oceanbase":
+ # Build OceanBase prompts with JSON extraction context
+ json_field_names = list(field_map.keys())
+ row_count_override = (
+ f"SELECT COUNT(*) AS rows FROM {table_name}"
+ if is_row_count_question(question)
+ else None
+ )
+ sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
+
+JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
+Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
+NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
+
+RULES:
+1. Use EXACT field names (case-sensitive) from the list below
+2. For SELECT: include doc_id, docnm_kwd, and json_extract_string() for requested fields
+3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
+4. Add AS alias for extracted field names
+5. DO NOT select 'content' field
+6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
+ - Question asks to "show me" or "display" specific columns
+ - Question mentions "not null" or "excluding null"
+ - Add NULL check for count specific column
+ - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
+7. Output ONLY the SQL, no explanations"""
+ user_prompt = """Table: {}
+Fields (EXACT case): {}
+{}
+Question: {}
+Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
+ table_name,
+ ", ".join(json_field_names),
+ "\n".join([f" - {field}" for field in json_field_names]),
+ question
+ )
+ else:
+ # Build ES/OS prompts with direct field access
+ row_count_override = None
+ sys_prompt = """You are a Database Administrator. Write SQL queries.
+
+RULES:
+1. Use EXACT field names from the schema below (e.g., product_tks, not product)
+2. Quote field names starting with digit: "123_field"
+3. Add IS NOT NULL in WHERE clause when:
+ - Question asks to "show me" or "display" specific columns
+4. Include doc_id/docnm in non-aggregate statement
+5. Output ONLY the SQL, no explanations"""
+ user_prompt = """Table: {}
+Available fields:
+{}
+Question: {}
+Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
+ table_name,
+ "\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
+ question
+ )
+
tried_times = 0
async def get_table():
- nonlocal sys_prompt, user_prompt, question, tried_times
- sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
- sql = re.sub(r"^.* ", "", sql, flags=re.DOTALL)
- logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
- sql = re.sub(r"[\r\n]+", " ", sql.lower())
- sql = re.sub(r".*select ", "select ", sql.lower())
- sql = re.sub(r" +", " ", sql)
- sql = re.sub(r"([;;]|```).*", "", sql)
- sql = re.sub(r"&", "and", sql)
- if sql[: len("select ")] != "select ":
- return None, None
- if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
- if sql[: len("select *")] != "select *":
- sql = "select doc_id,docnm_kwd," + sql[6:]
+ nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override
+ if row_count_override:
+ sql = row_count_override
+ else:
+ sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
+ logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
+ # Remove think blocks if present (format: ...)
+ sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL)
+ sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
+ # Remove markdown code blocks (```sql ... ```)
+ sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
+ sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
+ # Remove trailing semicolon that ES SQL parser doesn't like
+ sql = sql.rstrip().rstrip(';').strip()
+
+ # Add kb_id filter for ES/OS only (Infinity already has it in table name)
+ if doc_engine != "infinity" and kb_ids:
+ # Build kb_filter: single KB or multiple KBs with OR
+ if len(kb_ids) == 1:
+ kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
- flds = []
- for k in field_map.keys():
- if k in forbidden_select_fields4resume:
- continue
- if len(flds) > 11:
- break
- flds.append(k)
- sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
-
- if kb_ids:
- kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
- if "where" not in sql.lower():
+ kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
+
+ if "where " not in sql.lower():
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
- else:
- sql += f" AND {kb_filter}"
+ elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
+ sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
- return settings.retriever.sql_retrieval(sql, format="json"), sql
+ logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
+ tbl = settings.retriever.sql_retrieval(sql, format="json")
+ if tbl is None:
+ logging.debug("use_sql: SQL retrieval returned None")
+ return None, sql
+ logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
+ return tbl, sql
try:
tbl, sql = await get_table()
+ logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
+ logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
except Exception as e:
- user_prompt = """
+ logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
+ # Build retry prompt with error information
+ if doc_engine in ("infinity", "oceanbase"):
+ # Build Infinity error retry prompt
+ json_field_names = list(field_map.keys())
+ user_prompt = """
+Table name: {};
+JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
+{}
+
+Question: {}
+Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
+
+
+The SQL error you provided last time is as follows:
+{}
+
+Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
+""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
+ else:
+ # Build ES/OS error retry prompt
+ user_prompt = """
Table name: {};
- Table of database fields are as follows:
+ Table of database fields are as follows (use the field names directly in SQL):
{}
Question are as follows:
{}
- Please write the SQL, only SQL, without any other explanations or text.
+ Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
- Please correct the error and write SQL again, only SQL, without any other explanations or text.
- """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
+ Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
+ """.format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
try:
tbl, sql = await get_table()
+ logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
+ logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
except Exception:
+ logging.error("use_sql: Retry SQL execution also FAILED, returning None")
return
if len(tbl["rows"]) == 0:
+ logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
return None
- docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
- doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
+ logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
+
+ docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
+ doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
+
+ logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
+ logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
+
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
+ logging.debug(f"use_sql: column_idx={column_idx}")
+ logging.debug(f"use_sql: field_map={field_map}")
+
+ # Helper function to map column names to display names
+ def map_column_name(col_name):
+ if col_name.lower() == "count(star)":
+ return "COUNT(*)"
+
+ # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
+ # Pattern: anything AS alias_name
+ as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
+ if as_match:
+ alias = as_match.group(1).strip('"\'')
+
+ # Use the alias for display name lookup
+ if alias in field_map:
+ display = field_map[alias]
+ return re.sub(r"(/.*|([^()]+))", "", display)
+ # If alias not in field_map, try to match case-insensitively
+ for field_key, display_value in field_map.items():
+ if field_key.lower() == alias.lower():
+ return re.sub(r"(/.*|([^()]+))", "", display_value)
+ # Return alias as-is if no mapping found
+ return alias
+
+ # Try direct mapping first (for simple column names)
+ if col_name in field_map:
+ display = field_map[col_name]
+ # Clean up any suffix patterns
+ return re.sub(r"(/.*|([^()]+))", "", display)
+
+ # Try case-insensitive match for simple column names
+ col_lower = col_name.lower()
+ for field_key, display_value in field_map.items():
+ if field_key.lower() == col_lower:
+ return re.sub(r"(/.*|([^()]+))", "", display_value)
+
+ # For aggregate expressions or complex expressions without AS alias,
+ # try to replace field names with display names
+ result = col_name
+ for field_name, display_name in field_map.items():
+ # Replace field_name with display_name in the expression
+ result = result.replace(field_name, display_name)
+
+ # Clean up any suffix patterns
+ result = re.sub(r"(/.*|([^()]+))", "", result)
+ return result
+
# compose Markdown table
columns = (
"|" + "|".join(
- [re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
- "|Source|" if docid_idx and docid_idx else "|")
+ [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
+ "|Source|" if docid_idx and doc_name_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
- rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
- rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
+ # Build rows ensuring column names match values - create a dict for each row
+ # keyed by column name to handle any SQL column order
+ rows = []
+ for row_idx, r in enumerate(tbl["rows"]):
+ row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
+ if row_idx == 0:
+ logging.debug(f"use_sql: First row data: {row_dict}")
+ row_values = []
+ for col_idx in column_idx:
+ col_name = tbl["columns"][col_idx]["name"]
+ value = row_dict.get(col_name, " ")
+ row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
+ # Add Source column with citation marker if Source column exists
+ if docid_idx and doc_name_idx:
+ row_values.append(f" ##{row_idx}$$")
+ row_str = "|" + "|".join(row_values) + "|"
+ if re.sub(r"[ |]+", "", row_str):
+ rows.append(row_str)
if quota:
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
+ rows = "\n".join(rows)
else:
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
+ rows = "\n".join(rows)
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
- logging.warning("SQL missing field: " + sql)
+ logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
+ # For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
+ # to provide source chunks, but keep the original table format answer
+ if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
+ # Keep original table format as answer
+ answer = "\n".join([columns, line, rows])
+
+ # Now fetch doc_id, docnm_kwd to provide source chunks
+ # Extract WHERE clause from the original SQL
+ where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
+ if where_match:
+ where_clause = where_match.group(1).strip()
+ # Build a query to get doc_id and docnm_kwd with the same WHERE clause
+ chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
+ # Add LIMIT to avoid fetching too many chunks
+ if "limit" not in chunks_sql.lower():
+ chunks_sql += " limit 20"
+ logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
+ try:
+ chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
+ if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
+ # Build chunks reference - use case-insensitive matching
+ chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
+ chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
+ if chunks_did_idx is not None and chunks_dn_idx is not None:
+ chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
+ # Build doc_aggs
+ doc_aggs = {}
+ for r in chunks_tbl["rows"]:
+ doc_id = r[chunks_did_idx]
+ doc_name = r[chunks_dn_idx]
+ if doc_id not in doc_aggs:
+ doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
+ doc_aggs[doc_id]["count"] += 1
+ doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
+ logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
+ return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
+ except Exception as e:
+ logging.warning(f"use_sql: Failed to fetch chunks: {e}")
+ # Fallback: return answer without chunks
+ return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
+ # Fallback to table format for other cases
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
@@ -680,7 +948,8 @@ async def get_table():
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
- return {
+
+ result = {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
@@ -688,6 +957,8 @@ async def get_table():
},
"prompt": sys_prompt,
}
+ logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
+ return result
def clean_tts_text(text: str) -> str:
if not text:
@@ -733,6 +1004,84 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
+
+class _ThinkStreamState:
+ def __init__(self) -> None:
+ self.full_text = ""
+ self.last_idx = 0
+ self.endswith_think = False
+ self.last_full = ""
+ self.last_model_full = ""
+ self.in_think = False
+ self.buffer = ""
+
+
+def _next_think_delta(state: _ThinkStreamState) -> str:
+ full_text = state.full_text
+ if full_text == state.last_full:
+ return ""
+ state.last_full = full_text
+ delta_ans = full_text[state.last_idx:]
+
+ if delta_ans.find("") == 0:
+ state.last_idx += len("")
+ return ""
+ if delta_ans.find("") > 0:
+ delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")]
+ state.last_idx += delta_ans.find("")
+ return delta_text
+ if delta_ans.endswith(" "):
+ state.endswith_think = True
+ elif state.endswith_think:
+ state.endswith_think = False
+ return " "
+
+ state.last_idx = len(full_text)
+ if full_text.endswith(" "):
+ state.last_idx -= len(" ")
+ return re.sub(r"(| )", "", delta_ans)
+
+
+async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
+ state = _ThinkStreamState()
+ async for chunk in stream_iter:
+ if not chunk:
+ continue
+ if chunk.startswith(state.last_model_full):
+ new_part = chunk[len(state.last_model_full):]
+ state.last_model_full = chunk
+ else:
+ new_part = chunk
+ state.last_model_full += chunk
+ if not new_part:
+ continue
+ state.full_text += new_part
+ delta = _next_think_delta(state)
+ if not delta:
+ continue
+ if delta in ("", " "):
+ if delta == "" and state.in_think:
+ continue
+ if delta == " " and not state.in_think:
+ continue
+ if state.buffer:
+ yield ("text", state.buffer, state)
+ state.buffer = ""
+ state.in_think = delta == ""
+ yield ("marker", delta, state)
+ continue
+ state.buffer += delta
+ if num_tokens_from_string(state.buffer) < min_tokens:
+ continue
+ yield ("text", state.buffer, state)
+ state.buffer = ""
+
+ if state.buffer:
+ yield ("text", state.buffer, state)
+ state.buffer = ""
+ if state.endswith_think:
+ yield ("marker", " ", state)
+
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@@ -755,10 +1104,10 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
- metas = DocumentService.get_meta_by_kbs(kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
- kbinfos = retriever.retrieval(
+ kbinfos = await retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
@@ -798,11 +1147,20 @@ def decorate_answer(answer):
refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs}
- answer = ""
- async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
- answer = ans
- yield {"answer": answer, "reference": {}}
- yield decorate_answer(answer)
+ stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
+ last_state = None
+ async for kind, value, state in _stream_with_think_delta(stream_iter):
+ last_state = state
+ if kind == "marker":
+ flags = {"start_to_think": True} if value == "" else {"end_to_think": True}
+ yield {"answer": "", "reference": {}, "final": False, **flags}
+ continue
+ yield {"answer": value, "reference": {}, "final": False}
+ full_answer = last_state.full_text if last_state else ""
+ final = decorate_answer(full_answer)
+ final["final"] = True
+ final["answer"] = ""
+ yield final
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
@@ -822,10 +1180,10 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
if meta_data_filter:
- metas = DocumentService.get_meta_by_kbs(kb_ids)
+ metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
- ranks = settings.retriever.retrieval(
+ ranks = await settings.retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py
new file mode 100644
index 00000000000..339d51c3086
--- /dev/null
+++ b/api/db/services/doc_metadata_service.py
@@ -0,0 +1,1081 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+Document Metadata Service
+
+Manages document-level metadata storage in ES/Infinity.
+This is the SOLE source of truth for document metadata - MySQL meta_fields column has been removed.
+"""
+
+import json
+import logging
+import re
+from copy import deepcopy
+from typing import Dict, List, Optional
+
+from api.db.db_models import DB, Document
+from common import settings
+from common.metadata_utils import dedupe_list
+from api.db.db_models import Knowledgebase
+from common.doc_store.doc_store_base import OrderByExpr
+
+
+class DocMetadataService:
+ """Service for managing document metadata in ES/Infinity"""
+
+ @staticmethod
+ def _get_doc_meta_index_name(tenant_id: str) -> str:
+ """
+ Get the index name for document metadata.
+
+ Args:
+ tenant_id: Tenant ID
+
+ Returns:
+ Index name for document metadata
+ """
+ return f"ragflow_doc_meta_{tenant_id}"
+
+ @staticmethod
+ def _extract_metadata(flat_meta: Dict) -> Dict:
+ """
+ Extract metadata from ES/Infinity document format.
+
+ Args:
+ flat_meta: Raw document from ES/Infinity with meta_fields field
+
+ Returns:
+ Simple metadata dictionary
+ """
+ if not flat_meta or not isinstance(flat_meta, dict):
+ return {}
+
+ meta_fields = flat_meta.get('meta_fields')
+ if not meta_fields:
+ return {}
+
+ # Parse JSON string if needed
+ if isinstance(meta_fields, str):
+ import json
+ try:
+ return json.loads(meta_fields)
+ except json.JSONDecodeError:
+ return {}
+
+ # Already a dict, return as-is
+ if isinstance(meta_fields, dict):
+ return meta_fields
+
+ return {}
+
+ @staticmethod
+ def _extract_doc_id(doc: Dict, hit: Dict = None) -> str:
+ """
+ Extract document ID from various formats.
+
+ Args:
+ doc: Document dictionary (from DataFrame or list format)
+ hit: Hit dictionary (from ES format with _id field)
+
+ Returns:
+ Document ID or empty string
+ """
+ if hit:
+ # ES format: doc is in _source, id is in _id
+ return hit.get('_id', '')
+ # DataFrame or list format: check multiple possible fields
+ return doc.get("doc_id") or doc.get("_id") or doc.get("id", "")
+
+ @classmethod
+ def _iter_search_results(cls, results):
+ """
+ Iterate over search results in various formats (DataFrame, ES, list).
+
+ Yields:
+ Tuple of (doc_id, doc_dict) for each document
+
+ Args:
+ results: Search results from ES/Infinity in any format
+ """
+ # Handle tuple return from Infinity: (DataFrame, int)
+ # Check this FIRST because pandas DataFrames also have __getitem__
+ if isinstance(results, tuple) and len(results) == 2:
+ results = results[0] # Extract DataFrame from tuple
+
+ # Check if results is a pandas DataFrame (from Infinity)
+ if hasattr(results, 'iterrows'):
+ # Handle pandas DataFrame - use iterrows() to iterate over rows
+ for _, row in results.iterrows():
+ doc = dict(row) # Convert Series to dict
+ doc_id = cls._extract_doc_id(doc)
+ if doc_id:
+ yield doc_id, doc
+
+ # Check if ES format (has 'hits' key)
+ # Note: ES returns ObjectApiResponse which is dict-like but not isinstance(dict)
+ elif hasattr(results, '__getitem__') and 'hits' in results:
+ # ES format: {"hits": {"hits": [{"_source": {...}, "_id": "..."}]}}
+ hits = results.get('hits', {}).get('hits', [])
+ for hit in hits:
+ doc = hit.get('_source', {})
+ doc_id = cls._extract_doc_id(doc, hit)
+ if doc_id:
+ yield doc_id, doc
+
+ # Handle list of dicts or other formats
+ elif isinstance(results, list):
+ for res in results:
+ if isinstance(res, dict):
+ docs = [res]
+ else:
+ docs = res
+
+ for doc in docs:
+ doc_id = cls._extract_doc_id(doc)
+ if doc_id:
+ yield doc_id, doc
+
+ @classmethod
+ def _search_metadata(cls, kb_id: str, condition: Dict = None, limit: int = 10000):
+ """
+ Common search logic for metadata queries.
+
+ Args:
+ kb_id: Knowledge base ID
+ condition: Optional search condition (defaults to {"kb_id": kb_id})
+ limit: Max results to return
+
+ Returns:
+ Search results from ES/Infinity, or empty list if index doesn't exist
+ """
+ kb = Knowledgebase.get_by_id(kb_id)
+ if not kb:
+ return []
+
+ tenant_id = kb.tenant_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ # Check if metadata index exists, create if it doesn't
+ if not settings.docStoreConn.index_exist(index_name, ""):
+ logging.debug(f"Metadata index {index_name} does not exist, creating it")
+ result = settings.docStoreConn.create_doc_meta_idx(index_name)
+ if result is False:
+ logging.error(f"Failed to create metadata index {index_name}")
+ return []
+ logging.debug(f"Successfully created metadata index {index_name}")
+
+ if condition is None:
+ condition = {"kb_id": kb_id}
+
+ order_by = OrderByExpr()
+
+ return settings.docStoreConn.search(
+ select_fields=["*"],
+ highlight_fields=[],
+ condition=condition,
+ match_expressions=[],
+ order_by=order_by,
+ offset=0,
+ limit=limit,
+ index_names=index_name,
+ knowledgebase_ids=[kb_id]
+ )
+
+ @classmethod
+ def _split_combined_values(cls, meta_fields: Dict) -> Dict:
+ """
+ Post-process metadata to split combined values by common delimiters.
+
+ For example: "关羽、孙权、张辽" -> ["关羽", "孙权", "张辽"]
+ This fixes LLM extraction where multiple values are extracted as one combined value.
+ Also removes duplicates after splitting.
+
+ Args:
+ meta_fields: Metadata dictionary
+
+ Returns:
+ Processed metadata with split values
+ """
+ if not meta_fields or not isinstance(meta_fields, dict):
+ return meta_fields
+
+ processed = {}
+ for key, value in meta_fields.items():
+ if isinstance(value, list):
+ # Process each item in the list
+ new_values = []
+ for item in value:
+ if isinstance(item, str):
+ # Split by common delimiters: Chinese comma (、), regular comma (,), pipe (|), semicolon (;), Chinese semicolon (;)
+ # Also handle mixed delimiters and spaces
+ split_items = re.split(r'[、,,;;|]+', item.strip())
+ # Trim whitespace and filter empty strings
+ split_items = [s.strip() for s in split_items if s.strip()]
+ if split_items:
+ new_values.extend(split_items)
+ else:
+ # Keep original if no split happened
+ new_values.append(item)
+ else:
+ new_values.append(item)
+ # Remove duplicates while preserving order
+ processed[key] = list(dict.fromkeys(new_values))
+ else:
+ processed[key] = value
+
+ if processed != meta_fields:
+ logging.debug(f"[METADATA SPLIT] Split combined values: {meta_fields} -> {processed}")
+ return processed
+
+ @classmethod
+ @DB.connection_context()
+ def insert_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool:
+ """
+ Insert document metadata into ES/Infinity.
+
+ Args:
+ doc_id: Document ID
+ meta_fields: Metadata dictionary
+
+ Returns:
+ True if successful, False otherwise
+ """
+ try:
+ # Get document with tenant_id (need to join with Knowledgebase)
+ doc_query = Document.select(Document, Knowledgebase.tenant_id).join(
+ Knowledgebase, on=(Knowledgebase.id == Document.kb_id)
+ ).where(Document.id == doc_id)
+
+ doc = doc_query.first()
+ if not doc:
+ logging.warning(f"Document {doc_id} not found for metadata insertion")
+ return False
+
+ # Extract document fields
+ doc_obj = doc # This is the Document object
+ tenant_id = doc.knowledgebase.tenant_id # Get tenant_id from joined Knowledgebase
+ kb_id = doc_obj.kb_id
+
+ # Prepare metadata document
+ doc_meta = {
+ "id": doc_obj.id,
+ "kb_id": kb_id,
+ }
+
+ # Store metadata as JSON object in meta_fields column (same as MySQL structure)
+ if meta_fields:
+ # Post-process to split combined values by common delimiters
+ meta_fields = cls._split_combined_values(meta_fields)
+ doc_meta["meta_fields"] = meta_fields
+ else:
+ doc_meta["meta_fields"] = {}
+
+ # Ensure index/table exists (per-tenant for both ES and Infinity)
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ # Check if table exists
+ table_exists = settings.docStoreConn.index_exist(index_name, kb_id)
+ logging.debug(f"Metadata table exists check: {index_name} -> {table_exists}")
+
+ # Create index if it doesn't exist
+ if not table_exists:
+ logging.debug(f"Creating metadata table: {index_name}")
+ # Both ES and Infinity now use per-tenant metadata tables
+ result = settings.docStoreConn.create_doc_meta_idx(index_name)
+ logging.debug(f"Table creation result: {result}")
+ if result is False:
+ logging.error(f"Failed to create metadata table {index_name}")
+ return False
+ else:
+ logging.debug(f"Metadata table already exists: {index_name}")
+
+ # Insert into ES/Infinity
+ result = settings.docStoreConn.insert(
+ [doc_meta],
+ index_name,
+ kb_id
+ )
+
+ if result:
+ logging.error(f"Failed to insert metadata for document {doc_id}: {result}")
+ return False
+ # Force ES refresh to make metadata immediately available for search
+ if not settings.DOC_ENGINE_INFINITY:
+ try:
+ settings.docStoreConn.es.indices.refresh(index=index_name)
+ logging.debug(f"Refreshed metadata index: {index_name}")
+ except Exception as e:
+ logging.warning(f"Failed to refresh metadata index {index_name}: {e}")
+
+ logging.debug(f"Successfully inserted metadata for document {doc_id}")
+ return True
+
+ except Exception as e:
+ logging.error(f"Error inserting metadata for document {doc_id}: {e}")
+ return False
+
+ @classmethod
+ @DB.connection_context()
+ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool:
+ """
+ Update document metadata in ES/Infinity.
+
+ For Elasticsearch: Uses partial update to directly update the meta_fields field.
+ For Infinity: Falls back to delete+insert (Infinity doesn't support partial updates well).
+
+ Args:
+ doc_id: Document ID
+ meta_fields: Metadata dictionary
+
+ Returns:
+ True if successful, False otherwise
+ """
+ try:
+ # Get document with tenant_id
+ doc_query = Document.select(Document, Knowledgebase.tenant_id).join(
+ Knowledgebase, on=(Knowledgebase.id == Document.kb_id)
+ ).where(Document.id == doc_id)
+
+ doc = doc_query.first()
+ if not doc:
+ logging.warning(f"Document {doc_id} not found for metadata update")
+ return False
+
+ # Extract fields
+ doc_obj = doc
+ tenant_id = doc.knowledgebase.tenant_id
+ kb_id = doc_obj.kb_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ # Post-process to split combined values
+ processed_meta = cls._split_combined_values(meta_fields)
+
+ logging.debug(f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}")
+
+ # For Elasticsearch, use efficient partial update
+ if not settings.DOC_ENGINE_INFINITY:
+ try:
+ # Use ES partial update API - much more efficient than delete+insert
+ settings.docStoreConn.es.update(
+ index=index_name,
+ id=doc_id,
+ refresh=True, # Make changes immediately visible
+ doc={"meta_fields": processed_meta}
+ )
+ logging.debug(f"Successfully updated metadata for document {doc_id} using ES partial update")
+ return True
+ except Exception as e:
+ logging.error(f"ES partial update failed for document {doc_id}: {e}")
+ # Fall back to delete+insert if partial update fails
+ logging.info(f"Falling back to delete+insert for document {doc_id}")
+
+ # For Infinity or as fallback: use delete+insert
+ logging.debug(f"[update_document_metadata] Using delete+insert method for doc_id: {doc_id}")
+ cls.delete_document_metadata(doc_id, skip_empty_check=True)
+ return cls.insert_document_metadata(doc_id, processed_meta)
+
+ except Exception as e:
+ logging.error(f"Error updating metadata for document {doc_id}: {e}")
+ return False
+
+ @classmethod
+ @DB.connection_context()
+ def delete_document_metadata(cls, doc_id: str, skip_empty_check: bool = False) -> bool:
+ """
+ Delete document metadata from ES/Infinity.
+ Also drops the metadata table if it becomes empty (efficiently).
+ If document has no metadata in the table, this is a no-op.
+
+ Args:
+ doc_id: Document ID
+ skip_empty_check: If True, skip checking/dropping empty table (for bulk deletions)
+
+ Returns:
+ True if successful (or no metadata to delete), False otherwise
+ """
+ try:
+ logging.debug(f"[METADATA DELETE] Starting metadata deletion for document: {doc_id}")
+ # Get document with tenant_id
+ doc_query = Document.select(Document, Knowledgebase.tenant_id).join(
+ Knowledgebase, on=(Knowledgebase.id == Document.kb_id)
+ ).where(Document.id == doc_id)
+
+ doc = doc_query.first()
+ if not doc:
+ logging.warning(f"Document {doc_id} not found for metadata deletion")
+ return False
+
+ tenant_id = doc.knowledgebase.tenant_id
+ kb_id = doc.kb_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+ logging.debug(f"[delete_document_metadata] Deleting doc_id: {doc_id}, kb_id: {kb_id}, index: {index_name}")
+
+ # Check if metadata table exists before attempting deletion
+ # This is the key optimization - no table = no metadata = nothing to delete
+ if not settings.docStoreConn.index_exist(index_name, ""):
+ logging.debug(f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}")
+ return True # No metadata to delete is considered success
+
+ # Try to get the metadata to confirm it exists before deleting
+ # This is more efficient than attempting delete on non-existent records
+ try:
+ existing_metadata = settings.docStoreConn.get(
+ doc_id,
+ index_name,
+ [""] # Empty list for metadata tables
+ )
+ logging.debug(f"[METADATA DELETE] Get result: {existing_metadata is not None}")
+ if not existing_metadata:
+ logging.debug(f"[METADATA DELETE] Document {doc_id} has no metadata in table, skipping deletion")
+ # Only check/drop table if not skipped (tenant deletion will handle it)
+ if not skip_empty_check:
+ cls._drop_empty_metadata_table(index_name, tenant_id)
+ return True # No metadata to delete is success
+ except Exception as e:
+ # If get fails, document might not exist in metadata table, which is fine
+ logging.error(f"[METADATA DELETE] Get failed: {e}")
+ # Continue to check/drop table if needed
+
+ # Delete from ES/Infinity (only if metadata exists)
+ # For metadata tables, pass kb_id for the delete operation
+ # The delete() method will detect it's a metadata table and skip the kb_id filter
+ logging.debug(f"[METADATA DELETE] Deleting metadata with condition: {{'id': '{doc_id}'}}")
+ deleted_count = settings.docStoreConn.delete(
+ {"id": doc_id},
+ index_name,
+ kb_id # Pass actual kb_id (delete() will handle metadata tables correctly)
+ )
+ logging.debug(f"[METADATA DELETE] Deleted count: {deleted_count}")
+
+ # Only check if table should be dropped if not skipped (for bulk operations)
+ # Note: delete operation already uses refresh=True, so data is immediately available
+ if not skip_empty_check:
+ # Check by querying the actual metadata table (not MySQL)
+ cls._drop_empty_metadata_table(index_name, tenant_id)
+
+ logging.debug(f"Successfully deleted metadata for document {doc_id}")
+ return True
+
+ except Exception as e:
+ logging.error(f"Error deleting metadata for document {doc_id}: {e}")
+ return False
+
+ @classmethod
+ def _drop_empty_metadata_table(cls, index_name: str, tenant_id: str) -> None:
+ """
+ Check if metadata table is empty and drop it if so.
+ Uses optimized count query instead of full search.
+ This prevents accumulation of empty metadata tables.
+
+ Args:
+ index_name: Metadata table/index name
+ tenant_id: Tenant ID
+ """
+ try:
+ logging.debug(f"[DROP EMPTY TABLE] Starting empty table check for: {index_name}")
+
+ # Check if table exists first (cheap operation)
+ if not settings.docStoreConn.index_exist(index_name, ""):
+ logging.debug(f"[DROP EMPTY TABLE] Metadata table {index_name} does not exist, skipping")
+ return
+
+ logging.debug(f"[DROP EMPTY TABLE] Table {index_name} exists, checking if empty...")
+
+ # Use ES count API for accurate count
+ # Note: No need to refresh since delete operation already uses refresh=True
+ try:
+ count_response = settings.docStoreConn.es.count(index=index_name)
+ total_count = count_response['count']
+ logging.debug(f"[DROP EMPTY TABLE] ES count API result: {total_count} documents")
+ is_empty = (total_count == 0)
+ except Exception as e:
+ logging.warning(f"[DROP EMPTY TABLE] Count API failed, falling back to search: {e}")
+ # Fallback to search if count fails
+ results = settings.docStoreConn.search(
+ select_fields=["id"],
+ highlight_fields=[],
+ condition={},
+ match_expressions=[],
+ order_by=OrderByExpr(),
+ offset=0,
+ limit=1, # Only need 1 result to know if table is non-empty
+ index_names=index_name,
+ knowledgebase_ids=[""] # Metadata tables don't filter by KB
+ )
+
+ logging.debug(f"[DROP EMPTY TABLE] Search results type: {type(results)}, results: {results}")
+
+ # Check if empty based on return type (fallback search only)
+ if isinstance(results, tuple) and len(results) == 2:
+ # Infinity returns (DataFrame, int)
+ df, total = results
+ logging.debug(f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}")
+ is_empty = (total == 0 or (hasattr(df, '__len__') and len(df) == 0))
+ elif hasattr(results, 'get') and 'hits' in results:
+ # ES format - MUST check this before hasattr(results, '__len__')
+ # because ES response objects also have __len__
+ total = results.get('hits', {}).get('total', {})
+ hits = results.get('hits', {}).get('hits', [])
+
+ # ES 7.x+: total is a dict like {'value': 0, 'relation': 'eq'}
+ # ES 6.x: total is an int
+ if isinstance(total, dict):
+ total_count = total.get('value', 0)
+ else:
+ total_count = total
+
+ logging.debug(f"[DROP EMPTY TABLE] ES format - total: {total_count}, hits count: {len(hits)}")
+ is_empty = (total_count == 0 or len(hits) == 0)
+ elif hasattr(results, '__len__'):
+ # DataFrame or list (check this AFTER ES format)
+ result_len = len(results)
+ logging.debug(f"[DROP EMPTY TABLE] List/DataFrame format - length: {result_len}")
+ is_empty = result_len == 0
+ else:
+ logging.warning(f"[DROP EMPTY TABLE] Unknown result format: {type(results)}")
+ is_empty = False
+
+ if is_empty:
+ logging.debug(f"[DROP EMPTY TABLE] Metadata table {index_name} is empty, dropping it")
+ drop_result = settings.docStoreConn.delete_idx(index_name, "")
+ logging.debug(f"[DROP EMPTY TABLE] Drop result: {drop_result}")
+ else:
+ logging.debug(f"[DROP EMPTY TABLE] Metadata table {index_name} still has documents, keeping it")
+
+ except Exception as e:
+ # Log but don't fail - metadata deletion was successful
+ logging.error(f"[DROP EMPTY TABLE] Failed to check/drop empty metadata table {index_name}: {e}")
+
+ @classmethod
+ @DB.connection_context()
+ def get_document_metadata(cls, doc_id: str) -> Dict:
+ """
+ Get document metadata from ES/Infinity.
+
+ Args:
+ doc_id: Document ID
+
+ Returns:
+ Metadata dictionary, empty dict if not found
+ """
+ try:
+ # Get document with tenant_id
+ doc_query = Document.select(Document, Knowledgebase.tenant_id).join(
+ Knowledgebase, on=(Knowledgebase.id == Document.kb_id)
+ ).where(Document.id == doc_id)
+
+ doc = doc_query.first()
+ if not doc:
+ logging.warning(f"Document {doc_id} not found")
+ return {}
+
+ # Extract fields
+ doc_obj = doc
+ tenant_id = doc.knowledgebase.tenant_id
+ kb_id = doc_obj.kb_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ # Try to get metadata from ES/Infinity
+ metadata_doc = settings.docStoreConn.get(
+ doc_id,
+ index_name,
+ [kb_id]
+ )
+
+ if metadata_doc:
+ # Extract and unflatten metadata
+ return cls._extract_metadata(metadata_doc)
+
+ return {}
+
+ except Exception as e:
+ logging.error(f"Error getting metadata for document {doc_id}: {e}")
+ return {}
+
+ @classmethod
+ @DB.connection_context()
+ def get_meta_by_kbs(cls, kb_ids: List[str]) -> Dict:
+ """
+ Get metadata for documents in knowledge bases (Legacy).
+
+ Legacy metadata aggregator (backward-compatible).
+ - Does NOT expand list values and a list is kept as one string key.
+ Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id]
+ - Expects meta_fields is a dict.
+ Use when existing callers rely on the old list-as-string semantics.
+
+ Args:
+ kb_ids: List of knowledge base IDs
+
+ Returns:
+ Metadata dictionary in format: {field_name: {value: [doc_ids]}}
+ """
+ try:
+ # Get tenant_id from first KB
+ kb = Knowledgebase.get_by_id(kb_ids[0])
+ if not kb:
+ return {}
+
+ tenant_id = kb.tenant_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ condition = {"kb_id": kb_ids}
+ order_by = OrderByExpr()
+
+ # Query with large limit
+ results = settings.docStoreConn.search(
+ select_fields=["*"],
+ highlight_fields=[],
+ condition=condition,
+ match_expressions=[],
+ order_by=order_by,
+ offset=0,
+ limit=10000,
+ index_names=index_name,
+ knowledgebase_ids=kb_ids
+ )
+
+ logging.debug(f"[get_meta_by_kbs] index_name: {index_name}, kb_ids: {kb_ids}")
+
+ # Aggregate metadata (legacy: keeps lists as string keys)
+ meta = {}
+
+ # Use helper to iterate over results in any format
+ for doc_id, doc in cls._iter_search_results(results):
+ # Extract metadata fields (exclude system fields)
+ doc_meta = cls._extract_metadata(doc)
+
+ # Legacy: Keep lists as string keys (do NOT expand)
+ for k, v in doc_meta.items():
+ if k not in meta:
+ meta[k] = {}
+ # If not list, make it a list
+ if not isinstance(v, list):
+ v = [v]
+ # Legacy: Use the entire list as a string key
+ # Skip nested lists/dicts
+ if isinstance(v, list) and any(isinstance(x, (list, dict)) for x in v):
+ continue
+ list_key = str(v)
+ if list_key not in meta[k]:
+ meta[k][list_key] = []
+ meta[k][list_key].append(doc_id)
+
+ logging.debug(f"[get_meta_by_kbs] KBs: {kb_ids}, Returning metadata: {meta}")
+ return meta
+
+ except Exception as e:
+ logging.error(f"Error getting metadata for KBs {kb_ids}: {e}")
+ return {}
+
+ @classmethod
+ @DB.connection_context()
+ def get_flatted_meta_by_kbs(cls, kb_ids: List[str]) -> Dict:
+ """
+ Get flattened metadata for documents in knowledge bases.
+
+ - Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values.
+ - Expands list values into individual entries.
+ Example: {"tags": ["foo","bar"], "author": "alice"} ->
+ meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id]
+ Prefer for metadata_condition filtering and scenarios that must respect list semantics.
+
+ Args:
+ kb_ids: List of knowledge base IDs
+
+ Returns:
+ Metadata dictionary in format: {field_name: {value: [doc_ids]}}
+ """
+ try:
+ # Get tenant_id from first KB
+ kb = Knowledgebase.get_by_id(kb_ids[0])
+ if not kb:
+ return {}
+
+ tenant_id = kb.tenant_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ condition = {"kb_id": kb_ids}
+ order_by = OrderByExpr()
+
+ # Query with large limit
+ results = settings.docStoreConn.search(
+ select_fields=["*"], # Get all fields
+ highlight_fields=[],
+ condition=condition,
+ match_expressions=[],
+ order_by=order_by,
+ offset=0,
+ limit=10000,
+ index_names=index_name,
+ knowledgebase_ids=kb_ids
+ )
+
+ logging.debug(f"[get_flatted_meta_by_kbs] index_name: {index_name}, kb_ids: {kb_ids}")
+ logging.debug(f"[get_flatted_meta_by_kbs] results type: {type(results)}")
+
+ # Aggregate metadata
+ meta = {}
+
+ # Use helper to iterate over results in any format
+ for doc_id, doc in cls._iter_search_results(results):
+ # Extract metadata fields (exclude system fields)
+ doc_meta = cls._extract_metadata(doc)
+
+ for k, v in doc_meta.items():
+ if k not in meta:
+ meta[k] = {}
+
+ values = v if isinstance(v, list) else [v]
+ for vv in values:
+ if vv is None:
+ continue
+ sv = str(vv)
+ if sv not in meta[k]:
+ meta[k][sv] = []
+ meta[k][sv].append(doc_id)
+
+ logging.debug(f"[get_flatted_meta_by_kbs] KBs: {kb_ids}, Returning metadata: {meta}")
+ return meta
+
+ except Exception as e:
+ logging.error(f"Error getting flattened metadata for KBs {kb_ids}: {e}")
+ return {}
+
+ @classmethod
+ def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dict[str, Dict]:
+ """
+ Get metadata fields for specific documents.
+ Returns a mapping of doc_id -> meta_fields
+
+ Args:
+ doc_ids: List of document IDs (if None, gets all documents with metadata for the KB)
+ kb_id: Knowledge base ID
+
+ Returns:
+ Dictionary mapping doc_id to meta_fields dict
+ """
+ try:
+ results = cls._search_metadata(kb_id, condition={"kb_id": kb_id})
+ if not results:
+ return {}
+
+ # Build mapping: doc_id -> meta_fields
+ meta_mapping = {}
+
+ # If doc_ids is provided, create a set for efficient lookup
+ doc_ids_set = set(doc_ids) if doc_ids else None
+
+ # Use helper to iterate over results in any format
+ for doc_id, doc in cls._iter_search_results(results):
+ # Filter by doc_ids if provided
+ if doc_ids_set is not None and doc_id not in doc_ids_set:
+ continue
+
+ # Extract metadata (handles both JSON strings and dicts)
+ doc_meta = cls._extract_metadata(doc)
+ if doc_meta:
+ meta_mapping[doc_id] = doc_meta
+
+ logging.debug(f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents")
+ return meta_mapping
+
+ except Exception as e:
+ logging.error(f"Error getting metadata for documents: {e}")
+ return {}
+
+ @classmethod
+ @DB.connection_context()
+ def get_metadata_summary(cls, kb_id: str, doc_ids=None) -> Dict:
+ """
+ Get metadata summary for documents in a knowledge base.
+
+ Args:
+ kb_id: Knowledge base ID
+ doc_ids: Optional list of document IDs to filter by
+
+ Returns:
+ Dictionary with metadata field statistics in format:
+ {
+ "field_name": {
+ "type": "string" | "number" | "list" | "time",
+ "values": [("value1", count1), ("value2", count2), ...] # sorted by count desc
+ }
+ }
+ """
+ def _is_time_string(value: str) -> bool:
+ """Check if a string value is an ISO 8601 datetime (e.g., '2026-02-03T00:00:00')."""
+ if not isinstance(value, str):
+ return False
+ return bool(re.match(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$', value))
+
+ def _meta_value_type(value):
+ """Determine the type of a metadata value."""
+ if value is None:
+ return None
+ if isinstance(value, list):
+ return "list"
+ if isinstance(value, bool):
+ return "string"
+ if isinstance(value, (int, float)):
+ return "number"
+ if isinstance(value, str) and _is_time_string(value):
+ return "time"
+ return "string"
+
+ try:
+ results = cls._search_metadata(kb_id, condition={"kb_id": kb_id})
+ if not results:
+ return {}
+
+ # If doc_ids are provided, we'll filter after the search
+ doc_ids_set = set(doc_ids) if doc_ids else None
+
+ # Aggregate metadata
+ summary = {}
+ type_counter = {}
+
+ logging.debug(f"[METADATA SUMMARY] KB: {kb_id}, doc_ids: {doc_ids}")
+
+ # Use helper to iterate over results in any format
+ for doc_id, doc in cls._iter_search_results(results):
+ # Check doc_ids filter
+ if doc_ids_set and doc_id not in doc_ids_set:
+ continue
+
+ doc_meta = cls._extract_metadata(doc)
+
+ for k, v in doc_meta.items():
+ # Track type counts for this field
+ value_type = _meta_value_type(v)
+ if value_type:
+ if k not in type_counter:
+ type_counter[k] = {}
+ type_counter[k][value_type] = type_counter[k].get(value_type, 0) + 1
+
+ # Aggregate value counts
+ values = v if isinstance(v, list) else [v]
+ for vv in values:
+ if vv is None:
+ continue
+ sv = str(vv)
+ if k not in summary:
+ summary[k] = {}
+ summary[k][sv] = summary[k].get(sv, 0) + 1
+
+ # Build result with type information and sorted values
+ result = {}
+ for k, v in summary.items():
+ values = sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True)
+ type_counts = type_counter.get(k, {})
+ value_type = "string"
+ if type_counts:
+ value_type = max(type_counts.items(), key=lambda item: item[1])[0]
+ result[k] = {"type": value_type, "values": values}
+
+ logging.debug(f"[METADATA SUMMARY] Final result: {result}")
+ return result
+
+ except Exception as e:
+ logging.error(f"Error getting metadata summary for KB {kb_id}: {e}")
+ return {}
+
+ @classmethod
+ @DB.connection_context()
+ def batch_update_metadata(cls, kb_id: str, doc_ids: List[str], updates=None, deletes=None) -> int:
+ """
+ Batch update metadata for documents in a knowledge base.
+
+ Args:
+ kb_id: Knowledge base ID
+ doc_ids: List of document IDs to update
+ updates: List of update operations, each with:
+ - key: field name to update
+ - value: new value
+ - match (optional): only update if current value matches this
+ deletes: List of delete operations, each with:
+ - key: field name to delete from
+ - value (optional): specific value to delete (if not provided, deletes the entire field)
+
+ Returns:
+ Number of documents updated
+
+ Examples:
+ updates = [{"key": "author", "value": "John"}]
+ updates = [{"key": "tags", "value": "new", "match": "old"}] # Replace "old" with "new" in tags list
+ deletes = [{"key": "author"}] # Delete entire author field
+ deletes = [{"key": "tags", "value": "obsolete"}] # Remove "obsolete" from tags list
+ """
+ updates = updates or []
+ deletes = deletes or []
+ if not doc_ids:
+ return 0
+
+ def _normalize_meta(meta):
+ """Normalize metadata to a dict."""
+ if isinstance(meta, str):
+ try:
+ meta = json.loads(meta)
+ except Exception:
+ return {}
+ if not isinstance(meta, dict):
+ return {}
+ return deepcopy(meta)
+
+ def _str_equal(a, b):
+ """Compare two values as strings."""
+ return str(a) == str(b)
+
+ def _apply_updates(meta):
+ """Apply update operations to metadata."""
+ changed = False
+ for upd in updates:
+ key = upd.get("key")
+ if not key:
+ continue
+
+ new_value = upd.get("value")
+ match_value = upd.get("match", None)
+ match_provided = match_value is not None and match_value != ""
+
+ if key not in meta:
+ if match_provided:
+ continue
+ meta[key] = dedupe_list(new_value) if isinstance(new_value, list) else new_value
+ changed = True
+ continue
+
+ if isinstance(meta[key], list):
+ if not match_provided:
+ # No match provided, append new_value to the list
+ if isinstance(new_value, list):
+ meta[key] = dedupe_list(meta[key] + new_value)
+ else:
+ meta[key] = dedupe_list(meta[key] + [new_value])
+ changed = True
+ else:
+ # Replace items matching match_value with new_value
+ replaced = False
+ new_list = []
+ for item in meta[key]:
+ if _str_equal(item, match_value):
+ new_list.append(new_value)
+ replaced = True
+ else:
+ new_list.append(item)
+ if replaced:
+ meta[key] = dedupe_list(new_list)
+ changed = True
+ else:
+ if not match_provided:
+ meta[key] = new_value
+ changed = True
+ else:
+ if _str_equal(meta[key], match_value):
+ meta[key] = new_value
+ changed = True
+ return changed
+
+ def _apply_deletes(meta):
+ """Apply delete operations to metadata."""
+ changed = False
+ for d in deletes:
+ key = d.get("key")
+ if not key or key not in meta:
+ continue
+ value = d.get("value", None)
+ if isinstance(meta[key], list):
+ if value is None:
+ del meta[key]
+ changed = True
+ continue
+ new_list = [item for item in meta[key] if not _str_equal(item, value)]
+ if len(new_list) != len(meta[key]):
+ if new_list:
+ meta[key] = new_list
+ else:
+ del meta[key]
+ changed = True
+ else:
+ if value is None or _str_equal(meta[key], value):
+ del meta[key]
+ changed = True
+ return changed
+
+ try:
+ results = cls._search_metadata(kb_id, condition=None)
+ if not results:
+ results = [] # Treat as empty list if None
+
+ updated_docs = 0
+ doc_ids_set = set(doc_ids)
+ found_doc_ids = set()
+
+ logging.debug(f"[batch_update_metadata] Searching for doc_ids: {doc_ids}")
+
+ # Use helper to iterate over results in any format
+ for doc_id, doc in cls._iter_search_results(results):
+ # Filter to only process requested doc_ids
+ if doc_id not in doc_ids_set:
+ continue
+
+ found_doc_ids.add(doc_id)
+
+ # Get current metadata
+ current_meta = cls._extract_metadata(doc)
+ meta = _normalize_meta(current_meta)
+ original_meta = deepcopy(meta)
+
+ logging.debug(f"[batch_update_metadata] Doc {doc_id}: current_meta={current_meta}, meta={meta}")
+ logging.debug(f"[batch_update_metadata] Updates to apply: {updates}, Deletes: {deletes}")
+
+ # Apply updates and deletes
+ changed = _apply_updates(meta)
+ logging.debug(f"[batch_update_metadata] After _apply_updates: changed={changed}, meta={meta}")
+ changed = _apply_deletes(meta) or changed
+ logging.debug(f"[batch_update_metadata] After _apply_deletes: changed={changed}, meta={meta}")
+
+ # Update if changed
+ if changed and meta != original_meta:
+ logging.debug(f"[batch_update_metadata] Updating doc_id: {doc_id}, meta: {meta}")
+ # If metadata is empty, delete the row entirely instead of keeping empty metadata
+ if not meta:
+ cls.delete_document_metadata(doc_id, skip_empty_check=True)
+ else:
+ cls.update_document_metadata(doc_id, meta)
+ updated_docs += 1
+
+ # Handle documents that don't have metadata rows yet
+ # These documents weren't in the search results, so we need to insert new metadata for them
+ missing_doc_ids = doc_ids_set - found_doc_ids
+ if missing_doc_ids and updates:
+ logging.debug(f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}")
+ for doc_id in missing_doc_ids:
+ # Apply updates to create new metadata
+ meta = {}
+ _apply_updates(meta)
+ if meta:
+ # Only insert if there's actual metadata to add
+ cls.update_document_metadata(doc_id, meta)
+ updated_docs += 1
+ logging.debug(f"[batch_update_metadata] Inserted metadata for doc_id: {doc_id}, meta: {meta}")
+
+ logging.debug(f"[batch_update_metadata] KB: {kb_id}, doc_ids: {doc_ids}, updated: {updated_docs}")
+ return updated_docs
+
+ except Exception as e:
+ logging.error(f"Error in batch_update_metadata for KB {kb_id}: {e}")
+ return 0
diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py
index a05d1783d9e..aa532af6250 100644
--- a/api/db/services/document_service.py
+++ b/api/db/services/document_service.py
@@ -33,7 +33,7 @@
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
-from common.metadata_utils import dedupe_list
+from api.db.services.doc_metadata_service import DocMetadataService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
@@ -67,7 +67,6 @@ def get_cls_model_fields(cls):
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
- cls.model.meta_fields,
cls.model.suffix,
cls.model.run,
cls.model.status,
@@ -110,7 +109,12 @@ def get_list(cls, kb_id, page_number, items_per_page,
count = docs.count()
docs = docs.paginate(page_number, items_per_page)
- return list(docs.dicts()), count
+
+ docs_list = list(docs.dicts())
+ metadata_map = DocMetadataService.get_metadata_for_documents(None, kb_id)
+ for doc in docs_list:
+ doc["meta_fields"] = metadata_map.get(doc["id"], {})
+ return docs_list, count
@classmethod
@DB.connection_context()
@@ -154,8 +158,11 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keyword
docs = docs.where(cls.model.type.in_(types))
if suffix:
docs = docs.where(cls.model.suffix.in_(suffix))
- if return_empty_metadata:
- docs = docs.where(fn.COALESCE(fn.JSON_LENGTH(cls.model.meta_fields), 0) == 0)
+
+ metadata_map = DocMetadataService.get_metadata_for_documents(None, kb_id)
+ doc_ids_with_metadata = set(metadata_map.keys())
+ if return_empty_metadata and doc_ids_with_metadata:
+ docs = docs.where(cls.model.id.not_in(doc_ids_with_metadata))
count = docs.count()
if desc:
@@ -166,7 +173,14 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keyword
if page_number and items_per_page:
docs = docs.paginate(page_number, items_per_page)
- return list(docs.dicts()), count
+ docs_list = list(docs.dicts())
+ if return_empty_metadata:
+ for doc in docs_list:
+ doc["meta_fields"] = {}
+ else:
+ for doc in docs_list:
+ doc["meta_fields"] = metadata_map.get(doc["id"], {})
+ return docs_list, count
@classmethod
@DB.connection_context()
@@ -212,7 +226,7 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
if suffix:
query = query.where(cls.model.suffix.in_(suffix))
- rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
+ rows = query.select(cls.model.run, cls.model.suffix, cls.model.id)
total = rows.count()
suffix_counter = {}
@@ -220,10 +234,18 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
metadata_counter = {}
empty_metadata_count = 0
+ doc_ids = [row.id for row in rows]
+ metadata = {}
+ if doc_ids:
+ try:
+ metadata = DocMetadataService.get_metadata_for_documents(doc_ids, kb_id)
+ except Exception as e:
+ logging.warning(f"Failed to fetch metadata from ES/Infinity: {e}")
+
for row in rows:
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
- meta_fields = row.meta_fields or {}
+ meta_fields = metadata.get(row.id, {})
if not meta_fields:
empty_metadata_count += 1
continue
@@ -338,16 +360,50 @@ def insert(cls, doc):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
- from api.db.services.task_service import TaskService
+ from api.db.services.task_service import TaskService, cancel_all_task_of
cls.clear_chunk_num(doc.id)
+
+ # Cancel all running tasks first Using preset function in task_service.py --- set cancel flag in Redis
+ try:
+ cancel_all_task_of(doc.id)
+ logging.info(f"Cancelled all tasks for document {doc.id}")
+ except Exception as e:
+ logging.warning(f"Failed to cancel tasks for document {doc.id}: {e}")
+
+ # Delete tasks from database
try:
TaskService.filter_delete([Task.doc_id == doc.id])
+ except Exception as e:
+ logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
+
+ # Delete chunk images (non-critical, log and continue)
+ try:
cls.delete_chunk_images(doc, tenant_id)
+ except Exception as e:
+ logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
+
+ # Delete thumbnail (non-critical, log and continue)
+ try:
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
+ except Exception as e:
+ logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
+
+ # Delete chunks from doc store - this is critical, log errors
+ try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
+ except Exception as e:
+ logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
+
+ # Delete document metadata (non-critical, log and continue)
+ try:
+ DocMetadataService.delete_document_metadata(doc.id)
+ except Exception as e:
+ logging.warning(f"Failed to delete metadata for document {doc.id}: {e}")
+ # Cleanup knowledge graph references (non-critical, log and continue)
+ try:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
@@ -360,8 +416,9 @@ def remove_document(cls, doc, tenant_id):
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
- except Exception:
- pass
+ except Exception as e:
+ logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
+
return cls.delete_by_id(doc.id)
@classmethod
@@ -423,6 +480,7 @@ def get_unfinished_docs(cls):
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
+ ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@@ -645,8 +703,7 @@ def dfs_update(old, new):
if k not in old:
old[k] = v
continue
- if isinstance(v, dict):
- assert isinstance(old[k], dict)
+ if isinstance(v, dict) and isinstance(old[k], dict):
dfs_update(old[k], v)
else:
old[k] = v
@@ -678,209 +735,6 @@ def begin2parse(cls, doc_id, keep_progress=False):
cls.update_by_id(doc_id, info)
- @classmethod
- @DB.connection_context()
- def update_meta_fields(cls, doc_id, meta_fields):
- return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
-
- @classmethod
- @DB.connection_context()
- def get_meta_by_kbs(cls, kb_ids):
- """
- Legacy metadata aggregator (backward-compatible).
- - Does NOT expand list values and a list is kept as one string key.
- Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id]
- - Expects meta_fields is a dict.
- Use when existing callers rely on the old list-as-string semantics.
- """
- fields = [
- cls.model.id,
- cls.model.meta_fields,
- ]
- meta = {}
- for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
- doc_id = r.id
- for k,v in r.meta_fields.items():
- if k not in meta:
- meta[k] = {}
- if not isinstance(v, list):
- v = [v]
- for vv in v:
- if vv not in meta[k]:
- if isinstance(vv, list) or isinstance(vv, dict):
- continue
- meta[k][vv] = []
- meta[k][vv].append(doc_id)
- return meta
-
- @classmethod
- @DB.connection_context()
- def get_flatted_meta_by_kbs(cls, kb_ids):
- """
- - Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values.
- - Expands list values into individual entries.
- Example: {"tags": ["foo","bar"], "author": "alice"} ->
- meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id]
- Prefer for metadata_condition filtering and scenarios that must respect list semantics.
- """
- fields = [
- cls.model.id,
- cls.model.meta_fields,
- ]
- meta = {}
- for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
- doc_id = r.id
- meta_fields = r.meta_fields or {}
- if isinstance(meta_fields, str):
- try:
- meta_fields = json.loads(meta_fields)
- except Exception:
- continue
- if not isinstance(meta_fields, dict):
- continue
- for k, v in meta_fields.items():
- if k not in meta:
- meta[k] = {}
- values = v if isinstance(v, list) else [v]
- for vv in values:
- if vv is None:
- continue
- sv = str(vv)
- if sv not in meta[k]:
- meta[k][sv] = []
- meta[k][sv].append(doc_id)
- return meta
-
- @classmethod
- @DB.connection_context()
- def get_metadata_summary(cls, kb_id):
- fields = [cls.model.id, cls.model.meta_fields]
- summary = {}
- for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
- meta_fields = r.meta_fields or {}
- if isinstance(meta_fields, str):
- try:
- meta_fields = json.loads(meta_fields)
- except Exception:
- continue
- if not isinstance(meta_fields, dict):
- continue
- for k, v in meta_fields.items():
- values = v if isinstance(v, list) else [v]
- for vv in values:
- if not vv:
- continue
- sv = str(vv)
- if k not in summary:
- summary[k] = {}
- summary[k][sv] = summary[k].get(sv, 0) + 1
- return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
-
- @classmethod
- @DB.connection_context()
- def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
- updates = updates or []
- deletes = deletes or []
- if not doc_ids:
- return 0
-
- def _normalize_meta(meta):
- if isinstance(meta, str):
- try:
- meta = json.loads(meta)
- except Exception:
- return {}
- if not isinstance(meta, dict):
- return {}
- return deepcopy(meta)
-
- def _str_equal(a, b):
- return str(a) == str(b)
-
- def _apply_updates(meta):
- changed = False
- for upd in updates:
- key = upd.get("key")
- if not key or key not in meta:
- continue
-
- new_value = upd.get("value")
- match_provided = "match" in upd
- if isinstance(meta[key], list):
- if not match_provided:
- if isinstance(new_value, list):
- meta[key] = dedupe_list(new_value)
- else:
- meta[key] = new_value
- changed = True
- else:
- match_value = upd.get("match")
- replaced = False
- new_list = []
- for item in meta[key]:
- if _str_equal(item, match_value):
- new_list.append(new_value)
- replaced = True
- else:
- new_list.append(item)
- if replaced:
- meta[key] = dedupe_list(new_list)
- changed = True
- else:
- if not match_provided:
- meta[key] = new_value
- changed = True
- else:
- match_value = upd.get("match")
- if _str_equal(meta[key], match_value):
- meta[key] = new_value
- changed = True
- return changed
-
- def _apply_deletes(meta):
- changed = False
- for d in deletes:
- key = d.get("key")
- if not key or key not in meta:
- continue
- value = d.get("value", None)
- if isinstance(meta[key], list):
- if value is None:
- del meta[key]
- changed = True
- continue
- new_list = [item for item in meta[key] if not _str_equal(item, value)]
- if len(new_list) != len(meta[key]):
- if new_list:
- meta[key] = new_list
- else:
- del meta[key]
- changed = True
- else:
- if value is None or _str_equal(meta[key], value):
- del meta[key]
- changed = True
- return changed
-
- updated_docs = 0
- with DB.atomic():
- rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
- (cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
- )
- for r in rows:
- meta = _normalize_meta(r.meta_fields or {})
- original_meta = deepcopy(meta)
- changed = _apply_updates(meta)
- changed = _apply_deletes(meta) or changed
- if changed and meta != original_meta:
- cls.model.update(
- meta_fields=meta,
- update_time=current_timestamp(),
- update_date=get_format_time()
- ).where(cls.model.id == r.id).execute()
- updated_docs += 1
- return updated_docs
-
@classmethod
@DB.connection_context()
def update_progress(cls):
@@ -914,6 +768,8 @@ def _sync_progress(cls, docs:list[dict]):
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
+ if status == TaskStatus.CANCEL.value:
+ continue
doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False
priority = 0
@@ -957,7 +813,16 @@ def _sync_progress(cls, docs:list[dict]):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
- cls.update_by_id(d["id"], info)
+ info["update_time"] = current_timestamp()
+ info["update_date"] = get_format_time()
+ (
+ cls.model.update(info)
+ .where(
+ (cls.model.id == d["id"])
+ & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
+ )
+ .execute()
+ )
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@@ -990,7 +855,7 @@ def do_cancel(cls, doc_id):
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
- # cancelled: run == "2" but progress can vary
+ # cancelled: run == "2"
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
@@ -1217,7 +1082,7 @@ def embedding(doc_id, cnts, batch_size=16):
cks = [c for c in docs if c["doc_id"] == doc_id]
if parser_ids[doc_id] != ParserType.PICTURE.value:
- from graphrag.general.mind_map_extractor import MindMapExtractor
+ from rag.graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
@@ -1245,7 +1110,7 @@ def embedding(doc_id, cnts, batch_size=16):
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id):
- settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
+ settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py
index 3f523b1d8c1..48255512f5a 100644
--- a/api/db/services/evaluation_service.py
+++ b/api/db/services/evaluation_service.py
@@ -225,21 +225,36 @@ def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tupl
"""
success_count = 0
failure_count = 0
+ case_instances = []
+
+ if not cases:
+ return success_count, failure_count
+
+ cur_timestamp = current_timestamp()
- for case_data in cases:
- success, _ = cls.add_test_case(
- dataset_id=dataset_id,
- question=case_data.get("question", ""),
- reference_answer=case_data.get("reference_answer"),
- relevant_doc_ids=case_data.get("relevant_doc_ids"),
- relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
- metadata=case_data.get("metadata")
- )
+ try:
+ for case_data in cases:
+ case_id = get_uuid()
+ case_info = {
+ "id": case_id,
+ "dataset_id": dataset_id,
+ "question": case_data.get("question", ""),
+ "reference_answer": case_data.get("reference_answer"),
+ "relevant_doc_ids": case_data.get("relevant_doc_ids"),
+ "relevant_chunk_ids": case_data.get("relevant_chunk_ids"),
+ "metadata": case_data.get("metadata"),
+ "create_time": cur_timestamp
+ }
+
+ case_instances.append(EvaluationCase(**case_info))
+ EvaluationCase.bulk_create(case_instances, batch_size=300)
+ success_count = len(case_instances)
+ failure_count = 0
- if success:
- success_count += 1
- else:
- failure_count += 1
+ except Exception as e:
+ logging.error(f"Error bulk importing test cases: {str(e)}")
+ failure_count = len(cases)
+ success_count = 0
return success_count, failure_count
diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py
index d6a157b2d1e..eba59a3cf22 100644
--- a/api/db/services/file_service.py
+++ b/api/db/services/file_service.py
@@ -439,6 +439,15 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str
err, files = [], []
for file in file_objs:
+ doc_id = file.id if hasattr(file, "id") else get_uuid()
+ e, doc = DocumentService.get_by_id(doc_id)
+ if e:
+ blob = file.read()
+ settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
+ doc.size = len(blob)
+ doc = doc.to_dict()
+ DocumentService.update_by_id(doc["id"], doc)
+ continue
try:
DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
@@ -455,7 +464,6 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str
blob = read_potential_broken_pdf(blob)
settings.STORAGE_IMPL.put(kb.id, location, blob)
- doc_id = get_uuid()
img = thumbnail_img(filename, blob)
thumbnail_location = ""
diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py
index 5f506888c0d..1f8b096daa3 100644
--- a/api/db/services/knowledgebase_service.py
+++ b/api/db/services/knowledgebase_service.py
@@ -397,7 +397,7 @@ def create_with_name(
if dataset_name == "":
return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
- return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
+ return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant
dataset_name = duplicate_name(
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index e5505af8849..db65ec8ecbb 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -441,3 +441,46 @@ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict =
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return
+
+ async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
+ total_tokens = 0
+ ans = ""
+ if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
+ stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
+ elif hasattr(self.mdl, "async_chat_streamly"):
+ stream_fn = getattr(self.mdl, "async_chat_streamly", None)
+ else:
+ raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
+
+ generation = None
+ if self.langfuse:
+ generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
+
+ if stream_fn:
+ chat_partial = partial(stream_fn, system, history, gen_conf)
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+ try:
+ async for txt in chat_partial(**use_kwargs):
+ if isinstance(txt, int):
+ total_tokens = txt
+ break
+
+ if txt.endswith(" "):
+ ans = ans[: -len(" ")]
+
+ if not self.verbose_tool_use:
+ txt = re.sub(r".*? ", "", txt, flags=re.DOTALL)
+
+ ans += txt
+ yield txt
+ except Exception as e:
+ if generation:
+ generation.update(output={"error": str(e)})
+ generation.end()
+ raise
+ if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+ if generation:
+ generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
+ generation.end()
+ return
diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py
index 8a65d15e24d..215a198fe27 100644
--- a/api/db/services/memory_service.py
+++ b/api/db/services/memory_service.py
@@ -167,4 +167,4 @@ def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
@classmethod
@DB.connection_context()
def delete_memory(cls, memory_id: str):
- return cls.model.delete().where(cls.model.id == memory_id).execute()
+ return cls.delete_by_id(memory_id)
diff --git a/api/db/services/system_settings_service.py b/api/db/services/system_settings_service.py
new file mode 100644
index 00000000000..eac7019e6a1
--- /dev/null
+++ b/api/db/services/system_settings_service.py
@@ -0,0 +1,44 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from datetime import datetime
+from common.time_utils import current_timestamp, datetime_format
+from api.db.db_models import DB
+from api.db.db_models import SystemSettings
+from api.db.services.common_service import CommonService
+
+
+class SystemSettingsService(CommonService):
+ model = SystemSettings
+
+ @classmethod
+ @DB.connection_context()
+ def get_by_name(cls, name):
+ objs = cls.model.select().where(cls.model.name == name)
+ return objs
+
+ @classmethod
+ @DB.connection_context()
+ def update_by_name(cls, name, obj):
+ obj["update_time"] = current_timestamp()
+ obj["update_date"] = datetime_format(datetime.now())
+ cls.model.update(obj).where(cls.model.name == name).execute()
+ return SystemSettings(**obj)
+
+ @classmethod
+ @DB.connection_context()
+ def get_record_count(cls):
+ count = cls.model.select().count()
+ return count
diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py
index 065d2376dd7..3975c0ec3fc 100644
--- a/api/db/services/task_service.py
+++ b/api/db/services/task_service.py
@@ -121,13 +121,6 @@ def get_task(cls, task_id, doc_ids=[]):
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
- # Assuming docs = list(docs.dicts())
- if docs:
- kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config
- mineru_method = kb_config.get('mineru_parse_method', 'auto')
- mineru_formula = kb_config.get('mineru_formula_enable', True)
- mineru_table = kb_config.get('mineru_table_enable', True)
- print(mineru_method, mineru_formula, mineru_table)
if not docs:
return None
@@ -179,6 +172,40 @@ def get_tasks(cls, doc_id: str):
return None
return tasks
+ @classmethod
+ @DB.connection_context()
+ def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]):
+ """Retrieve all tasks associated with specific documents.
+
+ This method fetches all processing tasks for given document ids, ordered by
+ creation time. It includes task progress and chunk information.
+
+ Args:
+ doc_ids (str): The unique identifier of the document.
+
+ Returns:
+ list[dict]: List of task dictionaries containing task details.
+ Returns None if no tasks are found.
+ """
+ fields = [
+ cls.model.id,
+ cls.model.doc_id,
+ cls.model.from_page,
+ cls.model.progress,
+ cls.model.progress_msg,
+ cls.model.digest,
+ cls.model.chunk_ids,
+ cls.model.create_time
+ ]
+ tasks = (
+ cls.model.select(*fields).order_by(cls.model.create_time.desc())
+ .where(cls.model.doc_id.in_(doc_ids))
+ )
+ tasks = list(tasks.dicts())
+ if not tasks:
+ return None
+ return tasks
+
@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
@@ -495,6 +522,7 @@ def cancel_all_task_of(doc_id):
def has_canceled(task_id):
try:
if REDIS_CONN.get(f"{task_id}-cancel"):
+ logging.info(f"Task: {task_id} has been canceled")
return True
except Exception as e:
logging.exception(e)
diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py
index 65771f60f41..5bd663734a8 100644
--- a/api/db/services/tenant_llm_service.py
+++ b/api/db/services/tenant_llm_service.py
@@ -19,7 +19,7 @@
from peewee import IntegrityError
from langfuse import Langfuse
from common import settings
-from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
+from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
@@ -60,10 +60,8 @@ def get_api_key(cls, tenant_id, model_name):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
- fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
- cls.model.used_tokens, cls.model.status]
- objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
- cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
+ fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
+ objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@@ -90,6 +88,7 @@ def split_model_name_and_factory(model_name):
@DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
from api.db.services.llm_service import LLMService
+
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
@@ -119,9 +118,9 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None):
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config:
model_config = model_config.to_dict()
- elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
+ elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
embedding_cfg = settings.EMBEDDING_CFG
- model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
+ model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
@@ -140,33 +139,27 @@ def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kw
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return None
- return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
- base_url=model_config["api_base"])
+ return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return None
- return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
- base_url=model_config["api_base"])
+ return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return None
- return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
- base_url=model_config["api_base"], **kwargs)
+ return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return None
- return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
- base_url=model_config["api_base"], **kwargs)
+ return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return None
- return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
- model_name=model_config["llm_name"], lang=lang,
- base_url=model_config["api_base"])
+ return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
elif llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return None
@@ -216,14 +209,11 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
try:
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
- .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
- cls.model.llm_factory == llm_factory if llm_factory else True)
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
- logging.exception(
- "TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
- tenant_id, llm_name)
+ logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@@ -231,9 +221,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
@classmethod
@DB.connection_context()
def get_openai_models(cls):
- objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
- ~(cls.model.llm_name == "text-embedding-3-small"),
- ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
+ objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@classmethod
@@ -298,6 +286,68 @@ def _parse_api_key(raw: str) -> dict:
idx += 1
continue
+ @classmethod
+ def _collect_paddleocr_env_config(cls) -> dict | None:
+ cfg = PADDLEOCR_DEFAULT_CONFIG
+ found = False
+ for key in PADDLEOCR_ENV_KEYS:
+ val = os.environ.get(key)
+ if val:
+ found = True
+ cfg[key] = val
+ return cfg if found else None
+
+ @classmethod
+ @DB.connection_context()
+ def ensure_paddleocr_from_env(cls, tenant_id: str) -> str | None:
+ """
+ Ensure a PaddleOCR model exists for the tenant if env variables are present.
+ Return the existing or newly created llm_name, or None if env not set.
+ """
+ cfg = cls._collect_paddleocr_env_config()
+ if not cfg:
+ return None
+
+ saved_paddleocr_models = cls.query(tenant_id=tenant_id, llm_factory="PaddleOCR", model_type=LLMType.OCR.value)
+
+ def _parse_api_key(raw: str) -> dict:
+ try:
+ return json.loads(raw or "{}")
+ except Exception:
+ return {}
+
+ for item in saved_paddleocr_models:
+ api_cfg = _parse_api_key(item.api_key)
+ normalized = {k: api_cfg.get(k, PADDLEOCR_DEFAULT_CONFIG.get(k)) for k in PADDLEOCR_ENV_KEYS}
+ if normalized == cfg:
+ return item.llm_name
+
+ used_names = {item.llm_name for item in saved_paddleocr_models}
+ idx = 1
+ base_name = "paddleocr-from-env"
+ while True:
+ candidate = f"{base_name}-{idx}"
+ if candidate in used_names:
+ idx += 1
+ continue
+
+ try:
+ cls.save(
+ tenant_id=tenant_id,
+ llm_factory="PaddleOCR",
+ llm_name=candidate,
+ model_type=LLMType.OCR.value,
+ api_key=json.dumps(cfg),
+ api_base="",
+ max_tokens=0,
+ )
+ return candidate
+ except IntegrityError:
+ logging.warning("PaddleOCR env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
+ used_names.add(candidate)
+ idx += 1
+ continue
+
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
@@ -306,6 +356,7 @@ def delete_by_tenant_id(cls, tenant_id):
@staticmethod
def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService
+
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
@@ -340,9 +391,12 @@ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs)
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None
if langfuse_keys:
- langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
- host=langfuse_keys.host)
- if langfuse.auth_check():
- self.langfuse = langfuse
- trace_id = self.langfuse.create_trace_id()
- self.trace_context = {"trace_id": trace_id}
+ langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
+ try:
+ if langfuse.auth_check():
+ self.langfuse = langfuse
+ trace_id = self.langfuse.create_trace_id()
+ self.trace_context = {"trace_id": trace_id}
+ except Exception:
+ # Skip langfuse tracing if connection fails
+ pass
diff --git a/api/ragflow_server.py b/api/ragflow_server.py
index 26cd045c4de..1beb0cd099c 100644
--- a/api/ragflow_server.py
+++ b/api/ragflow_server.py
@@ -18,8 +18,8 @@
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
-from common.log_utils import init_root_logger
-from plugin import GlobalPluginManager
+import time
+start_ts = time.time()
import logging
import os
@@ -40,6 +40,8 @@
from common.versions import get_ragflow_version
from common.config_utils import show_configs
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
+from common.log_utils import init_root_logger
+from agent.plugin import GlobalPluginManager
from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event()
@@ -145,7 +147,7 @@ def delayed_start_update_progress():
# start http server
try:
- logging.info("RAGFlow HTTP server start...")
+ logging.info(f"RAGFlow server is ready after {time.time() - start_ts}s initialization.")
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
except Exception:
traceback.print_exc()
diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py
index afb4ff772de..326fb62bc66 100644
--- a/api/utils/api_utils.py
+++ b/api/utils/api_utils.py
@@ -29,8 +29,15 @@
from quart import (
Response,
jsonify,
- request
+ request,
+ has_app_context,
)
+from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
+
+try:
+ from quart.exceptions import BadRequest as QuartBadRequest
+except ImportError: # pragma: no cover - optional dependency
+ QuartBadRequest = None
from peewee import OperationalError
@@ -42,41 +49,45 @@
from common.connection_utils import timeout
from common.constants import RetCode
from common import settings
+from common.misc_utils import thread_pool_exec
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
+def _safe_jsonify(payload: dict):
+ if has_app_context():
+ return jsonify(payload)
+ return payload
+
async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
+ if hasattr(request, "_cached_payload"):
+ return request._cached_payload
payload: Any = None
- last_error: Exception | None = None
-
- try:
- payload = await request.get_json(force=True, silent=True)
- except Exception as e:
- last_error = e
- payload = None
-
- if payload is None:
- try:
- form = await request.form
- payload = form.to_dict()
- except Exception as e:
- last_error = e
- payload = None
- if payload is None:
- if last_error is not None:
- raise last_error
- raise ValueError("No JSON body or form data found in request.")
-
- if isinstance(payload, dict):
- return payload or {}
-
- if isinstance(payload, str):
- raise AttributeError("'str' object has no attribute 'get'")
+ body_bytes = await request.get_data()
+ has_body = bool(body_bytes)
+ content_type = (request.content_type or "").lower()
+ is_json = content_type.startswith("application/json")
+
+ if not has_body:
+ payload = {}
+ elif is_json:
+ payload = await request.get_json(force=False, silent=False)
+ if isinstance(payload, dict):
+ payload = payload or {}
+ elif isinstance(payload, str):
+ raise AttributeError("'str' object has no attribute 'get'")
+ else:
+ raise TypeError("JSON payload must be an object.")
+ else:
+ form = await request.form
+ payload = form.to_dict() if form else None
+ if payload is None:
+ raise TypeError("Request body is not a valid form payload.")
- raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
+ request._cached_payload = payload
+ return payload
async def get_request_json():
return await _coerce_request_data()
@@ -115,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
continue
else:
response[key] = value
- return jsonify(response)
+ return _safe_jsonify(response)
def server_error_response(e):
@@ -124,16 +135,12 @@ def server_error_response(e):
try:
msg = repr(e).lower()
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
- return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
+ resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
+ resp.status_code = RetCode.UNAUTHORIZED
+ return resp
except Exception as ex:
logging.warning(f"error checking authorization: {ex}")
- if len(e.args) > 1:
- try:
- serialized_data = serialize_for_json(e.args[1])
- return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
- except Exception:
- return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
@@ -168,7 +175,17 @@ def process_args(input_arguments):
def wrapper(func):
@wraps(func)
async def decorated_function(*_args, **_kwargs):
- errs = process_args(await _coerce_request_data())
+ exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
+ if QuartBadRequest is not None:
+ exception_types = exception_types + (QuartBadRequest,)
+ if args or kwargs:
+ try:
+ input_arguments = await _coerce_request_data()
+ except exception_types:
+ input_arguments = {}
+ else:
+ input_arguments = await _coerce_request_data()
+ errs = process_args(input_arguments)
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):
@@ -215,7 +232,7 @@ async def wrapper(*args, **kwargs):
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data}
- return jsonify(response)
+ return _safe_jsonify(response)
def apikey_required(func):
@@ -236,16 +253,16 @@ async def decorated_function(*args, **kwargs):
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message}
- response = jsonify(response)
- response.status_code = code
+ response = _safe_jsonify(response)
+ if hasattr(response, "status_code"):
+ response.status_code = code
return response
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
if data is None:
- return jsonify({"code": code, "message": message})
- else:
- return jsonify({"code": code, "message": message, "data": data})
+ return _safe_jsonify({"code": code, "message": message})
+ return _safe_jsonify({"code": code, "message": message, "data": data})
def token_required(func):
@@ -304,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
else:
response["message"] = message or "Error"
- return jsonify(response)
+ return _safe_jsonify(response)
def get_error_data_result(
@@ -318,7 +335,7 @@ def get_error_data_result(
continue
else:
response[key] = value
- return jsonify(response)
+ return _safe_jsonify(response)
def get_error_argument_result(message="Invalid arguments"):
@@ -683,7 +700,7 @@ async def _is_strong_enough():
nonlocal chat_model, embedding_model
if embedding_model:
await asyncio.wait_for(
- asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
+ thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10
)
diff --git a/api/utils/common.py b/api/utils/common.py
index 958cf20ffc2..4d38c40d218 100644
--- a/api/utils/common.py
+++ b/api/utils/common.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import xxhash
+
def string_to_bytes(string):
return string if isinstance(
@@ -22,3 +24,6 @@ def string_to_bytes(string):
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
+# 128 bit = 32 character
+def hash128(data: str) -> str:
+ return xxhash.xxh128(data).hexdigest()
diff --git a/api/utils/crypt.py b/api/utils/crypt.py
index 174ca356835..d81cf7c6a1c 100644
--- a/api/utils/crypt.py
+++ b/api/utils/crypt.py
@@ -24,7 +24,7 @@
def crypt(line):
"""
- decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
+ decrypt(crypt(input_string)) == base64(input_string), which frontend and ragflow_cli use.
"""
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py
index 4cad64c35ce..e73c5d21850 100644
--- a/api/utils/file_utils.py
+++ b/api/utils/file_utils.py
@@ -84,28 +84,6 @@ def thumbnail_img(filename, blob):
buffered = BytesIO()
image.save(buffered, format="png")
return buffered.getvalue()
-
- elif re.match(r".*\.(ppt|pptx)$", filename):
- import aspose.pydrawing as drawing
- import aspose.slides as slides
-
- try:
- with slides.Presentation(BytesIO(blob)) as presentation:
- buffered = BytesIO()
- scale = 0.03
- img = None
- for _ in range(10):
- # https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float
- presentation.slides[0].get_thumbnail(scale, scale).save(buffered, drawing.imaging.ImageFormat.png)
- img = buffered.getvalue()
- if len(img) >= 64000:
- scale = scale / 2.0
- buffered = BytesIO()
- else:
- break
- return img
- except Exception:
- pass
return None
diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py
index 0a7ab6e7a6f..7456ed0f88a 100644
--- a/api/utils/health_utils.py
+++ b/api/utils/health_utils.py
@@ -23,6 +23,7 @@
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.es_conn import ESConnection
from rag.utils.infinity_conn import InfinityConnection
+from rag.utils.ob_conn import OBConnection
from common import settings
@@ -100,6 +101,121 @@ def get_infinity_status():
}
+def get_oceanbase_status():
+ """
+ Get OceanBase health status and performance metrics.
+
+ Returns:
+ dict: OceanBase status with health information and performance metrics
+ """
+ doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
+ if doc_engine != 'oceanbase':
+ raise Exception("OceanBase is not in use.")
+ try:
+ ob_conn = OBConnection()
+ health_info = ob_conn.health()
+ performance_metrics = ob_conn.get_performance_metrics()
+
+ # Combine health and performance metrics
+ status = "alive" if health_info.get("status") == "healthy" else "timeout"
+
+ return {
+ "status": status,
+ "message": {
+ "health": health_info,
+ "performance": performance_metrics
+ }
+ }
+ except Exception as e:
+ return {
+ "status": "timeout",
+ "message": f"error: {str(e)}",
+ }
+
+
+def check_oceanbase_health() -> dict:
+ """
+ Check OceanBase health status with comprehensive metrics.
+
+ This function provides detailed health information including:
+ - Connection status
+ - Query latency
+ - Storage usage
+ - Query throughput (QPS)
+ - Slow query statistics
+ - Connection pool statistics
+
+ Returns:
+ dict: Health status with detailed metrics
+ """
+ doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
+ if doc_engine != 'oceanbase':
+ return {
+ "status": "not_configured",
+ "details": {
+ "connection": "not_configured",
+ "message": "OceanBase is not configured as the document engine"
+ }
+ }
+
+ try:
+ ob_conn = OBConnection()
+ health_info = ob_conn.health()
+ performance_metrics = ob_conn.get_performance_metrics()
+
+ # Determine overall health status
+ connection_status = performance_metrics.get("connection", "unknown")
+
+ # If connection is disconnected, return unhealthy
+ if connection_status == "disconnected" or health_info.get("status") != "healthy":
+ return {
+ "status": "unhealthy",
+ "details": {
+ "connection": connection_status,
+ "latency_ms": performance_metrics.get("latency_ms", 0),
+ "storage_used": performance_metrics.get("storage_used", "N/A"),
+ "storage_total": performance_metrics.get("storage_total", "N/A"),
+ "query_per_second": performance_metrics.get("query_per_second", 0),
+ "slow_queries": performance_metrics.get("slow_queries", 0),
+ "active_connections": performance_metrics.get("active_connections", 0),
+ "max_connections": performance_metrics.get("max_connections", 0),
+ "uri": health_info.get("uri", "unknown"),
+ "version": health_info.get("version_comment", "unknown"),
+ "error": health_info.get("error", performance_metrics.get("error"))
+ }
+ }
+
+ # Check if healthy (connected and low latency)
+ is_healthy = (
+ connection_status == "connected" and
+ performance_metrics.get("latency_ms", float('inf')) < 1000 # Latency under 1 second
+ )
+
+ return {
+ "status": "healthy" if is_healthy else "degraded",
+ "details": {
+ "connection": performance_metrics.get("connection", "unknown"),
+ "latency_ms": performance_metrics.get("latency_ms", 0),
+ "storage_used": performance_metrics.get("storage_used", "N/A"),
+ "storage_total": performance_metrics.get("storage_total", "N/A"),
+ "query_per_second": performance_metrics.get("query_per_second", 0),
+ "slow_queries": performance_metrics.get("slow_queries", 0),
+ "active_connections": performance_metrics.get("active_connections", 0),
+ "max_connections": performance_metrics.get("max_connections", 0),
+ "uri": health_info.get("uri", "unknown"),
+ "version": health_info.get("version_comment", "unknown")
+ }
+ }
+ except Exception as e:
+ return {
+ "status": "unhealthy",
+ "details": {
+ "connection": "disconnected",
+ "error": str(e)
+ }
+ }
+
+
def get_mysql_status():
try:
cursor = DB.execute_sql("SHOW PROCESSLIST;")
diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py
index 2dcace53fe9..d6178e641f4 100644
--- a/api/utils/validation_utils.py
+++ b/api/utils/validation_utils.py
@@ -82,6 +82,8 @@ async def validate_and_parse_json_request(request: Request, validator: type[Base
2. Extra fields added via `extras` parameter are automatically removed
from the final output after validation
"""
+ if request.mimetype != "application/json":
+ return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
try:
payload = await request.get_json() or {}
except UnsupportedMediaType:
diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py
index 11e8428b77c..2d262293115 100644
--- a/api/utils/web_utils.py
+++ b/api/utils/web_utils.py
@@ -86,6 +86,9 @@
"ico": "image/x-icon",
"avif": "image/avif",
"heic": "image/heic",
+ # PPTX
+ "ppt": "application/vnd.ms-powerpoint",
+ "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
}
@@ -239,4 +242,4 @@ def hash_code(code: str, salt: bytes) -> str:
def captcha_key(email: str) -> str:
return f"captcha:{email}"
-
\ No newline at end of file
+
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 00000000000..5dd21786318
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,4 @@
+coverage:
+ status:
+ project: off
+ patch: off
\ No newline at end of file
diff --git a/common/constants.py b/common/constants.py
index 23a75505941..6a939cf4cfd 100644
--- a/common/constants.py
+++ b/common/constants.py
@@ -20,6 +20,7 @@
SERVICE_CONF = "service_conf.yaml"
RAG_FLOW_SERVICE_NAME = "ragflow"
+
class CustomEnum(Enum):
@classmethod
def valid(cls, value):
@@ -68,13 +69,13 @@ class ActiveEnum(Enum):
class LLMType(StrEnum):
- CHAT = 'chat'
- EMBEDDING = 'embedding'
- SPEECH2TEXT = 'speech2text'
- IMAGE2TEXT = 'image2text'
- RERANK = 'rerank'
- TTS = 'tts'
- OCR = 'ocr'
+ CHAT = "chat"
+ EMBEDDING = "embedding"
+ SPEECH2TEXT = "speech2text"
+ IMAGE2TEXT = "image2text"
+ RERANK = "rerank"
+ TTS = "tts"
+ OCR = "ocr"
class TaskStatus(StrEnum):
@@ -86,8 +87,7 @@ class TaskStatus(StrEnum):
SCHEDULE = "5"
-VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
- TaskStatus.SCHEDULE}
+VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
class ParserType(StrEnum):
@@ -133,6 +133,12 @@ class FileSource(StrEnum):
GITHUB = "github"
GITLAB = "gitlab"
IMAP = "imap"
+ BITBUCKET = "bitbucket"
+ ZENDESK = "zendesk"
+ SEAFILE = "seafile"
+ MYSQL = "mysql"
+ POSTGRESQL = "postgresql"
+
class PipelineTaskType(StrEnum):
PARSE = "Parse"
@@ -143,15 +149,17 @@ class PipelineTaskType(StrEnum):
MEMORY = "Memory"
-VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
- PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
+VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
+
class MCPServerType(StrEnum):
SSE = "sse"
STREAMABLE_HTTP = "streamable-http"
+
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
+
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
@@ -163,10 +171,10 @@ class Storage(Enum):
class MemoryType(Enum):
- RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
- SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
- EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
- PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
+ RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
+ SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
+ EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
+ PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
class MemoryStorageType(StrEnum):
@@ -237,3 +245,10 @@ class ForgettingPolicy(StrEnum):
"MINERU_SERVER_URL": "",
"MINERU_DELETE_OUTPUT": 1,
}
+
+PADDLEOCR_ENV_KEYS = ["PADDLEOCR_API_URL", "PADDLEOCR_ACCESS_TOKEN", "PADDLEOCR_ALGORITHM"]
+PADDLEOCR_DEFAULT_CONFIG = {
+ "PADDLEOCR_API_URL": "",
+ "PADDLEOCR_ACCESS_TOKEN": None,
+ "PADDLEOCR_ALGORITHM": "PaddleOCR-VL",
+}
diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py
index 2619e779dcd..74baaee016f 100644
--- a/common/data_source/__init__.py
+++ b/common/data_source/__init__.py
@@ -34,11 +34,13 @@
from .jira.connector import JiraConnector
from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector
-from .webdav_connector import WebDAVConnector
from .moodle_connector import MoodleConnector
from .airtable_connector import AirtableConnector
from .asana_connector import AsanaConnector
from .imap_connector import ImapConnector
+from .zendesk_connector import ZendeskConnector
+from .seafile_connector import SeaFileConnector
+from .rdbms_connector import RDBMSConnector
from .config import BlobType, DocumentSource
from .models import Document, TextSection, ImageSection, BasicExpertInfo
from .exceptions import (
@@ -61,7 +63,6 @@
"JiraConnector",
"SharePointConnector",
"TeamsConnector",
- "WebDAVConnector",
"MoodleConnector",
"BlobType",
"DocumentSource",
@@ -76,5 +77,8 @@
"UnexpectedValidationError",
"AirtableConnector",
"AsanaConnector",
- "ImapConnector"
+ "ImapConnector",
+ "ZendeskConnector",
+ "SeaFileConnector",
+ "RDBMSConnector",
]
diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py
index 6f0b5a930cd..46dcf07ee47 100644
--- a/common/data_source/airtable_connector.py
+++ b/common/data_source/airtable_connector.py
@@ -75,7 +75,6 @@ def load_from_state(self) -> GenerateDocumentsOutput:
batch: list[Document] = []
for record in records:
- print(record)
record_id = record.get("id")
fields = record.get("fields", {})
created_time = record.get("createdTime")
diff --git a/graphrag/__init__.py b/common/data_source/bitbucket/__init__.py
similarity index 100%
rename from graphrag/__init__.py
rename to common/data_source/bitbucket/__init__.py
diff --git a/common/data_source/bitbucket/connector.py b/common/data_source/bitbucket/connector.py
new file mode 100644
index 00000000000..f355a8945fc
--- /dev/null
+++ b/common/data_source/bitbucket/connector.py
@@ -0,0 +1,388 @@
+from __future__ import annotations
+
+import copy
+from collections.abc import Callable
+from collections.abc import Iterator
+from datetime import datetime
+from datetime import timezone
+from typing import Any
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from common.data_source.config import INDEX_BATCH_SIZE
+from common.data_source.config import DocumentSource
+from common.data_source.config import REQUEST_TIMEOUT_SECONDS
+from common.data_source.exceptions import (
+ ConnectorMissingCredentialError,
+ CredentialExpiredError,
+ InsufficientPermissionsError,
+ UnexpectedValidationError,
+)
+from common.data_source.interfaces import CheckpointedConnector
+from common.data_source.interfaces import CheckpointOutput
+from common.data_source.interfaces import IndexingHeartbeatInterface
+from common.data_source.interfaces import SecondsSinceUnixEpoch
+from common.data_source.interfaces import SlimConnectorWithPermSync
+from common.data_source.models import ConnectorCheckpoint
+from common.data_source.models import ConnectorFailure
+from common.data_source.models import DocumentFailure
+from common.data_source.models import SlimDocument
+from common.data_source.bitbucket.utils import (
+ build_auth_client,
+ list_repositories,
+ map_pr_to_document,
+ paginate,
+ PR_LIST_RESPONSE_FIELDS,
+ SLIM_PR_LIST_RESPONSE_FIELDS,
+)
+
+if TYPE_CHECKING:
+ import httpx
+
+
+class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
+ """Checkpoint state for resumable Bitbucket PR indexing.
+
+ Fields:
+ repos_queue: Materialized list of repository slugs to process.
+ current_repo_index: Index of the repository currently being processed.
+ next_url: Bitbucket "next" URL for continuing pagination within the current repo.
+ """
+
+ repos_queue: list[str] = []
+ current_repo_index: int = 0
+ next_url: str | None = None
+
+
+class BitbucketConnector(
+ CheckpointedConnector[BitbucketConnectorCheckpoint],
+ SlimConnectorWithPermSync,
+):
+ """Connector for indexing Bitbucket Cloud pull requests.
+
+ Args:
+ workspace: Bitbucket workspace ID.
+ repositories: Comma-separated list of repository slugs to index.
+ projects: Comma-separated list of project keys to index all repositories within.
+ batch_size: Max number of documents to yield per batch.
+ """
+
+ def __init__(
+ self,
+ workspace: str,
+ repositories: str | None = None,
+ projects: str | None = None,
+ batch_size: int = INDEX_BATCH_SIZE,
+ ) -> None:
+ self.workspace = workspace
+ self._repositories = (
+ [s.strip() for s in repositories.split(",") if s.strip()]
+ if repositories
+ else None
+ )
+ self._projects: list[str] | None = (
+ [s.strip() for s in projects.split(",") if s.strip()] if projects else None
+ )
+ self.batch_size = batch_size
+ self.email: str | None = None
+ self.api_token: str | None = None
+
+ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
+ """Load API token-based credentials.
+
+ Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`.
+ """
+ self.email = credentials.get("bitbucket_email")
+ self.api_token = credentials.get("bitbucket_api_token")
+ if not self.email or not self.api_token:
+ raise ConnectorMissingCredentialError("Bitbucket")
+ return None
+
+ def _client(self) -> httpx.Client:
+ """Build an authenticated HTTP client or raise if credentials missing."""
+ if not self.email or not self.api_token:
+ raise ConnectorMissingCredentialError("Bitbucket")
+ return build_auth_client(self.email, self.api_token)
+
+ def _iter_pull_requests_for_repo(
+ self,
+ client: httpx.Client,
+ repo_slug: str,
+ params: dict[str, Any] | None = None,
+ start_url: str | None = None,
+ on_page: Callable[[str | None], None] | None = None,
+ ) -> Iterator[dict[str, Any]]:
+ base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests"
+ yield from paginate(
+ client,
+ base,
+ params,
+ start_url=start_url,
+ on_page=on_page,
+ )
+
+ def _build_params(
+ self,
+ fields: str = PR_LIST_RESPONSE_FIELDS,
+ start: SecondsSinceUnixEpoch | None = None,
+ end: SecondsSinceUnixEpoch | None = None,
+ ) -> dict[str, Any]:
+ """Build Bitbucket fetch params.
+
+ Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and
+ ``end`` are provided, apply a single updated_on time window.
+ """
+
+ def _iso(ts: SecondsSinceUnixEpoch) -> str:
+ return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
+
+ def _tc_epoch(
+ lower_epoch: SecondsSinceUnixEpoch | None,
+ upper_epoch: SecondsSinceUnixEpoch | None,
+ ) -> str | None:
+ if lower_epoch is not None and upper_epoch is not None:
+ lower_iso = _iso(lower_epoch)
+ upper_iso = _iso(upper_epoch)
+ return f'(updated_on > "{lower_iso}" AND updated_on <= "{upper_iso}")'
+ return None
+
+ params: dict[str, Any] = {"fields": fields, "pagelen": 50}
+ time_clause = _tc_epoch(start, end)
+ q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")'
+ if time_clause:
+ q = f"{q} AND {time_clause}"
+ params["q"] = q
+ return params
+
+ def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]:
+ """Yield repository slugs based on configuration.
+
+ Priority:
+ - repositories list
+ - projects list (list repos by project key)
+ - workspace (all repos)
+ """
+ if self._repositories:
+ for slug in self._repositories:
+ yield slug
+ return
+ if self._projects:
+ for project_key in self._projects:
+ for repo in list_repositories(client, self.workspace, project_key):
+ slug_val = repo.get("slug")
+ if isinstance(slug_val, str) and slug_val:
+ yield slug_val
+ return
+ for repo in list_repositories(client, self.workspace, None):
+ slug_val = repo.get("slug")
+ if isinstance(slug_val, str) and slug_val:
+ yield slug_val
+
+ @override
+ def load_from_checkpoint(
+ self,
+ start: SecondsSinceUnixEpoch,
+ end: SecondsSinceUnixEpoch,
+ checkpoint: BitbucketConnectorCheckpoint,
+ ) -> CheckpointOutput[BitbucketConnectorCheckpoint]:
+ """Resumable PR ingestion across repos and pages within a time window.
+
+ Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns
+ an updated checkpoint that records repo position and next page URL.
+ """
+ new_checkpoint = copy.deepcopy(checkpoint)
+
+ with self._client() as client:
+ # Materialize target repositories once
+ if not new_checkpoint.repos_queue:
+ # Preserve explicit order; otherwise ensure deterministic ordering
+ repos_list = list(self._iter_target_repositories(client))
+ new_checkpoint.repos_queue = sorted(set(repos_list))
+ new_checkpoint.current_repo_index = 0
+ new_checkpoint.next_url = None
+
+ repos = new_checkpoint.repos_queue
+ if not repos or new_checkpoint.current_repo_index >= len(repos):
+ new_checkpoint.has_more = False
+ return new_checkpoint
+
+ repo_slug = repos[new_checkpoint.current_repo_index]
+
+ first_page_params = self._build_params(
+ fields=PR_LIST_RESPONSE_FIELDS,
+ start=start,
+ end=end,
+ )
+
+ def _on_page(next_url: str | None) -> None:
+ new_checkpoint.next_url = next_url
+
+ for pr in self._iter_pull_requests_for_repo(
+ client,
+ repo_slug,
+ params=first_page_params,
+ start_url=new_checkpoint.next_url,
+ on_page=_on_page,
+ ):
+ try:
+ document = map_pr_to_document(pr, self.workspace, repo_slug)
+ yield document
+ except Exception as e:
+ pr_id = pr.get("id")
+ pr_link = (
+ f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}"
+ if pr_id is not None
+ else None
+ )
+ yield ConnectorFailure(
+ failed_document=DocumentFailure(
+ document_id=(
+ f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}"
+ if pr_id is not None
+ else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown"
+ ),
+ document_link=pr_link,
+ ),
+ failure_message=f"Failed to process Bitbucket PR: {e}",
+ exception=e,
+ )
+
+ # Advance to next repository (if any) and set has_more accordingly
+ new_checkpoint.current_repo_index += 1
+ new_checkpoint.next_url = None
+ new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos)
+
+ return new_checkpoint
+
+ @override
+ def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint:
+ """Create an initial checkpoint with work remaining."""
+ return BitbucketConnectorCheckpoint(has_more=True)
+
+ @override
+ def validate_checkpoint_json(
+ self, checkpoint_json: str
+ ) -> BitbucketConnectorCheckpoint:
+ """Validate and deserialize a checkpoint instance from JSON."""
+ return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ start: SecondsSinceUnixEpoch | None = None,
+ end: SecondsSinceUnixEpoch | None = None,
+ callback: IndexingHeartbeatInterface | None = None,
+ ) -> Iterator[list[SlimDocument]]:
+ """Return only document IDs for all existing pull requests."""
+ batch: list[SlimDocument] = []
+ params = self._build_params(
+ fields=SLIM_PR_LIST_RESPONSE_FIELDS,
+ start=start,
+ end=end,
+ )
+ with self._client() as client:
+ for slug in self._iter_target_repositories(client):
+ for pr in self._iter_pull_requests_for_repo(
+ client, slug, params=params
+ ):
+ pr_id = pr["id"]
+ doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}"
+ batch.append(SlimDocument(id=doc_id))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+ if callback:
+ if callback.should_stop():
+ # Note: this is not actually used for permission sync yet, just pruning
+ raise RuntimeError(
+ "bitbucket_pr_sync: Stop signal detected"
+ )
+ callback.progress("bitbucket_pr_sync", len(batch))
+ if batch:
+ yield batch
+
+ def validate_connector_settings(self) -> None:
+ """Validate Bitbucket credentials and workspace access by probing a lightweight endpoint.
+
+ Raises:
+ CredentialExpiredError: on HTTP 401
+ InsufficientPermissionsError: on HTTP 403
+ UnexpectedValidationError: on any other failure
+ """
+ try:
+ with self._client() as client:
+ url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}"
+ resp = client.get(
+ url,
+ params={"pagelen": 1, "fields": "pagelen"},
+ timeout=REQUEST_TIMEOUT_SECONDS,
+ )
+ if resp.status_code == 401:
+ raise CredentialExpiredError(
+ "Invalid or expired Bitbucket credentials (HTTP 401)."
+ )
+ if resp.status_code == 403:
+ raise InsufficientPermissionsError(
+ "Insufficient permissions to access Bitbucket workspace (HTTP 403)."
+ )
+ if resp.status_code < 200 or resp.status_code >= 300:
+ raise UnexpectedValidationError(
+ f"Unexpected Bitbucket error (status={resp.status_code})."
+ )
+ except Exception as e:
+ # Network or other unexpected errors
+ if isinstance(
+ e,
+ (
+ CredentialExpiredError,
+ InsufficientPermissionsError,
+ UnexpectedValidationError,
+ ConnectorMissingCredentialError,
+ ),
+ ):
+ raise
+ raise UnexpectedValidationError(
+ f"Unexpected error while validating Bitbucket settings: {e}"
+ )
+
+if __name__ == "__main__":
+ bitbucket = BitbucketConnector(
+ workspace=""
+ )
+
+ bitbucket.load_credentials({
+ "bitbucket_email": "",
+ "bitbucket_api_token": "",
+ })
+
+ bitbucket.validate_connector_settings()
+ print("Credentials validated successfully.")
+
+ start_time = datetime.fromtimestamp(0, tz=timezone.utc)
+ end_time = datetime.now(timezone.utc)
+
+ for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(
+ start=start_time.timestamp(),
+ end=end_time.timestamp(),
+ ):
+ for doc in doc_batch:
+ print(doc)
+
+
+ bitbucket_checkpoint = bitbucket.build_dummy_checkpoint()
+
+ while bitbucket_checkpoint.has_more:
+ gen = bitbucket.load_from_checkpoint(
+ start=start_time.timestamp(),
+ end=end_time.timestamp(),
+ checkpoint=bitbucket_checkpoint,
+ )
+
+ while True:
+ try:
+ doc = next(gen)
+ print(doc)
+ except StopIteration as e:
+ bitbucket_checkpoint = e.value
+ break
+
\ No newline at end of file
diff --git a/common/data_source/bitbucket/utils.py b/common/data_source/bitbucket/utils.py
new file mode 100644
index 00000000000..4667a960066
--- /dev/null
+++ b/common/data_source/bitbucket/utils.py
@@ -0,0 +1,288 @@
+from __future__ import annotations
+
+import time
+from collections.abc import Callable
+from collections.abc import Iterator
+from datetime import datetime
+from datetime import timezone
+from typing import Any
+
+import httpx
+
+from common.data_source.config import REQUEST_TIMEOUT_SECONDS, DocumentSource
+from common.data_source.cross_connector_utils.rate_limit_wrapper import (
+ rate_limit_builder,
+)
+from common.data_source.utils import sanitize_filename
+from common.data_source.models import BasicExpertInfo, Document
+from common.data_source.cross_connector_utils.retry_wrapper import retry_builder
+
+# Fields requested from Bitbucket PR list endpoint to ensure rich PR data
+PR_LIST_RESPONSE_FIELDS: str = ",".join(
+ [
+ "next",
+ "page",
+ "pagelen",
+ "values.author",
+ "values.close_source_branch",
+ "values.closed_by",
+ "values.comment_count",
+ "values.created_on",
+ "values.description",
+ "values.destination",
+ "values.draft",
+ "values.id",
+ "values.links",
+ "values.merge_commit",
+ "values.participants",
+ "values.reason",
+ "values.rendered",
+ "values.reviewers",
+ "values.source",
+ "values.state",
+ "values.summary",
+ "values.task_count",
+ "values.title",
+ "values.type",
+ "values.updated_on",
+ ]
+)
+
+# Minimal fields for slim retrieval (IDs only)
+SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join(
+ [
+ "next",
+ "page",
+ "pagelen",
+ "values.id",
+ ]
+)
+
+
+# Minimal fields for repository list calls
+REPO_LIST_RESPONSE_FIELDS: str = ",".join(
+ [
+ "next",
+ "page",
+ "pagelen",
+ "values.slug",
+ "values.full_name",
+ "values.project.key",
+ ]
+)
+
+
+class BitbucketRetriableError(Exception):
+ """Raised for retriable Bitbucket conditions (429, 5xx)."""
+
+
+class BitbucketNonRetriableError(Exception):
+ """Raised for non-retriable Bitbucket client errors (4xx except 429)."""
+
+
+@retry_builder(
+ tries=6,
+ delay=1,
+ backoff=2,
+ max_delay=30,
+ exceptions=(BitbucketRetriableError, httpx.RequestError),
+)
+@rate_limit_builder(max_calls=60, period=60)
+def bitbucket_get(
+ client: httpx.Client, url: str, params: dict[str, Any] | None = None
+) -> httpx.Response:
+ """Perform a GET against Bitbucket with retry and rate limiting.
+
+ Retries on 429 and 5xx responses, and on transport errors. Honors
+ `Retry-After` header for 429 when present by sleeping before retrying.
+ """
+ try:
+ response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS)
+ except httpx.RequestError:
+ # Allow retry_builder to handle retries of transport errors
+ raise
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ status = e.response.status_code if e.response is not None else None
+ if status == 429:
+ retry_after = e.response.headers.get("Retry-After") if e.response else None
+ if retry_after is not None:
+ try:
+ time.sleep(int(retry_after))
+ except (TypeError, ValueError):
+ pass
+ raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e
+ if status is not None and 500 <= status < 600:
+ raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e
+ if status is not None and 400 <= status < 500:
+ raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e
+ # Unknown status, propagate
+ raise
+
+ return response
+
+
+def build_auth_client(email: str, api_token: str) -> httpx.Client:
+ """Create an authenticated httpx client for Bitbucket Cloud API."""
+ return httpx.Client(auth=(email, api_token), http2=True)
+
+
+def paginate(
+ client: httpx.Client,
+ url: str,
+ params: dict[str, Any] | None = None,
+ start_url: str | None = None,
+ on_page: Callable[[str | None], None] | None = None,
+) -> Iterator[dict[str, Any]]:
+ """Iterate over paginated Bitbucket API responses yielding individual values.
+
+ Args:
+ client: Authenticated HTTP client.
+ url: Base collection URL (first page when start_url is None).
+ params: Query params for the first page.
+ start_url: If provided, start from this absolute URL (ignores params).
+ on_page: Optional callback invoked after each page with the next page URL.
+ """
+ next_url = start_url or url
+ # If resuming from a next URL, do not pass params again
+ query = params.copy() if params else None
+ query = None if start_url else query
+ while next_url:
+ resp = bitbucket_get(client, next_url, params=query)
+ data = resp.json()
+ values = data.get("values", [])
+ for item in values:
+ yield item
+ next_url = data.get("next")
+ if on_page is not None:
+ on_page(next_url)
+ # only include params on first call, next_url will contain all necessary params
+ query = None
+
+
+def list_repositories(
+ client: httpx.Client, workspace: str, project_key: str | None = None
+) -> Iterator[dict[str, Any]]:
+ """List repositories in a workspace, optionally filtered by project key."""
+ base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}"
+ params: dict[str, Any] = {
+ "fields": REPO_LIST_RESPONSE_FIELDS,
+ "pagelen": 100,
+ # Ensure deterministic ordering
+ "sort": "full_name",
+ }
+ if project_key:
+ params["q"] = f'project.key="{project_key}"'
+ yield from paginate(client, base_url, params)
+
+
+def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document:
+ """Map a Bitbucket pull request JSON to Onyx Document."""
+ pr_id = pr["id"]
+ title = pr.get("title") or f"PR {pr_id}"
+ description = pr.get("description") or ""
+ state = pr.get("state")
+ draft = pr.get("draft", False)
+ author = pr.get("author", {})
+ reviewers = pr.get("reviewers", [])
+ participants = pr.get("participants", [])
+
+ link = pr.get("links", {}).get("html", {}).get("href") or (
+ f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}"
+ )
+
+ created_on = pr.get("created_on")
+ updated_on = pr.get("updated_on")
+ updated_dt = (
+ datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(
+ timezone.utc
+ )
+ if isinstance(updated_on, str)
+ else None
+ )
+
+ source_branch = pr.get("source", {}).get("branch", {}).get("name", "")
+ destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "")
+
+ approved_by = [
+ _get_user_name(p.get("user", {})) for p in participants if p.get("approved")
+ ]
+
+ primary_owner = None
+ if author:
+ primary_owner = BasicExpertInfo(
+ display_name=_get_user_name(author),
+ )
+
+ # secondary_owners = [
+ # BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers
+ # ] or None
+
+ reviewer_names = [_get_user_name(r) for r in reviewers]
+
+ # Create a concise summary of key PR info
+ created_date = created_on.split("T")[0] if created_on else "N/A"
+ updated_date = updated_on.split("T")[0] if updated_on else "N/A"
+ content_text = (
+ "Pull Request Information:\n"
+ f"- Pull Request ID: {pr_id}\n"
+ f"- Title: {title}\n"
+ f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n"
+ )
+ if state == "DECLINED":
+ content_text += f"- Reason: {pr.get('reason', 'N/A')}\n"
+ content_text += (
+ f"- Author: {_get_user_name(author) if author else 'N/A'}\n"
+ f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n"
+ f"- Branch: {source_branch} -> {destination_branch}\n"
+ f"- Created: {created_date}\n"
+ f"- Updated: {updated_date}"
+ )
+ if description:
+ content_text += f"\n\nDescription:\n{description}"
+
+ metadata: dict[str, str | list[str]] = {
+ "object_type": "PullRequest",
+ "workspace": workspace,
+ "repository": repo_slug,
+ "pr_key": f"{workspace}/{repo_slug}#{pr_id}",
+ "id": str(pr_id),
+ "title": title,
+ "state": state or "",
+ "draft": str(bool(draft)),
+ "link": link,
+ "author": _get_user_name(author) if author else "",
+ "reviewers": reviewer_names,
+ "approved_by": approved_by,
+ "comment_count": str(pr.get("comment_count", "")),
+ "task_count": str(pr.get("task_count", "")),
+ "created_on": created_on or "",
+ "updated_on": updated_on or "",
+ "source_branch": source_branch,
+ "destination_branch": destination_branch,
+ "closed_by": (
+ _get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""
+ ),
+ "close_source_branch": str(bool(pr.get("close_source_branch", False))),
+ }
+
+ name = sanitize_filename(title, "md")
+
+ return Document(
+ id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}",
+ blob=content_text.encode("utf-8"),
+ source=DocumentSource.BITBUCKET,
+ extension=".md",
+ semantic_identifier=f"#{pr_id}: {name}",
+ size_bytes=len(content_text.encode("utf-8")),
+ doc_updated_at=updated_dt,
+ primary_owners=[primary_owner] if primary_owner else None,
+ # secondary_owners=secondary_owners,
+ metadata=metadata,
+ )
+
+
+def _get_user_name(user: dict[str, Any]) -> str:
+ return user.get("display_name") or user.get("nickname") or "unknown"
\ No newline at end of file
diff --git a/common/data_source/config.py b/common/data_source/config.py
index bca13b5bed6..b05d8af24af 100644
--- a/common/data_source/config.py
+++ b/common/data_source/config.py
@@ -13,6 +13,9 @@ def get_current_tz_offset() -> int:
return round(time_diff.total_seconds() / 3600)
+# Default request timeout, mostly used by connectors
+REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
+
ONE_MINUTE = 60
ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24
@@ -58,8 +61,13 @@ class DocumentSource(str, Enum):
GITHUB = "github"
GITLAB = "gitlab"
IMAP = "imap"
+ BITBUCKET = "bitbucket"
+ ZENDESK = "zendesk"
+ SEAFILE = "seafile"
+ MYSQL = "mysql"
+ POSTGRESQL = "postgresql"
+
-
class FileOrigin(str, Enum):
"""File origins"""
CONNECTOR = "connector"
@@ -271,6 +279,10 @@ class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
+ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
+ "ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
+).split(",")
+
_USER_NOT_FOUND = "Unknown Confluence User"
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
diff --git a/graphrag/general/__init__.py b/common/data_source/cross_connector_utils/__init__.py
similarity index 100%
rename from graphrag/general/__init__.py
rename to common/data_source/cross_connector_utils/__init__.py
diff --git a/common/data_source/cross_connector_utils/rate_limit_wrapper.py b/common/data_source/cross_connector_utils/rate_limit_wrapper.py
new file mode 100644
index 00000000000..bc0e0b470d6
--- /dev/null
+++ b/common/data_source/cross_connector_utils/rate_limit_wrapper.py
@@ -0,0 +1,126 @@
+import time
+import logging
+from collections.abc import Callable
+from functools import wraps
+from typing import Any
+from typing import cast
+from typing import TypeVar
+
+import requests
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+class RateLimitTriedTooManyTimesError(Exception):
+ pass
+
+
+class _RateLimitDecorator:
+ """Builds a generic wrapper/decorator for calls to external APIs that
+ prevents making more than `max_calls` requests per `period`
+
+ Implementation inspired by the `ratelimit` library:
+ https://github.com/tomasbasham/ratelimit.
+
+ NOTE: is not thread safe.
+ """
+
+ def __init__(
+ self,
+ max_calls: int,
+ period: float, # in seconds
+ sleep_time: float = 2, # in seconds
+ sleep_backoff: float = 2, # applies exponential backoff
+ max_num_sleep: int = 0,
+ ):
+ self.max_calls = max_calls
+ self.period = period
+ self.sleep_time = sleep_time
+ self.sleep_backoff = sleep_backoff
+ self.max_num_sleep = max_num_sleep
+
+ self.call_history: list[float] = []
+ self.curr_calls = 0
+
+ def __call__(self, func: F) -> F:
+ @wraps(func)
+ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
+ # cleanup calls which are no longer relevant
+ self._cleanup()
+
+ # check if we've exceeded the rate limit
+ sleep_cnt = 0
+ while len(self.call_history) == self.max_calls:
+ sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
+ logging.warning(
+ f"Rate limit exceeded for function {func.__name__}. "
+ f"Waiting {sleep_time} seconds before retrying."
+ )
+ time.sleep(sleep_time)
+ sleep_cnt += 1
+ if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
+ raise RateLimitTriedTooManyTimesError(
+ f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
+ )
+
+ self._cleanup()
+
+ # add the current call to the call history
+ self.call_history.append(time.monotonic())
+ return func(*args, **kwargs)
+
+ return cast(F, wrapped_func)
+
+ def _cleanup(self) -> None:
+ curr_time = time.monotonic()
+ time_to_expire_before = curr_time - self.period
+ self.call_history = [
+ call_time
+ for call_time in self.call_history
+ if call_time > time_to_expire_before
+ ]
+
+
+rate_limit_builder = _RateLimitDecorator
+
+
+"""If you want to allow the external service to tell you when you've hit the rate limit,
+use the following instead"""
+
+R = TypeVar("R", bound=Callable[..., requests.Response])
+
+
+def wrap_request_to_handle_ratelimiting(
+ request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
+) -> R:
+ def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
+ for _ in range(max_waits):
+ response = request_fn(*args, **kwargs)
+ if response.status_code == 429:
+ try:
+ wait_time = int(
+ response.headers.get("Retry-After", default_wait_time_sec)
+ )
+ except ValueError:
+ wait_time = default_wait_time_sec
+
+ time.sleep(wait_time)
+ continue
+
+ return response
+
+ raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries")
+
+ return cast(R, wrapped_request)
+
+
+_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
+_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
+
+
+class _RateLimitedRequest:
+ get = _rate_limited_get
+ post = _rate_limited_post
+
+
+rl_requests = _RateLimitedRequest
\ No newline at end of file
diff --git a/common/data_source/cross_connector_utils/retry_wrapper.py b/common/data_source/cross_connector_utils/retry_wrapper.py
new file mode 100644
index 00000000000..a055847975d
--- /dev/null
+++ b/common/data_source/cross_connector_utils/retry_wrapper.py
@@ -0,0 +1,88 @@
+from collections.abc import Callable
+import logging
+from logging import Logger
+from typing import Any
+from typing import cast
+from typing import TypeVar
+import requests
+from retry import retry
+
+from common.data_source.config import REQUEST_TIMEOUT_SECONDS
+
+
+F = TypeVar("F", bound=Callable[..., Any])
+logger = logging.getLogger(__name__)
+
+def retry_builder(
+ tries: int = 20,
+ delay: float = 0.1,
+ max_delay: float | None = 60,
+ backoff: float = 2,
+ jitter: tuple[float, float] | float = 1,
+ exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
+) -> Callable[[F], F]:
+ """Builds a generic wrapper/decorator for calls to external APIs that
+ may fail due to rate limiting, flakes, or other reasons. Applies exponential
+ backoff with jitter to retry the call."""
+
+ def retry_with_default(func: F) -> F:
+ @retry(
+ tries=tries,
+ delay=delay,
+ max_delay=max_delay,
+ backoff=backoff,
+ jitter=jitter,
+ logger=cast(Logger, logger),
+ exceptions=exceptions,
+ )
+ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
+ return func(*args, **kwargs)
+
+ return cast(F, wrapped_func)
+
+ return retry_with_default
+
+
+def request_with_retries(
+ method: str,
+ url: str,
+ *,
+ data: dict[str, Any] | None = None,
+ headers: dict[str, Any] | None = None,
+ params: dict[str, Any] | None = None,
+ timeout: int = REQUEST_TIMEOUT_SECONDS,
+ stream: bool = False,
+ tries: int = 8,
+ delay: float = 1,
+ backoff: float = 2,
+) -> requests.Response:
+ @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
+ def _make_request() -> requests.Response:
+ response = requests.request(
+ method=method,
+ url=url,
+ data=data,
+ headers=headers,
+ params=params,
+ timeout=timeout,
+ stream=stream,
+ )
+ try:
+ response.raise_for_status()
+ except requests.exceptions.HTTPError:
+ logging.exception(
+ "Request failed:\n%s",
+ {
+ "method": method,
+ "url": url,
+ "data": data,
+ "headers": headers,
+ "params": params,
+ "timeout": timeout,
+ "stream": stream,
+ },
+ )
+ raise
+ return response
+
+ return _make_request()
\ No newline at end of file
diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py
index 2e6d5f2af93..6a9b96740bc 100644
--- a/common/data_source/github/connector.py
+++ b/common/data_source/github/connector.py
@@ -19,7 +19,7 @@
from github.PullRequest import PullRequest
from pydantic import BaseModel
from typing_extensions import override
-from common.data_source.google_util.util import sanitize_filename
+from common.data_source.utils import sanitize_filename
from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py
index e64db984714..1421f9f4bf1 100644
--- a/common/data_source/gmail_connector.py
+++ b/common/data_source/gmail_connector.py
@@ -8,10 +8,10 @@
from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
-from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
+from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, clean_string
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
-from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
+from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc, sanitize_filename
# Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
diff --git a/common/data_source/google_util/util.py b/common/data_source/google_util/util.py
index b1f0162a4cb..187c06d6d84 100644
--- a/common/data_source/google_util/util.py
+++ b/common/data_source/google_util/util.py
@@ -191,42 +191,6 @@ def get_credentials_from_env(email: str, oauth: bool = False, source="drive") ->
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
-def sanitize_filename(name: str, extension: str = "txt") -> str:
- """
- Soft sanitize for MinIO/S3:
- - Replace only prohibited characters with a space.
- - Preserve readability (no ugly underscores).
- - Collapse multiple spaces.
- """
- if name is None:
- return f"file.{extension}"
-
- name = str(name).strip()
-
- # Characters that MUST NOT appear in S3/MinIO object keys
- # Replace them with a space (not underscore)
- forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
- name = re.sub(forbidden, " ", name)
-
- # Replace slashes "/" (S3 interprets as folder) with space
- name = name.replace("/", " ")
-
- # Collapse multiple spaces into one
- name = re.sub(r"\s+", " ", name)
-
- # Trim both ends
- name = name.strip()
-
- # Enforce reasonable max length
- if len(name) > 200:
- base, ext = os.path.splitext(name)
- name = base[:180].rstrip() + ext
-
- if not os.path.splitext(name)[1]:
- name += f".{extension}"
-
- return name
-
def clean_string(text: str | None) -> str | None:
"""
diff --git a/common/data_source/imap_connector.py b/common/data_source/imap_connector.py
index f3371ee2493..acaba7e01ec 100644
--- a/common/data_source/imap_connector.py
+++ b/common/data_source/imap_connector.py
@@ -12,6 +12,7 @@
from enum import Enum
from typing import Any
from typing import cast
+import uuid
import bs4
from pydantic import BaseModel
@@ -635,7 +636,6 @@ def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
if __name__ == "__main__":
import time
- import uuid
from types import TracebackType
from common.data_source.utils import load_all_docs_from_checkpoint_connector
diff --git a/common/data_source/rdbms_connector.py b/common/data_source/rdbms_connector.py
new file mode 100644
index 00000000000..944bfdb551a
--- /dev/null
+++ b/common/data_source/rdbms_connector.py
@@ -0,0 +1,405 @@
+"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
+
+import hashlib
+import json
+import logging
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Any, Dict, Generator, Optional, Union
+
+from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE
+from common.data_source.exceptions import (
+ ConnectorMissingCredentialError,
+ ConnectorValidationError,
+)
+from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
+from common.data_source.models import Document
+
+
+class DatabaseType(str, Enum):
+ """Supported database types."""
+ MYSQL = "mysql"
+ POSTGRESQL = "postgresql"
+
+
+class RDBMSConnector(LoadConnector, PollConnector):
+ """
+ RDBMS connector for importing data from MySQL and PostgreSQL databases.
+
+ This connector allows users to:
+ 1. Connect to a MySQL or PostgreSQL database
+ 2. Execute a SQL query to extract data
+ 3. Map columns to content (for vectorization) and metadata
+ 4. Sync data in batch or incremental mode using a timestamp column
+ """
+ def __init__(
+ self,
+ db_type: str,
+ host: str,
+ port: int,
+ database: str,
+ query: str,
+ content_columns: str,
+ metadata_columns: Optional[str] = None,
+ id_column: Optional[str] = None,
+ timestamp_column: Optional[str] = None,
+ batch_size: int = INDEX_BATCH_SIZE,
+ ) -> None:
+ """
+ Initialize the RDBMS connector.
+
+ Args:
+ db_type: Database type ('mysql' or 'postgresql')
+ host: Database host
+ port: Database port
+ database: Database name
+ query: SQL query to execute (e.g., "SELECT * FROM products WHERE status = 'active'")
+ content_columns: Comma-separated column names to use for document content
+ metadata_columns: Comma-separated column names to use as metadata (optional)
+ id_column: Column to use as unique document ID (optional, will generate hash if not provided)
+ timestamp_column: Column to use for incremental sync (optional, must be datetime/timestamp type)
+ batch_size: Number of documents per batch
+ """
+ self.db_type = DatabaseType(db_type.lower())
+ self.host = host.strip()
+ self.port = port
+ self.database = database.strip()
+ self.query = query.strip()
+ self.content_columns = [c.strip() for c in content_columns.split(",") if c.strip()]
+ self.metadata_columns = [c.strip() for c in (metadata_columns or "").split(",") if c.strip()]
+ self.id_column = id_column.strip() if id_column else None
+ self.timestamp_column = timestamp_column.strip() if timestamp_column else None
+ self.batch_size = batch_size
+
+ self._connection = None
+ self._credentials: Dict[str, Any] = {}
+
+ def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
+ """Load database credentials."""
+ logging.debug(f"Loading credentials for {self.db_type} database: {self.database}")
+
+ required_keys = ["username", "password"]
+ for key in required_keys:
+ if not credentials.get(key):
+ raise ConnectorMissingCredentialError(f"RDBMS ({self.db_type}): missing {key}")
+
+ self._credentials = credentials
+ return None
+
+ def _get_connection(self):
+ """Create and return a database connection."""
+ if self._connection is not None:
+ return self._connection
+
+ username = self._credentials.get("username")
+ password = self._credentials.get("password")
+
+ if self.db_type == DatabaseType.MYSQL:
+ try:
+ import mysql.connector
+ except ImportError:
+ raise ConnectorValidationError(
+ "MySQL connector not installed. Please install mysql-connector-python."
+ )
+ try:
+ self._connection = mysql.connector.connect(
+ host=self.host,
+ port=self.port,
+ database=self.database,
+ user=username,
+ password=password,
+ charset='utf8mb4',
+ use_unicode=True,
+ )
+ except Exception as e:
+ raise ConnectorValidationError(f"Failed to connect to MySQL: {e}")
+ elif self.db_type == DatabaseType.POSTGRESQL:
+ try:
+ import psycopg2
+ except ImportError:
+ raise ConnectorValidationError(
+ "PostgreSQL connector not installed. Please install psycopg2-binary."
+ )
+ try:
+ self._connection = psycopg2.connect(
+ host=self.host,
+ port=self.port,
+ dbname=self.database,
+ user=username,
+ password=password,
+ )
+ except Exception as e:
+ raise ConnectorValidationError(f"Failed to connect to PostgreSQL: {e}")
+
+ return self._connection
+
+ def _close_connection(self):
+ """Close the database connection."""
+ if self._connection is not None:
+ try:
+ self._connection.close()
+ except Exception:
+ pass
+ self._connection = None
+
+ def _get_tables(self) -> list[str]:
+ """Get list of all tables in the database."""
+ connection = self._get_connection()
+ cursor = connection.cursor()
+
+ try:
+ if self.db_type == DatabaseType.MYSQL:
+ cursor.execute("SHOW TABLES")
+ else:
+ cursor.execute(
+ "SELECT table_name FROM information_schema.tables "
+ "WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
+ )
+ tables = [row[0] for row in cursor.fetchall()]
+ return tables
+ finally:
+ cursor.close()
+
+ def _build_query_with_time_filter(
+ self,
+ start: Optional[datetime] = None,
+ end: Optional[datetime] = None,
+ ) -> str:
+ """Build the query with optional time filtering for incremental sync."""
+ if not self.query:
+ return "" # Will be handled by table discovery
+ base_query = self.query.rstrip(";")
+
+ if not self.timestamp_column or (start is None and end is None):
+ return base_query
+
+ has_where = "where" in base_query.lower()
+ connector = " AND" if has_where else " WHERE"
+
+ time_conditions = []
+ if start is not None:
+ if self.db_type == DatabaseType.MYSQL:
+ time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
+ else:
+ time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
+
+ if end is not None:
+ if self.db_type == DatabaseType.MYSQL:
+ time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
+ else:
+ time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
+
+ if time_conditions:
+ return f"{base_query}{connector} {' AND '.join(time_conditions)}"
+
+ return base_query
+
+ def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
+ """Convert a database row to a Document."""
+ row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
+
+ content_parts = []
+ for col in self.content_columns:
+ if col in row_dict and row_dict[col] is not None:
+ value = row_dict[col]
+ if isinstance(value, (dict, list)):
+ value = json.dumps(value, ensure_ascii=False)
+ # Use brackets around field name to ensure it's distinguishable
+ # after chunking (TxtParser strips \n delimiters during merge)
+ content_parts.append(f"【{col}】: {value}")
+
+ content = "\n".join(content_parts)
+
+ if self.id_column and self.id_column in row_dict:
+ doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
+ else:
+ content_hash = hashlib.md5(content.encode()).hexdigest()
+ doc_id = f"{self.db_type}:{self.database}:{content_hash}"
+
+ metadata = {}
+ for col in self.metadata_columns:
+ if col in row_dict and row_dict[col] is not None:
+ value = row_dict[col]
+ if isinstance(value, datetime):
+ value = value.isoformat()
+ elif isinstance(value, (dict, list)):
+ value = json.dumps(value, ensure_ascii=False)
+ else:
+ value = str(value)
+ metadata[col] = value
+
+ doc_updated_at = datetime.now(timezone.utc)
+ if self.timestamp_column and self.timestamp_column in row_dict:
+ ts_value = row_dict[self.timestamp_column]
+ if isinstance(ts_value, datetime):
+ if ts_value.tzinfo is None:
+ doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
+ else:
+ doc_updated_at = ts_value
+
+ first_content_col = self.content_columns[0] if self.content_columns else "record"
+ semantic_id = str(row_dict.get(first_content_col, "database_record"))[:100]
+
+ return Document(
+ id=doc_id,
+ blob=content.encode("utf-8"),
+ source=DocumentSource(self.db_type.value),
+ semantic_identifier=semantic_id,
+ extension=".txt",
+ doc_updated_at=doc_updated_at,
+ size_bytes=len(content.encode("utf-8")),
+ metadata=metadata if metadata else None,
+ )
+
+ def _yield_documents_from_query(
+ self,
+ query: str,
+ ) -> Generator[list[Document], None, None]:
+ """Generate documents from a single query."""
+ connection = self._get_connection()
+ cursor = connection.cursor()
+
+ try:
+ logging.info(f"Executing query: {query[:200]}...")
+ cursor.execute(query)
+ column_names = [desc[0] for desc in cursor.description]
+
+ batch: list[Document] = []
+ for row in cursor:
+ try:
+ doc = self._row_to_document(row, column_names)
+ batch.append(doc)
+
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+ except Exception as e:
+ logging.warning(f"Error converting row to document: {e}")
+ continue
+
+ if batch:
+ yield batch
+
+ finally:
+ try:
+ cursor.fetchall()
+ except Exception:
+ pass
+ cursor.close()
+
+ def _yield_documents(
+ self,
+ start: Optional[datetime] = None,
+ end: Optional[datetime] = None,
+ ) -> Generator[list[Document], None, None]:
+ """Generate documents from database query results."""
+ if self.query:
+ query = self._build_query_with_time_filter(start, end)
+ yield from self._yield_documents_from_query(query)
+ else:
+ tables = self._get_tables()
+ logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
+ for table in tables:
+ query = f"SELECT * FROM {table}"
+ logging.info(f"Loading table: {table}")
+ yield from self._yield_documents_from_query(query)
+
+ self._close_connection()
+
+ def load_from_state(self) -> Generator[list[Document], None, None]:
+ """Load all documents from the database (full sync)."""
+ logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
+ return self._yield_documents()
+
+ def poll_source(
+ self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
+ ) -> Generator[list[Document], None, None]:
+ """Poll for new/updated documents since the last sync (incremental sync)."""
+ if not self.timestamp_column:
+ logging.warning(
+ "No timestamp column configured for incremental sync. "
+ "Falling back to full sync."
+ )
+ return self.load_from_state()
+
+ start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
+ end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
+
+ logging.debug(
+ f"Polling {self.db_type} database {self.database} "
+ f"from {start_datetime} to {end_datetime}"
+ )
+
+ return self._yield_documents(start_datetime, end_datetime)
+
+ def validate_connector_settings(self) -> None:
+ """Validate connector settings by testing the connection."""
+ if not self._credentials:
+ raise ConnectorMissingCredentialError("RDBMS credentials not loaded.")
+
+ if not self.host:
+ raise ConnectorValidationError("Database host is required.")
+
+ if not self.database:
+ raise ConnectorValidationError("Database name is required.")
+
+ if not self.content_columns:
+ raise ConnectorValidationError(
+ "At least one content column must be specified."
+ )
+
+ try:
+ connection = self._get_connection()
+ cursor = connection.cursor()
+
+ test_query = "SELECT 1"
+ cursor.execute(test_query)
+ cursor.fetchone()
+ cursor.close()
+
+ logging.info(f"Successfully connected to {self.db_type} database: {self.database}")
+
+ except ConnectorValidationError:
+ self._close_connection()
+ raise
+ except Exception as e:
+ self._close_connection()
+ raise ConnectorValidationError(
+ f"Failed to connect to {self.db_type} database: {str(e)}"
+ )
+ finally:
+ self._close_connection()
+
+
+if __name__ == "__main__":
+ import os
+
+ credentials_dict = {
+ "username": os.environ.get("DB_USERNAME", "root"),
+ "password": os.environ.get("DB_PASSWORD", ""),
+ }
+
+ connector = RDBMSConnector(
+ db_type="mysql",
+ host=os.environ.get("DB_HOST", "localhost"),
+ port=int(os.environ.get("DB_PORT", "3306")),
+ database=os.environ.get("DB_NAME", "test"),
+ query="SELECT * FROM products LIMIT 10",
+ content_columns="name,description",
+ metadata_columns="id,category,price",
+ id_column="id",
+ timestamp_column="updated_at",
+ )
+
+ try:
+ connector.load_credentials(credentials_dict)
+ connector.validate_connector_settings()
+
+ for batch in connector.load_from_state():
+ print(f"Batch of {len(batch)} documents:")
+ for doc in batch:
+ print(f" - {doc.id}: {doc.semantic_identifier}")
+ break
+
+ except Exception as e:
+ print(f"Error: {e}")
diff --git a/common/data_source/seafile_connector.py b/common/data_source/seafile_connector.py
new file mode 100644
index 00000000000..0181269e858
--- /dev/null
+++ b/common/data_source/seafile_connector.py
@@ -0,0 +1,390 @@
+"""SeaFile connector"""
+import logging
+from datetime import datetime, timezone
+from typing import Any, Optional
+
+from retry import retry
+
+from common.data_source.utils import (
+ get_file_ext,
+ rl_requests,
+)
+from common.data_source.config import (
+ DocumentSource,
+ INDEX_BATCH_SIZE,
+ BLOB_STORAGE_SIZE_THRESHOLD,
+)
+from common.data_source.exceptions import (
+ ConnectorMissingCredentialError,
+ ConnectorValidationError,
+ CredentialExpiredError,
+ InsufficientPermissionsError,
+)
+from common.data_source.interfaces import LoadConnector, PollConnector
+from common.data_source.models import (
+ Document,
+ SecondsSinceUnixEpoch,
+ GenerateDocumentsOutput,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SeaFileConnector(LoadConnector, PollConnector):
+ """SeaFile connector for syncing files from SeaFile servers"""
+
+ def __init__(
+ self,
+ seafile_url: str,
+ batch_size: int = INDEX_BATCH_SIZE,
+ include_shared: bool = True,
+ ) -> None:
+ """Initialize SeaFile connector.
+
+ Args:
+ seafile_url: Base URL of the SeaFile server (e.g., https://seafile.example.com)
+ batch_size: Number of documents to yield per batch
+ include_shared: Whether to include shared libraries
+ """
+
+ self.seafile_url = seafile_url.rstrip("/")
+ self.api_url = f"{self.seafile_url}/api2"
+ self.batch_size = batch_size
+ self.include_shared = include_shared
+ self.token: Optional[str] = None
+ self.current_user_email: Optional[str] = None
+ self.size_threshold: int = BLOB_STORAGE_SIZE_THRESHOLD
+
+ def _get_headers(self) -> dict[str, str]:
+ """Get authorization headers for API requests"""
+ if not self.token:
+ raise ConnectorMissingCredentialError("SeaFile token not set")
+ return {
+ "Authorization": f"Token {self.token}",
+ "Accept": "application/json",
+ }
+
+ def _make_get_request(self, endpoint: str, params: Optional[dict] = None):
+ """Make authenticated GET request"""
+ url = f"{self.api_url}/{endpoint.lstrip('/')}"
+ response = rl_requests.get(
+ url,
+ headers=self._get_headers(),
+ params=params,
+ timeout=60,
+ )
+ return response
+
+ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
+ """Load and validate SeaFile credentials.
+
+ Args:
+ credentials: Dictionary containing 'seafile_token' or 'username'/'password'
+
+ Returns:
+ None
+
+ Raises:
+ ConnectorMissingCredentialError: If required credentials are missing
+ """
+ logger.debug(f"Loading credentials for SeaFile server {self.seafile_url}")
+
+ token = credentials.get("seafile_token")
+ username = credentials.get("username")
+ password = credentials.get("password")
+
+ if token:
+ self.token = token
+ elif username and password:
+ self.token = self._authenticate_with_password(username, password)
+ else:
+ raise ConnectorMissingCredentialError(
+ "SeaFile requires 'seafile_token' or 'username'/'password' credentials"
+ )
+
+ # Validate token and get current user info
+ try:
+ self._validate_token()
+ except Exception as e:
+ raise CredentialExpiredError(f"SeaFile token validation failed: {e}")
+
+ return None
+
+ def _authenticate_with_password(self, username: str, password: str) -> str:
+ """Authenticate with username/password and return API token"""
+ try:
+ response = rl_requests.post(
+ f"{self.api_url}/auth-token/",
+ data={"username": username, "password": password},
+ timeout=30,
+ )
+ response.raise_for_status()
+ data = response.json()
+ token = data.get("token")
+ if not token:
+ raise CredentialExpiredError("No token returned from SeaFile")
+ return token
+ except Exception as e:
+ raise ConnectorMissingCredentialError(
+ f"Failed to authenticate with SeaFile: {e}"
+ )
+
+ def _validate_token(self) -> dict:
+ """Validate token by fetching account info"""
+ response = self._make_get_request("/account/info/")
+ response.raise_for_status()
+ account_info = response.json()
+ self.current_user_email = account_info.get("email")
+ logger.info(f"SeaFile authenticated as: {self.current_user_email}")
+ return account_info
+
+ def validate_connector_settings(self) -> None:
+ """Validate SeaFile connector settings"""
+ if self.token is None:
+ raise ConnectorMissingCredentialError("SeaFile credentials not loaded.")
+
+ if not self.seafile_url:
+ raise ConnectorValidationError("No SeaFile URL was provided.")
+
+ try:
+ account_info = self._validate_token()
+ if not account_info.get("email"):
+ raise InsufficientPermissionsError("Invalid SeaFile API response")
+
+ # Check if we can list libraries
+ libraries = self._get_libraries()
+ logger.info(f"SeaFile connection validated. Found {len(libraries)} libraries.")
+
+ except Exception as e:
+ status = None
+ resp = getattr(e, "response", None)
+ if resp is not None:
+ status = getattr(resp, "status_code", None)
+
+ if status == 401:
+ raise CredentialExpiredError("SeaFile token is invalid or expired.")
+ if status == 403:
+ raise InsufficientPermissionsError(
+ "Insufficient permissions to access SeaFile API."
+ )
+ raise ConnectorValidationError(f"SeaFile validation failed: {repr(e)}")
+
+ @retry(tries=3, delay=1, backoff=2)
+ def _get_libraries(self) -> list[dict]:
+ """Fetch all accessible libraries (repos)"""
+ response = self._make_get_request("/repos/")
+ response.raise_for_status()
+ libraries = response.json()
+
+ logger.debug(f"Found {len(libraries)} total libraries")
+
+ if not self.include_shared and self.current_user_email:
+ # Filter to only owned libraries
+ owned_libraries = [
+ lib for lib in libraries
+ if lib.get("owner") == self.current_user_email
+ or lib.get("owner_email") == self.current_user_email
+ ]
+ logger.debug(
+ f"Filtered to {len(owned_libraries)} owned libraries "
+ f"(excluded {len(libraries) - len(owned_libraries)} shared)"
+ )
+ return owned_libraries
+
+ return libraries
+
+ @retry(tries=3, delay=1, backoff=2)
+ def _get_directory_entries(self, repo_id: str, path: str = "/") -> list[dict]:
+ """Fetch directory entries for a given path"""
+ try:
+ response = self._make_get_request(
+ f"/repos/{repo_id}/dir/",
+ params={"p": path},
+ )
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ logger.warning(f"Error fetching directory {path} in repo {repo_id}: {e}")
+ return []
+
+ @retry(tries=3, delay=1, backoff=2)
+ def _get_file_download_link(self, repo_id: str, path: str) -> Optional[str]:
+ """Get download link for a file"""
+ try:
+ response = self._make_get_request(
+ f"/repos/{repo_id}/file/",
+ params={"p": path, "reuse": 1},
+ )
+ response.raise_for_status()
+ return response.text.strip('"')
+ except Exception as e:
+ logger.warning(f"Error getting download link for {path}: {e}")
+ return None
+
+ def _list_files_recursive(
+ self,
+ repo_id: str,
+ repo_name: str,
+ path: str,
+ start: datetime,
+ end: datetime,
+ ) -> list[tuple[str, dict, dict]]:
+ """Recursively list all files in the given path within time range.
+
+ Returns:
+ List of tuples: (file_path, file_entry, library_info)
+ """
+ files = []
+ entries = self._get_directory_entries(repo_id, path)
+
+ for entry in entries:
+ entry_type = entry.get("type")
+ entry_name = entry.get("name", "")
+ entry_path = f"{path.rstrip('/')}/{entry_name}"
+
+ if entry_type == "dir":
+ # Recursively process subdirectories
+ files.extend(
+ self._list_files_recursive(repo_id, repo_name, entry_path, start, end)
+ )
+ elif entry_type == "file":
+ # Check modification time
+ mtime = entry.get("mtime", 0)
+ if mtime:
+ modified = datetime.fromtimestamp(mtime, tz=timezone.utc)
+ if start < modified <= end:
+ files.append((entry_path, entry, {"id": repo_id, "name": repo_name}))
+
+ return files
+
+ def _yield_seafile_documents(
+ self,
+ start: datetime,
+ end: datetime,
+ ) -> GenerateDocumentsOutput:
+ """Generate documents from SeaFile server.
+
+ Args:
+ start: Start datetime for filtering
+ end: End datetime for filtering
+
+ Yields:
+ Batches of documents
+ """
+ logger.info(f"Searching for files between {start} and {end}")
+
+ libraries = self._get_libraries()
+ logger.info(f"Processing {len(libraries)} libraries")
+
+ all_files = []
+ for lib in libraries:
+ repo_id = lib.get("id")
+ repo_name = lib.get("name", "Unknown")
+
+ if not repo_id:
+ continue
+
+ logger.debug(f"Scanning library: {repo_name}")
+ try:
+ files = self._list_files_recursive(repo_id, repo_name, "/", start, end)
+ all_files.extend(files)
+ logger.debug(f"Found {len(files)} files in {repo_name}")
+ except Exception as e:
+ logger.error(f"Error processing library {repo_name}: {e}")
+
+ logger.info(f"Found {len(all_files)} total files matching time criteria")
+
+ batch: list[Document] = []
+ for file_path, file_entry, library in all_files:
+ file_name = file_entry.get("name", "")
+ file_size = file_entry.get("size", 0)
+ file_id = file_entry.get("id", "")
+ mtime = file_entry.get("mtime", 0)
+ repo_id = library["id"]
+ repo_name = library["name"]
+
+ # Skip files that are too large
+ if file_size > self.size_threshold:
+ logger.warning(
+ f"Skipping large file: {file_path} ({file_size} bytes)"
+ )
+ continue
+
+ try:
+ # Get download link
+ download_link = self._get_file_download_link(repo_id, file_path)
+ if not download_link:
+ logger.warning(f"Could not get download link for {file_path}")
+ continue
+
+ # Download file content
+ logger.debug(f"Downloading: {file_path}")
+ response = rl_requests.get(download_link, timeout=120)
+ response.raise_for_status()
+ blob = response.content
+
+ if not blob:
+ logger.warning(f"Downloaded content is empty for {file_path}")
+ continue
+
+ # Build semantic identifier
+ semantic_id = f"{repo_name}{file_path}"
+
+ # Get modification time
+ modified = datetime.fromtimestamp(mtime, tz=timezone.utc) if mtime else datetime.now(timezone.utc)
+
+ batch.append(
+ Document(
+ id=f"seafile:{repo_id}:{file_id}",
+ blob=blob,
+ source=DocumentSource.SEAFILE,
+ semantic_identifier=semantic_id,
+ extension=get_file_ext(file_name),
+ doc_updated_at=modified,
+ size_bytes=len(blob),
+ )
+ )
+
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ except Exception as e:
+ logger.error(f"Error downloading file {file_path}: {e}")
+
+ if batch:
+ yield batch
+
+ def load_from_state(self) -> GenerateDocumentsOutput:
+ """Load all documents from SeaFile server.
+
+ Yields:
+ Batches of documents
+ """
+ logger.info(f"Loading all documents from SeaFile server {self.seafile_url}")
+ return self._yield_seafile_documents(
+ start=datetime(1970, 1, 1, tzinfo=timezone.utc),
+ end=datetime.now(timezone.utc),
+ )
+
+ def poll_source(
+ self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
+ ) -> GenerateDocumentsOutput:
+ """Poll SeaFile server for updated documents.
+
+ Args:
+ start: Start timestamp (seconds since Unix epoch)
+ end: End timestamp (seconds since Unix epoch)
+
+ Yields:
+ Batches of documents
+ """
+ start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
+ end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
+
+ logger.info(f"Polling SeaFile for updates from {start_datetime} to {end_datetime}")
+
+ for batch in self._yield_seafile_documents(start_datetime, end_datetime):
+ yield batch
+
+
diff --git a/common/data_source/utils.py b/common/data_source/utils.py
index f69ecbd7863..4cc3cce43c8 100644
--- a/common/data_source/utils.py
+++ b/common/data_source/utils.py
@@ -315,14 +315,13 @@ def _refresh_credentials() -> dict[str, str]:
region_name=credentials["region"],
)
elif bucket_type == BlobType.S3_COMPATIBLE:
- addressing_style = credentials.get("addressing_style", "virtual")
return boto3.client(
"s3",
endpoint_url=credentials["endpoint_url"],
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
- config=Config(s3={'addressing_style': addressing_style}),
+ config=Config(s3={'addressing_style': credentials["addressing_style"]}),
)
else:
@@ -1149,3 +1148,137 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
next_ind += 1
del future_to_index[future]
+
+
+def sanitize_filename(name: str, extension: str = "txt") -> str:
+ """
+ Soft sanitize for MinIO/S3:
+ - Replace only prohibited characters with a space.
+ - Preserve readability (no ugly underscores).
+ - Collapse multiple spaces.
+ """
+ if name is None:
+ return f"file.{extension}"
+
+ name = str(name).strip()
+
+ # Characters that MUST NOT appear in S3/MinIO object keys
+ # Replace them with a space (not underscore)
+ forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
+ name = re.sub(forbidden, " ", name)
+
+ # Replace slashes "/" (S3 interprets as folder) with space
+ name = name.replace("/", " ")
+
+ # Collapse multiple spaces into one
+ name = re.sub(r"\s+", " ", name)
+
+ # Trim both ends
+ name = name.strip()
+
+ # Enforce reasonable max length
+ if len(name) > 200:
+ base, ext = os.path.splitext(name)
+ name = base[:180].rstrip() + ext
+
+ if not os.path.splitext(name)[1]:
+ name += f".{extension}"
+
+ return name
+F = TypeVar("F", bound=Callable[..., Any])
+
+class _RateLimitDecorator:
+ """Builds a generic wrapper/decorator for calls to external APIs that
+ prevents making more than `max_calls` requests per `period`
+
+ Implementation inspired by the `ratelimit` library:
+ https://github.com/tomasbasham/ratelimit.
+
+ NOTE: is not thread safe.
+ """
+
+ def __init__(
+ self,
+ max_calls: int,
+ period: float, # in seconds
+ sleep_time: float = 2, # in seconds
+ sleep_backoff: float = 2, # applies exponential backoff
+ max_num_sleep: int = 0,
+ ):
+ self.max_calls = max_calls
+ self.period = period
+ self.sleep_time = sleep_time
+ self.sleep_backoff = sleep_backoff
+ self.max_num_sleep = max_num_sleep
+
+ self.call_history: list[float] = []
+ self.curr_calls = 0
+
+ def __call__(self, func: F) -> F:
+ @wraps(func)
+ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
+ # cleanup calls which are no longer relevant
+ self._cleanup()
+
+ # check if we've exceeded the rate limit
+ sleep_cnt = 0
+ while len(self.call_history) == self.max_calls:
+ sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
+ logging.warning(
+ f"Rate limit exceeded for function {func.__name__}. "
+ f"Waiting {sleep_time} seconds before retrying."
+ )
+ time.sleep(sleep_time)
+ sleep_cnt += 1
+ if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
+ raise RateLimitTriedTooManyTimesError(
+ f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
+ )
+
+ self._cleanup()
+
+ # add the current call to the call history
+ self.call_history.append(time.monotonic())
+ return func(*args, **kwargs)
+
+ return cast(F, wrapped_func)
+
+ def _cleanup(self) -> None:
+ curr_time = time.monotonic()
+ time_to_expire_before = curr_time - self.period
+ self.call_history = [
+ call_time
+ for call_time in self.call_history
+ if call_time > time_to_expire_before
+ ]
+
+rate_limit_builder = _RateLimitDecorator
+
+def retry_builder(
+ tries: int = 20,
+ delay: float = 0.1,
+ max_delay: float | None = 60,
+ backoff: float = 2,
+ jitter: tuple[float, float] | float = 1,
+ exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
+) -> Callable[[F], F]:
+ """Builds a generic wrapper/decorator for calls to external APIs that
+ may fail due to rate limiting, flakes, or other reasons. Applies exponential
+ backoff with jitter to retry the call."""
+
+ def retry_with_default(func: F) -> F:
+ @retry(
+ tries=tries,
+ delay=delay,
+ max_delay=max_delay,
+ backoff=backoff,
+ jitter=jitter,
+ logger=logging.getLogger(__name__),
+ exceptions=exceptions,
+ )
+ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
+ return func(*args, **kwargs)
+
+ return cast(F, wrapped_func)
+
+ return retry_with_default
diff --git a/common/data_source/webdav_connector.py b/common/data_source/webdav_connector.py
index f8e61578900..ec06a64e192 100644
--- a/common/data_source/webdav_connector.py
+++ b/common/data_source/webdav_connector.py
@@ -82,10 +82,6 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
base_url=self.base_url,
auth=(username, password)
)
-
- # Test connection
- self.client.exists(self.remote_path)
-
except Exception as e:
logging.error(f"Failed to connect to WebDAV server: {e}")
raise ConnectorMissingCredentialError(
@@ -308,60 +304,79 @@ def poll_source(
yield batch
def validate_connector_settings(self) -> None:
- """Validate WebDAV connector settings
-
- Raises:
- ConnectorMissingCredentialError: If credentials are not loaded
- ConnectorValidationError: If settings are invalid
+ """Validate WebDAV connector settings.
+
+ Validation should exercise the same code-paths used by the connector
+ (directory listing / PROPFIND), avoiding exists() which may probe with
+ methods that differ across servers.
"""
if self.client is None:
- raise ConnectorMissingCredentialError(
- "WebDAV credentials not loaded."
- )
+ raise ConnectorMissingCredentialError("WebDAV credentials not loaded.")
if not self.base_url:
- raise ConnectorValidationError(
- "No base URL was provided in connector settings."
- )
+ raise ConnectorValidationError("No base URL was provided in connector settings.")
+
+ # Normalize directory path: for collections, many servers behave better with trailing '/'
+ test_path = self.remote_path or "/"
+ if not test_path.startswith("/"):
+ test_path = f"/{test_path}"
+ if test_path != "/" and not test_path.endswith("/"):
+ test_path = f"{test_path}/"
try:
- if not self.client.exists(self.remote_path):
- raise ConnectorValidationError(
- f"Remote path '{self.remote_path}' does not exist on WebDAV server."
- )
+ # Use the same behavior as real sync: list directory with details (PROPFIND)
+ self.client.ls(test_path, detail=True)
except Exception as e:
- error_message = str(e)
-
- if "401" in error_message or "unauthorized" in error_message.lower():
- raise CredentialExpiredError(
- "WebDAV credentials appear invalid or expired."
- )
-
- if "403" in error_message or "forbidden" in error_message.lower():
+ # Prefer structured status codes if present on the exception/response
+ status = None
+ for attr in ("status_code", "code"):
+ v = getattr(e, attr, None)
+ if isinstance(v, int):
+ status = v
+ break
+ if status is None:
+ resp = getattr(e, "response", None)
+ v = getattr(resp, "status_code", None)
+ if isinstance(v, int):
+ status = v
+
+ # If we can classify by status code, do it
+ if status == 401:
+ raise CredentialExpiredError("WebDAV credentials appear invalid or expired.")
+ if status == 403:
raise InsufficientPermissionsError(
f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server."
)
-
- if "404" in error_message or "not found" in error_message.lower():
+ if status == 404:
raise ConnectorValidationError(
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
)
+ # Fallback: avoid brittle substring matching that caused false positives.
+ # Provide the original exception for diagnosis.
raise ConnectorValidationError(
- f"Unexpected WebDAV client error: {e}"
+ f"WebDAV validation failed for path '{test_path}': {repr(e)}"
)
+
if __name__ == "__main__":
credentials_dict = {
"username": os.environ.get("WEBDAV_USERNAME"),
"password": os.environ.get("WEBDAV_PASSWORD"),
}
+ credentials_dict = {
+ "username": "user",
+ "password": "pass",
+ }
+
+
+
connector = WebDAVConnector(
- base_url=os.environ.get("WEBDAV_URL") or "https://webdav.example.com",
- remote_path=os.environ.get("WEBDAV_PATH") or "/",
+ base_url="http://172.17.0.1:8080/",
+ remote_path="/",
)
try:
diff --git a/common/data_source/zendesk_connector.py b/common/data_source/zendesk_connector.py
new file mode 100644
index 00000000000..85b3426fe3f
--- /dev/null
+++ b/common/data_source/zendesk_connector.py
@@ -0,0 +1,667 @@
+import copy
+import logging
+import time
+from collections.abc import Callable
+from collections.abc import Iterator
+from typing import Any
+
+import requests
+from pydantic import BaseModel
+from requests.exceptions import HTTPError
+from typing_extensions import override
+
+from common.data_source.config import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS, DocumentSource
+from common.data_source.exceptions import ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError
+from common.data_source.html_utils import parse_html_page_basic
+from common.data_source.interfaces import CheckpointOutput, CheckpointOutputWrapper, CheckpointedConnector, IndexingHeartbeatInterface, SlimConnectorWithPermSync
+from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, ConnectorFailure, Document, DocumentFailure, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument
+from common.data_source.utils import retry_builder, time_str_to_utc,rate_limit_builder
+
+MAX_PAGE_SIZE = 30 # Zendesk API maximum
+MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
+_SLIM_BATCH_SIZE = 1000
+
+
+class ZendeskCredentialsNotSetUpError(PermissionError):
+ def __init__(self) -> None:
+ super().__init__(
+ "Zendesk Credentials are not set up, was load_credentials called?"
+ )
+
+
+class ZendeskClient:
+ def __init__(
+ self,
+ subdomain: str,
+ email: str,
+ token: str,
+ calls_per_minute: int | None = None,
+ ):
+ self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
+ self.auth = (f"{email}/token", token)
+ self.make_request = request_with_rate_limit(self, calls_per_minute)
+
+
+def request_with_rate_limit(
+ client: ZendeskClient, max_calls_per_minute: int | None = None
+) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
+ @retry_builder()
+ @(
+ rate_limit_builder(max_calls=max_calls_per_minute, period=60)
+ if max_calls_per_minute
+ else lambda x: x
+ )
+ def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
+ response = requests.get(
+ f"{client.base_url}/{endpoint}", auth=client.auth, params=params
+ )
+
+ if response.status_code == 429:
+ retry_after = response.headers.get("Retry-After")
+ if retry_after is not None:
+ # Sleep for the duration indicated by the Retry-After header
+ time.sleep(int(retry_after))
+
+ elif (
+ response.status_code == 403
+ and response.json().get("error") == "SupportProductInactive"
+ ):
+ return response.json()
+
+ response.raise_for_status()
+ return response.json()
+
+ return make_request
+
+
+class ZendeskPageResponse(BaseModel):
+ data: list[dict[str, Any]]
+ meta: dict[str, Any]
+ has_more: bool
+
+
+def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
+ content_tags: dict[str, str] = {}
+ params = {"page[size]": MAX_PAGE_SIZE}
+
+ try:
+ while True:
+ data = client.make_request("guide/content_tags", params)
+
+ for tag in data.get("records", []):
+ content_tags[tag["id"]] = tag["name"]
+
+ # Check if there are more pages
+ if data.get("meta", {}).get("has_more", False):
+ params["page[after]"] = data["meta"]["after_cursor"]
+ else:
+ break
+
+ return content_tags
+ except Exception as e:
+ raise Exception(f"Error fetching content tags: {str(e)}")
+
+
+def _get_articles(
+ client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
+) -> Iterator[dict[str, Any]]:
+ params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
+ if start_time is not None:
+ params["start_time"] = start_time
+
+ while True:
+ data = client.make_request("help_center/articles", params)
+ for article in data["articles"]:
+ yield article
+
+ if not data.get("meta", {}).get("has_more"):
+ break
+ params["page[after]"] = data["meta"]["after_cursor"]
+
+
+def _get_article_page(
+ client: ZendeskClient,
+ start_time: int | None = None,
+ after_cursor: str | None = None,
+ page_size: int = MAX_PAGE_SIZE,
+) -> ZendeskPageResponse:
+ params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
+ if start_time is not None:
+ params["start_time"] = start_time
+ if after_cursor is not None:
+ params["page[after]"] = after_cursor
+
+ data = client.make_request("help_center/articles", params)
+ return ZendeskPageResponse(
+ data=data["articles"],
+ meta=data["meta"],
+ has_more=bool(data["meta"].get("has_more", False)),
+ )
+
+
+def _get_tickets(
+ client: ZendeskClient, start_time: int | None = None
+) -> Iterator[dict[str, Any]]:
+ params = {"start_time": start_time or 0}
+
+ while True:
+ data = client.make_request("incremental/tickets.json", params)
+ for ticket in data["tickets"]:
+ yield ticket
+
+ if not data.get("end_of_stream", False):
+ params["start_time"] = data["end_time"]
+ else:
+ break
+
+
+# TODO: maybe these don't need to be their own functions?
+def _get_tickets_page(
+ client: ZendeskClient, start_time: int | None = None
+) -> ZendeskPageResponse:
+ params = {"start_time": start_time or 0}
+
+ # NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
+ # in my local testing with very few tickets. We'll look into it if this becomes an
+ # issue in larger deployments
+ data = client.make_request("incremental/tickets.json", params)
+ if data.get("error") == "SupportProductInactive":
+ raise ValueError(
+ "Zendesk Support Product is not active for this account, No tickets to index"
+ )
+ return ZendeskPageResponse(
+ data=data["tickets"],
+ meta={"end_time": data["end_time"]},
+ has_more=not bool(data.get("end_of_stream", False)),
+ )
+
+
+def _fetch_author(
+ client: ZendeskClient, author_id: str | int
+) -> BasicExpertInfo | None:
+ # Skip fetching if author_id is invalid
+ # cast to str to avoid issues with zendesk changing their types
+ if not author_id or str(author_id) == "-1":
+ return None
+
+ try:
+ author_data = client.make_request(f"users/{author_id}", {})
+ user = author_data.get("user")
+ return (
+ BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
+ if user and user.get("name") and user.get("email")
+ else None
+ )
+ except requests.exceptions.HTTPError:
+ # Handle any API errors gracefully
+ return None
+
+
+def _article_to_document(
+ article: dict[str, Any],
+ content_tags: dict[str, str],
+ author_map: dict[str, BasicExpertInfo],
+ client: ZendeskClient,
+) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
+ author_id = article.get("author_id")
+ if not author_id:
+ author = None
+ else:
+ author = (
+ author_map.get(author_id)
+ if author_id in author_map
+ else _fetch_author(client, author_id)
+ )
+
+ new_author_mapping = {author_id: author} if author_id and author else None
+
+ updated_at = article.get("updated_at")
+ update_time = time_str_to_utc(updated_at) if updated_at else None
+
+ text = parse_html_page_basic(article.get("body") or "")
+ blob = text.encode("utf-8", errors="replace")
+ # Build metadata
+ metadata: dict[str, str | list[str]] = {
+ "labels": [str(label) for label in article.get("label_names", []) if label],
+ "content_tags": [
+ content_tags[tag_id]
+ for tag_id in article.get("content_tag_ids", [])
+ if tag_id in content_tags
+ ],
+ }
+
+ # Remove empty values
+ metadata = {k: v for k, v in metadata.items() if v}
+
+ return new_author_mapping, Document(
+ id=f"article:{article['id']}",
+ source=DocumentSource.ZENDESK,
+ semantic_identifier=article["title"],
+ extension=".txt",
+ blob=blob,
+ size_bytes=len(blob),
+ doc_updated_at=update_time,
+ primary_owners=[author] if author else None,
+ metadata=metadata,
+ )
+
+
+def _get_comment_text(
+ comment: dict[str, Any],
+ author_map: dict[str, BasicExpertInfo],
+ client: ZendeskClient,
+) -> tuple[dict[str, BasicExpertInfo] | None, str]:
+ author_id = comment.get("author_id")
+ if not author_id:
+ author = None
+ else:
+ author = (
+ author_map.get(author_id)
+ if author_id in author_map
+ else _fetch_author(client, author_id)
+ )
+
+ new_author_mapping = {author_id: author} if author_id and author else None
+
+ comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
+ comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
+
+ return new_author_mapping, comment_text
+
+
+def _ticket_to_document(
+ ticket: dict[str, Any],
+ author_map: dict[str, BasicExpertInfo],
+ client: ZendeskClient,
+) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
+ submitter_id = ticket.get("submitter")
+ if not submitter_id:
+ submitter = None
+ else:
+ submitter = (
+ author_map.get(submitter_id)
+ if submitter_id in author_map
+ else _fetch_author(client, submitter_id)
+ )
+
+ new_author_mapping = (
+ {submitter_id: submitter} if submitter_id and submitter else None
+ )
+
+ updated_at = ticket.get("updated_at")
+ update_time = time_str_to_utc(updated_at) if updated_at else None
+
+ metadata: dict[str, str | list[str]] = {}
+ if status := ticket.get("status"):
+ metadata["status"] = status
+ if priority := ticket.get("priority"):
+ metadata["priority"] = priority
+ if tags := ticket.get("tags"):
+ metadata["tags"] = tags
+ if ticket_type := ticket.get("type"):
+ metadata["ticket_type"] = ticket_type
+
+ # Fetch comments for the ticket
+ comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
+ comments = comments_data.get("comments", [])
+
+ comment_texts = []
+ for comment in comments:
+ new_author_mapping, comment_text = _get_comment_text(
+ comment, author_map, client
+ )
+ if new_author_mapping:
+ author_map.update(new_author_mapping)
+ comment_texts.append(comment_text)
+
+ comments_text = "\n\n".join(comment_texts)
+
+ subject = ticket.get("subject")
+ full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
+
+ blob = full_text.encode("utf-8", errors="replace")
+ return new_author_mapping, Document(
+ id=f"zendesk_ticket_{ticket['id']}",
+ blob=blob,
+ extension=".txt",
+ size_bytes=len(blob),
+ source=DocumentSource.ZENDESK,
+ semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
+ doc_updated_at=update_time,
+ primary_owners=[submitter] if submitter else None,
+ metadata=metadata,
+ )
+
+
+class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
+ # We use cursor-based paginated retrieval for articles
+ after_cursor_articles: str | None
+
+ # We use timestamp-based paginated retrieval for tickets
+ next_start_time_tickets: int | None
+
+ cached_author_map: dict[str, BasicExpertInfo] | None
+ cached_content_tags: dict[str, str] | None
+
+
+class ZendeskConnector(
+ SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
+):
+ def __init__(
+ self,
+ content_type: str = "articles",
+ calls_per_minute: int | None = None,
+ ) -> None:
+ self.content_type = content_type
+ self.subdomain = ""
+ # Fetch all tags ahead of time
+ self.content_tags: dict[str, str] = {}
+ self.calls_per_minute = calls_per_minute
+
+ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
+ # Subdomain is actually the whole URL
+ subdomain = (
+ credentials["zendesk_subdomain"]
+ .replace("https://", "")
+ .split(".zendesk.com")[0]
+ )
+ self.subdomain = subdomain
+
+ self.client = ZendeskClient(
+ subdomain,
+ credentials["zendesk_email"],
+ credentials["zendesk_token"],
+ calls_per_minute=self.calls_per_minute,
+ )
+ return None
+
+ @override
+ def load_from_checkpoint(
+ self,
+ start: SecondsSinceUnixEpoch,
+ end: SecondsSinceUnixEpoch,
+ checkpoint: ZendeskConnectorCheckpoint,
+ ) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
+ if self.client is None:
+ raise ZendeskCredentialsNotSetUpError()
+ if checkpoint.cached_content_tags is None:
+ checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
+ return checkpoint # save the content tags to the checkpoint
+ self.content_tags = checkpoint.cached_content_tags
+
+ if self.content_type == "articles":
+ checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
+ return checkpoint
+ elif self.content_type == "tickets":
+ checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
+ return checkpoint
+ else:
+ raise ValueError(f"Unsupported content_type: {self.content_type}")
+
+ def _retrieve_articles(
+ self,
+ start: SecondsSinceUnixEpoch | None,
+ end: SecondsSinceUnixEpoch | None,
+ checkpoint: ZendeskConnectorCheckpoint,
+ ) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
+ checkpoint = copy.deepcopy(checkpoint)
+ # This one is built on the fly as there may be more many more authors than tags
+ author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
+ after_cursor = checkpoint.after_cursor_articles
+ doc_batch: list[Document] = []
+
+ response = _get_article_page(
+ self.client,
+ start_time=int(start) if start else None,
+ after_cursor=after_cursor,
+ )
+ articles = response.data
+ has_more = response.has_more
+ after_cursor = response.meta.get("after_cursor")
+ for article in articles:
+ if (
+ article.get("body") is None
+ or article.get("draft")
+ or any(
+ label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
+ for label in article.get("label_names", [])
+ )
+ ):
+ continue
+
+ try:
+ new_author_map, document = _article_to_document(
+ article, self.content_tags, author_map, self.client
+ )
+ except Exception as e:
+ logging.error(f"Error processing article {article['id']}: {e}")
+ yield ConnectorFailure(
+ failed_document=DocumentFailure(
+ document_id=f"{article.get('id')}",
+ document_link=article.get("html_url", ""),
+ ),
+ failure_message=str(e),
+ exception=e,
+ )
+ continue
+
+ if new_author_map:
+ author_map.update(new_author_map)
+ updated_at = document.doc_updated_at
+ updated_ts = updated_at.timestamp() if updated_at else None
+ if updated_ts is not None:
+ if start is not None and updated_ts <= start:
+ continue
+ if end is not None and updated_ts > end:
+ continue
+
+ doc_batch.append(document)
+
+ if not has_more:
+ yield from doc_batch
+ checkpoint.has_more = False
+ return checkpoint
+
+ # Sometimes no documents are retrieved, but the cursor
+ # is still updated so the connector makes progress.
+ yield from doc_batch
+ checkpoint.after_cursor_articles = after_cursor
+
+ last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
+ checkpoint.has_more = bool(
+ end is None
+ or last_doc_updated_at is None
+ or last_doc_updated_at.timestamp() <= end
+ )
+ checkpoint.cached_author_map = (
+ author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
+ )
+ return checkpoint
+
+ def _retrieve_tickets(
+ self,
+ start: SecondsSinceUnixEpoch | None,
+ end: SecondsSinceUnixEpoch | None,
+ checkpoint: ZendeskConnectorCheckpoint,
+ ) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
+ checkpoint = copy.deepcopy(checkpoint)
+ if self.client is None:
+ raise ZendeskCredentialsNotSetUpError()
+
+ author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
+
+ doc_batch: list[Document] = []
+ next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
+ ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
+
+ tickets = ticket_response.data
+ has_more = ticket_response.has_more
+ next_start_time = ticket_response.meta["end_time"]
+ for ticket in tickets:
+ if ticket.get("status") == "deleted":
+ continue
+
+ try:
+ new_author_map, document = _ticket_to_document(
+ ticket=ticket,
+ author_map=author_map,
+ client=self.client,
+ )
+ except Exception as e:
+ logging.error(f"Error processing ticket {ticket['id']}: {e}")
+ yield ConnectorFailure(
+ failed_document=DocumentFailure(
+ document_id=f"{ticket.get('id')}",
+ document_link=ticket.get("url", ""),
+ ),
+ failure_message=str(e),
+ exception=e,
+ )
+ continue
+
+ if new_author_map:
+ author_map.update(new_author_map)
+
+ updated_at = document.doc_updated_at
+ updated_ts = updated_at.timestamp() if updated_at else None
+
+ if updated_ts is not None:
+ if start is not None and updated_ts <= start:
+ continue
+ if end is not None and updated_ts > end:
+ continue
+
+ doc_batch.append(document)
+
+ if not has_more:
+ yield from doc_batch
+ checkpoint.has_more = False
+ return checkpoint
+
+ yield from doc_batch
+ checkpoint.next_start_time_tickets = next_start_time
+ last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
+ checkpoint.has_more = bool(
+ end is None
+ or last_doc_updated_at is None
+ or last_doc_updated_at.timestamp() <= end
+ )
+ checkpoint.cached_author_map = (
+ author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
+ )
+ return checkpoint
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ start: SecondsSinceUnixEpoch | None = None,
+ end: SecondsSinceUnixEpoch | None = None,
+ callback: IndexingHeartbeatInterface | None = None,
+ ) -> GenerateSlimDocumentOutput:
+ slim_doc_batch: list[SlimDocument] = []
+ if self.content_type == "articles":
+ articles = _get_articles(
+ self.client, start_time=int(start) if start else None
+ )
+ for article in articles:
+ slim_doc_batch.append(
+ SlimDocument(
+ id=f"article:{article['id']}",
+ )
+ )
+ if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
+ yield slim_doc_batch
+ slim_doc_batch = []
+ elif self.content_type == "tickets":
+ tickets = _get_tickets(
+ self.client, start_time=int(start) if start else None
+ )
+ for ticket in tickets:
+ slim_doc_batch.append(
+ SlimDocument(
+ id=f"zendesk_ticket_{ticket['id']}",
+ )
+ )
+ if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
+ yield slim_doc_batch
+ slim_doc_batch = []
+ else:
+ raise ValueError(f"Unsupported content_type: {self.content_type}")
+ if slim_doc_batch:
+ yield slim_doc_batch
+
+ @override
+ def validate_connector_settings(self) -> None:
+ if self.client is None:
+ raise ZendeskCredentialsNotSetUpError()
+
+ try:
+ _get_article_page(self.client, start_time=0)
+ except HTTPError as e:
+ # Check for HTTP status codes
+ if e.response.status_code == 401:
+ raise CredentialExpiredError(
+ "Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
+ ) from e
+ elif e.response.status_code == 403:
+ raise InsufficientPermissionsError(
+ "Your Zendesk token does not have sufficient permissions (HTTP 403)."
+ ) from e
+ elif e.response.status_code == 404:
+ raise ConnectorValidationError(
+ "Zendesk resource not found (HTTP 404)."
+ ) from e
+ else:
+ raise ConnectorValidationError(
+ f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
+ ) from e
+
+ @override
+ def validate_checkpoint_json(
+ self, checkpoint_json: str
+ ) -> ZendeskConnectorCheckpoint:
+ return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
+
+ @override
+ def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
+ return ZendeskConnectorCheckpoint(
+ after_cursor_articles=None,
+ next_start_time_tickets=None,
+ cached_author_map=None,
+ cached_content_tags=None,
+ has_more=True,
+ )
+
+
+if __name__ == "__main__":
+ import os
+
+ connector = ZendeskConnector(content_type="articles")
+ connector.load_credentials(
+ {
+ "zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
+ "zendesk_email": os.environ["ZENDESK_EMAIL"],
+ "zendesk_token": os.environ["ZENDESK_TOKEN"],
+ }
+ )
+
+ current = time.time()
+ one_day_ago = current - 24 * 60 * 60 # 1 day
+
+ checkpoint = connector.build_dummy_checkpoint()
+
+ while checkpoint.has_more:
+ gen = connector.load_from_checkpoint(
+ one_day_ago, current, checkpoint
+ )
+
+ wrapper = CheckpointOutputWrapper()
+ any_doc = False
+
+ for document, failure, next_checkpoint in wrapper(gen):
+ if document:
+ print("got document:", document.id)
+ any_doc = True
+
+ checkpoint = next_checkpoint
+ if any_doc:
+ break
\ No newline at end of file
diff --git a/common/doc_store/doc_store_base.py b/common/doc_store/doc_store_base.py
index fe6304f7579..fd684baef25 100644
--- a/common/doc_store/doc_store_base.py
+++ b/common/doc_store/doc_store_base.py
@@ -164,7 +164,7 @@ def health(self) -> dict:
"""
@abstractmethod
- def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
+ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
"""
Create an index with given name
"""
diff --git a/common/doc_store/es_conn_base.py b/common/doc_store/es_conn_base.py
index cec628c0db5..dccb8a2fe3d 100644
--- a/common/doc_store/es_conn_base.py
+++ b/common/doc_store/es_conn_base.py
@@ -24,6 +24,7 @@
from elasticsearch import NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
+from elasticsearch.client import IndicesClient
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
@@ -47,7 +48,8 @@ def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragf
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
- self.mapping = json.load(open(fp_mapping, "r"))
+ with open(fp_mapping, "r") as f:
+ self.mapping = json.load(f)
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
@@ -123,17 +125,40 @@ def get_cluster_stats(self):
Table operations
"""
- def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
+ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
+ # parser_id is used by Infinity but not needed for ES (kept for interface compatibility)
if self.index_exist(index_name, dataset_id):
return True
try:
- from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
+ def create_doc_meta_idx(self, index_name: str):
+ """
+ Create a document metadata index.
+
+ Index name pattern: ragflow_doc_meta_{tenant_id}
+ - Per-tenant metadata index for storing document metadata fields
+ """
+ if self.index_exist(index_name, ""):
+ return True
+ try:
+ fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_es_mapping.json")
+ if not os.path.exists(fp_mapping):
+ self.logger.error(f"Document metadata mapping file not found at {fp_mapping}")
+ return False
+
+ with open(fp_mapping, "r") as f:
+ doc_meta_mapping = json.load(f)
+ return IndicesClient(self.es).create(index=index_name,
+ settings=doc_meta_mapping["settings"],
+ mappings=doc_meta_mapping["mappings"])
+ except Exception as e:
+ self.logger.exception(f"Error creating document metadata index {index_name}: {e}")
+
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py
index 82650f81d6f..327f518f5a1 100644
--- a/common/doc_store/infinity_conn_base.py
+++ b/common/doc_store/infinity_conn_base.py
@@ -33,12 +33,13 @@
class InfinityConnectionBase(DocStoreConnection):
- def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
+ def __init__(self, mapping_file_name: str = "infinity_mapping.json", logger_name: str = "ragflow.infinity_conn", table_name_prefix: str="ragflow_"):
from common.doc_store.infinity_conn_pool import INFINITY_CONN
self.dbName = settings.INFINITY.get("db_name", "default_db")
self.mapping_file_name = mapping_file_name
self.logger = logging.getLogger(logger_name)
+ self.table_name_prefix = table_name_prefix
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
@@ -73,9 +74,13 @@ def _migrate_db(self, inf_conn):
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
- schema = json.load(open(fp_mapping))
+ with open(fp_mapping) as f:
+ schema = json.load(f)
table_names = inf_db.list_tables().table_names
for table_name in table_names:
+ if not table_name.startswith(self.table_name_prefix):
+ # Skip tables not created by me
+ continue
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
@@ -84,22 +89,43 @@ def _migrate_db(self, inf_conn):
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
- if field_name in column_names:
- continue
- res = inf_table.add_columns({field_name: field_info})
- assert res.error_code == infinity.ErrorCode.OK
- self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
- if field_info["type"] != "varchar" or "analyzer" not in field_info:
- continue
- analyzers = field_info["analyzer"]
- if isinstance(analyzers, str):
- analyzers = [analyzers]
- for analyzer in analyzers:
- inf_table.create_index(
- f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
- IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
- ConflictType.Ignore,
- )
+ is_new_column = field_name not in column_names
+ if is_new_column:
+ res = inf_table.add_columns({field_name: field_info})
+ assert res.error_code == infinity.ErrorCode.OK
+ self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
+
+ if field_info["type"] == "varchar" and "analyzer" in field_info:
+ analyzers = field_info["analyzer"]
+ if isinstance(analyzers, str):
+ analyzers = [analyzers]
+ for analyzer in analyzers:
+ inf_table.create_index(
+ f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
+ IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
+ ConflictType.Ignore,
+ )
+
+ if "index_type" in field_info:
+ index_config = field_info["index_type"]
+ if isinstance(index_config, str) and index_config == "secondary":
+ inf_table.create_index(
+ f"sec_{field_name}",
+ IndexInfo(field_name, IndexType.Secondary),
+ ConflictType.Ignore,
+ )
+ self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name}")
+ elif isinstance(index_config, dict):
+ if index_config.get("type") == "secondary":
+ params = {}
+ if "cardinality" in index_config:
+ params = {"cardinality": index_config["cardinality"]}
+ inf_table.create_index(
+ f"sec_{field_name}",
+ IndexInfo(field_name, IndexType.Secondary, params),
+ ConflictType.Ignore,
+ )
+ self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name} with params {params}")
"""
Dataframe and fields convert
@@ -228,15 +254,27 @@ def health(self) -> dict:
Table operations
"""
- def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
+ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
table_name = f"{index_name}_{dataset_id}"
+ self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}")
+
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
+ # Use configured schema
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
+
+ if parser_id is not None:
+ from common.constants import ParserType
+
+ if parser_id == ParserType.TABLE.value:
+ # Table parser: add chunk_data JSON column to store table-specific fields
+ schema["chunk_data"] = {"type": "json", "default": "{}"}
+ self.logger.info("Added chunk_data column for TABLE parser")
+
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
@@ -270,12 +308,95 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
+
+ # Create secondary indexes for fields with index_type
+ for field_name, field_info in schema.items():
+ if "index_type" not in field_info:
+ continue
+ index_config = field_info["index_type"]
+ if isinstance(index_config, str) and index_config == "secondary":
+ inf_table.create_index(
+ f"sec_{field_name}",
+ IndexInfo(field_name, IndexType.Secondary),
+ ConflictType.Ignore,
+ )
+ self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name}")
+ elif isinstance(index_config, dict):
+ if index_config.get("type") == "secondary":
+ params = {}
+ if "cardinality" in index_config:
+ params = {"cardinality": index_config["cardinality"]}
+ inf_table.create_index(
+ f"sec_{field_name}",
+ IndexInfo(field_name, IndexType.Secondary, params),
+ ConflictType.Ignore,
+ )
+ self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name} with params {params}")
+
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
return True
+ def create_doc_meta_idx(self, index_name: str):
+ """
+ Create a document metadata table.
+
+ Table name pattern: ragflow_doc_meta_{tenant_id}
+ - Per-tenant metadata table for storing document metadata fields
+ """
+ table_name = index_name
+ inf_conn = self.connPool.get_conn()
+ inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
+ try:
+ fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_infinity_mapping.json")
+ if not os.path.exists(fp_mapping):
+ self.logger.error(f"Document metadata mapping file not found at {fp_mapping}")
+ return False
+ with open(fp_mapping) as f:
+ schema = json.load(f)
+ inf_db.create_table(
+ table_name,
+ schema,
+ ConflictType.Ignore,
+ )
+
+ # Create secondary indexes on id and kb_id for better query performance
+ inf_table = inf_db.get_table(table_name)
+
+ try:
+ inf_table.create_index(
+ f"idx_{table_name}_id",
+ IndexInfo("id", IndexType.Secondary),
+ ConflictType.Ignore,
+ )
+ self.logger.debug(f"INFINITY created secondary index on id for table {table_name}")
+ except Exception as e:
+ self.logger.warning(f"Failed to create index on id for {table_name}: {e}")
+
+ try:
+ inf_table.create_index(
+ f"idx_{table_name}_kb_id",
+ IndexInfo("kb_id", IndexType.Secondary),
+ ConflictType.Ignore,
+ )
+ self.logger.debug(f"INFINITY created secondary index on kb_id for table {table_name}")
+ except Exception as e:
+ self.logger.warning(f"Failed to create index on kb_id for {table_name}: {e}")
+
+ self.connPool.release_conn(inf_conn)
+ self.logger.debug(f"INFINITY created document metadata table {table_name} with secondary indexes")
+ return True
+
+ except Exception as e:
+ self.connPool.release_conn(inf_conn)
+ self.logger.exception(f"Error creating document metadata table {table_name}: {e}")
+ return False
+
def delete_idx(self, index_name: str, dataset_id: str):
- table_name = f"{index_name}_{dataset_id}"
+ if index_name.startswith("ragflow_doc_meta_"):
+ table_name = index_name
+ else:
+ table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
@@ -283,7 +404,10 @@ def delete_idx(self, index_name: str, dataset_id: str):
self.logger.info(f"INFINITY dropped table {table_name}")
def index_exist(self, index_name: str, dataset_id: str) -> bool:
- table_name = f"{index_name}_{dataset_id}"
+ if index_name.startswith("ragflow_doc_meta_"):
+ table_name = index_name
+ else:
+ table_name = f"{index_name}_{dataset_id}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@@ -330,7 +454,10 @@ def update(self, condition: dict, new_value: dict, index_name: str, dataset_id:
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
- table_name = f"{index_name}_{dataset_id}"
+ if index_name.startswith("ragflow_doc_meta_"):
+ table_name = index_name
+ else:
+ table_name = f"{index_name}_{dataset_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
@@ -367,7 +494,10 @@ def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords:
num_rows = len(res)
column_id = res["id"]
if field_name not in res:
- return {}
+ if field_name == "content_with_weight" and "content" in res:
+ field_name = "content"
+ else:
+ return {}
for i in range(num_rows):
id = column_id[i]
txt = res[field_name][i]
@@ -450,4 +580,174 @@ def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_na
"""
def sql(self, sql: str, fetch_size: int, format: str):
- raise NotImplementedError("Not implemented")
+ """
+ Execute SQL query on Infinity database via psql command.
+ Transform text-to-sql for Infinity's SQL syntax.
+ """
+ import subprocess
+
+ try:
+ self.logger.debug(f"InfinityConnection.sql get sql: {sql}")
+
+ # Clean up SQL
+ sql = re.sub(r"[ `]+", " ", sql)
+ sql = sql.replace("%", "")
+
+ # Transform SELECT field aliases to actual stored field names
+ # Build field mapping from infinity_mapping.json comment field
+ field_mapping = {}
+ # Also build reverse mapping for column names in result
+ reverse_mapping = {}
+ fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
+ if os.path.exists(fp_mapping):
+ with open(fp_mapping) as f:
+ schema = json.load(f)
+ for field_name, field_info in schema.items():
+ if "comment" in field_info:
+ # Parse comma-separated aliases from comment
+ # e.g., "docnm_kwd, title_tks, title_sm_tks"
+ aliases = [a.strip() for a in field_info["comment"].split(",")]
+ for alias in aliases:
+ field_mapping[alias] = field_name
+ reverse_mapping[field_name] = alias # Store first alias for reverse mapping
+
+ # Replace field names in SELECT clause
+ select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE)
+ if select_match:
+ select_clause = select_match.group(1)
+ from_clause = select_match.group(2)
+
+ # Apply field transformations
+ for alias, actual in field_mapping.items():
+ select_clause = re.sub(rf"(^|[, ]){alias}([, ]|$)", rf"\1{actual}\2", select_clause)
+
+ sql = select_clause + from_clause + sql[select_match.end() :]
+
+ # Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses
+ for alias, actual in field_mapping.items():
+ # Transform in WHERE clause
+ sql = re.sub(rf"(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b", rf"\1{actual}", sql, flags=re.IGNORECASE)
+ # Transform in ORDER BY clause
+ sql = re.sub(rf"(\border by\s+[^;]*?)(\b){re.escape(alias)}\b", rf"\1{actual}", sql, flags=re.IGNORECASE)
+ # Transform in GROUP BY clause
+ sql = re.sub(rf"(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b", rf"\1{actual}", sql, flags=re.IGNORECASE)
+ # Transform in HAVING clause
+ sql = re.sub(rf"(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b", rf"\1{actual}", sql, flags=re.IGNORECASE)
+
+ self.logger.debug(f"InfinityConnection.sql to execute: {sql}")
+
+ # Get connection parameters from the Infinity connection pool wrapper
+ # We need to use INFINITY_CONN singleton, not the raw ConnectionPool
+ from common.doc_store.infinity_conn_pool import INFINITY_CONN
+
+ conn_info = INFINITY_CONN.get_conn_uri()
+
+ # Parse host and port from conn_info
+ if conn_info and "host=" in conn_info:
+ host_match = re.search(r"host=(\S+)", conn_info)
+ if host_match:
+ host = host_match.group(1)
+ else:
+ host = "infinity"
+ else:
+ host = "infinity"
+
+ # Parse port from conn_info, default to 5432 if not found
+ if conn_info and "port=" in conn_info:
+ port_match = re.search(r"port=(\d+)", conn_info)
+ if port_match:
+ port = port_match.group(1)
+ else:
+ port = "5432"
+ else:
+ port = "5432"
+
+ # Use psql command to execute SQL
+ # Use full path to psql to avoid PATH issues
+ psql_path = "/usr/bin/psql"
+ # Check if psql exists at expected location, otherwise try to find it
+ import shutil
+
+ psql_from_path = shutil.which("psql")
+ if psql_from_path:
+ psql_path = psql_from_path
+
+ # Execute SQL with psql to get both column names and data in one call
+ psql_cmd = [
+ psql_path,
+ "-h",
+ host,
+ "-p",
+ port,
+ "-c",
+ sql,
+ ]
+
+ self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}")
+
+ result = subprocess.run(
+ psql_cmd,
+ capture_output=True,
+ text=True,
+ timeout=10, # 10 second timeout
+ )
+
+ if result.returncode != 0:
+ error_msg = result.stderr.strip()
+ raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}")
+
+ # Parse the output
+ output = result.stdout.strip()
+ if not output:
+ # No results
+ return {"columns": [], "rows": []} if format == "json" else []
+
+ # Parse psql table output which has format:
+ # col1 | col2 | col3
+ # -----+-----+-----
+ # val1 | val2 | val3
+ lines = output.split("\n")
+
+ # Extract column names from first line
+ columns = []
+ rows = []
+
+ if len(lines) >= 1:
+ header_line = lines[0]
+ for col_name in header_line.split("|"):
+ col_name = col_name.strip()
+ if col_name:
+ columns.append({"name": col_name})
+
+ # Data starts after the separator line (line with dashes)
+ data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1
+ for i in range(data_start, len(lines)):
+ line = lines[i].strip()
+ # Skip empty lines and footer lines like "(1 row)"
+ if not line or re.match(r"^\(\d+ row", line):
+ continue
+ # Split by | and strip each cell
+ row = [cell.strip() for cell in line.split("|")]
+ # Ensure row matches column count
+ if len(row) == len(columns):
+ rows.append(row)
+ elif len(row) > len(columns):
+ # Row has more cells than columns - truncate
+ rows.append(row[: len(columns)])
+ elif len(row) < len(columns):
+ # Row has fewer cells - pad with empty strings
+ rows.append(row + [""] * (len(columns) - len(row)))
+
+ if format == "json":
+ result = {"columns": columns, "rows": rows[:fetch_size] if fetch_size > 0 else rows}
+ else:
+ result = rows[:fetch_size] if fetch_size > 0 else rows
+
+ return result
+
+ except subprocess.TimeoutExpired:
+ self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}")
+ raise Exception(f"SQL timeout\n\nSQL: {sql}")
+ except Exception as e:
+ self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}")
+ raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
diff --git a/common/doc_store/infinity_conn_pool.py b/common/doc_store/infinity_conn_pool.py
index f74e244096d..1aa3f81254d 100644
--- a/common/doc_store/infinity_conn_pool.py
+++ b/common/doc_store/infinity_conn_pool.py
@@ -31,7 +31,11 @@ def __init__(self):
if hasattr(settings, "INFINITY"):
self.INFINITY_CONFIG = settings.INFINITY
else:
- self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
+ self.INFINITY_CONFIG = settings.get_base_config("infinity", {
+ "uri": "infinity:23817",
+ "postgres_port": 5432,
+ "db_name": "default_db"
+ })
infinity_uri = self.INFINITY_CONFIG["uri"]
if ":" in infinity_uri:
@@ -61,6 +65,19 @@ def __init__(self):
def get_conn_pool(self):
return self.conn_pool
+ def get_conn_uri(self):
+ """
+ Get connection URI for PostgreSQL protocol.
+ """
+ infinity_uri = self.INFINITY_CONFIG["uri"]
+ postgres_port = self.INFINITY_CONFIG["postgres_port"]
+ db_name = self.INFINITY_CONFIG["db_name"]
+
+ if ":" in infinity_uri:
+ host, _ = infinity_uri.split(":")
+ return f"host={host} port={postgres_port} dbname={db_name}"
+ return f"host=localhost port={postgres_port} dbname={db_name}"
+
def refresh_conn_pool(self):
try:
inf_conn = self.conn_pool.get_conn()
diff --git a/common/doc_store/ob_conn_base.py b/common/doc_store/ob_conn_base.py
new file mode 100644
index 00000000000..0b95770ca5b
--- /dev/null
+++ b/common/doc_store/ob_conn_base.py
@@ -0,0 +1,739 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import json
+import logging
+import os
+import re
+import threading
+import time
+from abc import abstractmethod
+from typing import Any
+
+from pymysql.converters import escape_string
+from pyobvector import ObVecClient, FtsIndexParam, FtsParser, VECTOR
+from sqlalchemy import Column, Table
+
+from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
+
+ATTEMPT_TIME = 2
+
+# Common templates for OceanBase
+index_name_template = "ix_%s_%s"
+fulltext_index_name_template = "fts_idx_%s"
+fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
+vector_search_template = "cosine_distance(%s, '%s')"
+vector_column_pattern = re.compile(r"q_(?P\d+)_vec")
+
+
+def get_value_str(value: Any) -> str:
+ """Convert value to SQL string representation."""
+ if isinstance(value, str):
+ # escape_string already handles all necessary escaping for MySQL/OceanBase
+ # including backslashes, quotes, newlines, etc.
+ return f"'{escape_string(value)}'"
+ elif isinstance(value, bool):
+ return "true" if value else "false"
+ elif value is None:
+ return "NULL"
+ elif isinstance(value, (list, dict)):
+ json_str = json.dumps(value, ensure_ascii=False)
+ return f"'{escape_string(json_str)}'"
+ else:
+ return str(value)
+
+
+def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None):
+ """Execute function with distributed lock."""
+ if not timeout:
+ timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60"))
+
+ if not check_func():
+ from rag.utils.redis_conn import RedisDistributedLock
+ lock = RedisDistributedLock(lock_name)
+ if lock.acquire():
+ try:
+ process_func()
+ return
+ except Exception as e:
+ if "Duplicate" in str(e):
+ return
+ raise
+ finally:
+ lock.release()
+
+ if not check_func():
+ time.sleep(1)
+ count = 1
+ while count < timeout and not check_func():
+ count += 1
+ time.sleep(1)
+ if count >= timeout and not check_func():
+ raise Exception(f"Timeout to wait for process complete for {lock_name}.")
+
+
+class OBConnectionBase(DocStoreConnection):
+ """Base class for OceanBase document store connections."""
+
+ def __init__(self, logger_name: str = 'ragflow.ob_conn'):
+ from common.doc_store.ob_conn_pool import OB_CONN
+
+ self.logger = logging.getLogger(logger_name)
+ self.client: ObVecClient = OB_CONN.get_client()
+ self.es = OB_CONN.get_hybrid_search_client()
+ self.db_name = OB_CONN.get_db_name()
+ self.uri = OB_CONN.get_uri()
+
+ self._load_env_vars()
+
+ self._table_exists_cache: set[str] = set()
+ self._table_exists_cache_lock = threading.RLock()
+
+ # Cache for vector columns: stores (table_name, vector_size) tuples
+ self._vector_column_cache: set[tuple[str, int]] = set()
+ self._vector_column_cache_lock = threading.RLock()
+
+ self.logger.info(f"OceanBase {self.uri} connection initialized.")
+
+ def _load_env_vars(self):
+ def is_true(var: str, default: str) -> bool:
+ return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
+
+ self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
+ self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
+ self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
+ self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
+ self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
+
+ # Adjust settings based on hybrid search availability
+ if self.es is not None and self.search_original_content:
+ self.logger.info("HybridSearch is enabled, forcing search_original_content to False")
+ self.search_original_content = False
+
+ """
+ Template methods - must be implemented by subclasses
+ """
+
+ @abstractmethod
+ def get_index_columns(self) -> list[str]:
+ """Return list of column names that need regular indexes."""
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_fulltext_columns(self) -> list[str]:
+ """Return list of column names that need fulltext indexes (without weight suffix)."""
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_column_definitions(self) -> list[Column]:
+ """Return list of column definitions for table creation."""
+ raise NotImplementedError("Not implemented")
+
+ def get_extra_columns(self) -> list[Column]:
+ """Return list of extra columns to add after table creation. Override if needed."""
+ return []
+
+ def get_table_name(self, index_name: str, dataset_id: str) -> str:
+ """Return the actual table name given index_name and dataset_id."""
+ return index_name
+
+ @abstractmethod
+ def get_lock_prefix(self) -> str:
+ """Return the lock name prefix for distributed locking."""
+ raise NotImplementedError("Not implemented")
+
+ """
+ Database operations
+ """
+
+ def db_type(self) -> str:
+ return "oceanbase"
+
+ def health(self) -> dict:
+ return {
+ "uri": self.uri,
+ "version_comment": self._get_variable_value("version_comment")
+ }
+
+ def _get_variable_value(self, var_name: str) -> Any:
+ rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'")
+ for row in rows:
+ return row[1]
+ raise Exception(f"Variable '{var_name}' not found.")
+
+ """
+ Table operations - common implementation using template methods
+ """
+
+ def _check_table_exists_cached(self, table_name: str) -> bool:
+ """
+ Check table existence with cache to reduce INFORMATION_SCHEMA queries.
+ Thread-safe implementation using RLock.
+ """
+ if table_name in self._table_exists_cache:
+ return True
+
+ try:
+ if not self.client.check_table_exists(table_name):
+ return False
+
+ # Check regular indexes
+ for column_name in self.get_index_columns():
+ if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
+ return False
+
+ # Check fulltext indexes
+ for column_name in self.get_fulltext_columns():
+ if not self._index_exists(table_name, fulltext_index_name_template % column_name):
+ return False
+
+ # Check extra columns
+ for column in self.get_extra_columns():
+ if not self._column_exist(table_name, column.name):
+ return False
+
+ except Exception as e:
+ raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
+
+ with self._table_exists_cache_lock:
+ if table_name not in self._table_exists_cache:
+ self._table_exists_cache.add(table_name)
+ return True
+
+ def _create_table(self, table_name: str):
+ """Create table using column definitions from subclass."""
+ self._create_table_with_columns(table_name, self.get_column_definitions())
+
+ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
+ """Create index/table with all necessary indexes."""
+ table_name = self.get_table_name(index_name, dataset_id)
+ lock_prefix = self.get_lock_prefix()
+
+ try:
+ _try_with_lock(
+ lock_name=f"{lock_prefix}create_table_{table_name}",
+ check_func=lambda: self.client.check_table_exists(table_name),
+ process_func=lambda: self._create_table(table_name),
+ )
+
+ for column_name in self.get_index_columns():
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_idx_{table_name}_{column_name}",
+ check_func=lambda cn=column_name: self._index_exists(table_name,
+ index_name_template % (table_name, cn)),
+ process_func=lambda cn=column_name: self._add_index(table_name, cn),
+ )
+
+ for column_name in self.get_fulltext_columns():
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_fulltext_idx_{table_name}_{column_name}",
+ check_func=lambda cn=column_name: self._index_exists(table_name, fulltext_index_name_template % cn),
+ process_func=lambda cn=column_name: self._add_fulltext_index(table_name, cn),
+ )
+
+ # Add vector column and index (skip metadata refresh, will be done in finally)
+ self._ensure_vector_column_exists(table_name, vector_size, refresh_metadata=False)
+
+ # Add extra columns if any
+ for column in self.get_extra_columns():
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_{column.name}_{table_name}",
+ check_func=lambda c=column: self._column_exist(table_name, c.name),
+ process_func=lambda c=column: self._add_column(table_name, c),
+ )
+
+ except Exception as e:
+ raise Exception(f"OBConnection.create_idx error: {str(e)}")
+ finally:
+ self.client.refresh_metadata([table_name])
+
+ def create_doc_meta_idx(self, index_name: str):
+ """
+ Create a document metadata table.
+
+ Table name pattern: ragflow_doc_meta_{tenant_id}
+ - Per-tenant metadata table for storing document metadata fields
+ """
+ from sqlalchemy import JSON
+ from sqlalchemy.dialects.mysql import VARCHAR
+
+ table_name = index_name
+ lock_prefix = self.get_lock_prefix()
+
+ # Define columns for document metadata table
+ doc_meta_columns = [
+ Column("id", VARCHAR(256), primary_key=True, comment="document id"),
+ Column("kb_id", VARCHAR(256), nullable=False, comment="knowledge base id"),
+ Column("meta_fields", JSON, nullable=True, comment="document metadata fields"),
+ ]
+
+ try:
+ # Create table with distributed lock
+ _try_with_lock(
+ lock_name=f"{lock_prefix}create_doc_meta_table_{table_name}",
+ check_func=lambda: self.client.check_table_exists(table_name),
+ process_func=lambda: self._create_table_with_columns(table_name, doc_meta_columns),
+ )
+
+ # Create index on kb_id for better query performance
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_idx_{table_name}_kb_id",
+ check_func=lambda: self._index_exists(table_name, index_name_template % (table_name, "kb_id")),
+ process_func=lambda: self._add_index(table_name, "kb_id"),
+ )
+
+ self.logger.info(f"Created document metadata table '{table_name}'.")
+ return True
+
+ except Exception as e:
+ self.logger.error(f"OBConnection.create_doc_meta_idx error: {str(e)}")
+ return False
+ finally:
+ self.client.refresh_metadata([table_name])
+
+ def delete_idx(self, index_name: str, dataset_id: str):
+ """Delete index/table."""
+ # For doc_meta tables, use index_name directly as table name
+ if index_name.startswith("ragflow_doc_meta_"):
+ table_name = index_name
+ else:
+ table_name = self.get_table_name(index_name, dataset_id)
+ try:
+ if self.client.check_table_exists(table_name=table_name):
+ self.client.drop_table_if_exist(table_name)
+ self.logger.info(f"Dropped table '{table_name}'.")
+ except Exception as e:
+ raise Exception(f"OBConnection.delete_idx error: {str(e)}")
+
+ def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
+ """Check if index/table exists."""
+ # For doc_meta tables, use index_name directly as table name
+ if index_name.startswith("ragflow_doc_meta_"):
+ table_name = index_name
+ else:
+ table_name = self.get_table_name(index_name, dataset_id) if dataset_id else index_name
+ return self._check_table_exists_cached(table_name)
+
+ """
+ Table operations - helper methods
+ """
+
+ def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
+ where_clause = "WHERE " + " AND ".join(filter_list) if filter_list and len(filter_list) > 0 else ""
+ (count,) = self.client.perform_raw_text_sql(
+ f"SELECT COUNT(*) FROM {table_name} {where_clause}"
+ ).fetchone()
+ return count
+
+ def _column_exist(self, table_name: str, column_name: str) -> bool:
+ return self._get_count(
+ table_name="INFORMATION_SCHEMA.COLUMNS",
+ filter_list=[
+ f"TABLE_SCHEMA = '{self.db_name}'",
+ f"TABLE_NAME = '{table_name}'",
+ f"COLUMN_NAME = '{column_name}'",
+ ]) > 0
+
+ def _index_exists(self, table_name: str, idx_name: str) -> bool:
+ return self._get_count(
+ table_name="INFORMATION_SCHEMA.STATISTICS",
+ filter_list=[
+ f"TABLE_SCHEMA = '{self.db_name}'",
+ f"TABLE_NAME = '{table_name}'",
+ f"INDEX_NAME = '{idx_name}'",
+ ]) > 0
+
+ def _create_table_with_columns(self, table_name: str, columns: list[Column]):
+ """Create table with specified columns."""
+ if table_name in self.client.metadata_obj.tables:
+ self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj))
+
+ table_options = {
+ "mysql_charset": "utf8mb4",
+ "mysql_collate": "utf8mb4_unicode_ci",
+ "mysql_organization": "heap",
+ }
+
+ self.client.create_table(
+ table_name=table_name,
+ columns=[c.copy() for c in columns],
+ **table_options,
+ )
+ self.logger.info(f"Created table '{table_name}'.")
+
+ def _add_index(self, table_name: str, column_name: str):
+ idx_name = index_name_template % (table_name, column_name)
+ self.client.create_index(
+ table_name=table_name,
+ is_vec_index=False,
+ index_name=idx_name,
+ column_names=[column_name],
+ )
+ self.logger.info(f"Created index '{idx_name}' on table '{table_name}'.")
+
+ def _add_fulltext_index(self, table_name: str, column_name: str):
+ fulltext_idx_name = fulltext_index_name_template % column_name
+ self.client.create_fts_idx_with_fts_index_param(
+ table_name=table_name,
+ fts_idx_param=FtsIndexParam(
+ index_name=fulltext_idx_name,
+ field_names=[column_name],
+ parser_type=FtsParser.IK,
+ ),
+ )
+ self.logger.info(f"Created full text index '{fulltext_idx_name}' on table '{table_name}'.")
+
+ def _add_vector_column(self, table_name: str, vector_size: int):
+ vector_field_name = f"q_{vector_size}_vec"
+ self.client.add_columns(
+ table_name=table_name,
+ columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)],
+ )
+ self.logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.")
+
+ def _add_vector_index(self, table_name: str, vector_field_name: str):
+ vector_idx_name = f"{vector_field_name}_idx"
+ self.client.create_index(
+ table_name=table_name,
+ is_vec_index=True,
+ index_name=vector_idx_name,
+ column_names=[vector_field_name],
+ vidx_params="distance=cosine, type=hnsw, lib=vsag",
+ )
+ self.logger.info(
+ f"Created vector index '{vector_idx_name}' on table '{table_name}' with column '{vector_field_name}'."
+ )
+
+ def _add_column(self, table_name: str, column: Column):
+ try:
+ self.client.add_columns(
+ table_name=table_name,
+ columns=[column.copy()],
+ )
+ self.logger.info(f"Added column '{column.name}' to table '{table_name}'.")
+ except Exception as e:
+ self.logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}")
+
+ def _ensure_vector_column_exists(self, table_name: str, vector_size: int, refresh_metadata: bool = True):
+ """
+ Ensure vector column and index exist for the given vector size.
+ This method is safe to call multiple times - it will skip if already exists.
+ Uses cache to avoid repeated INFORMATION_SCHEMA queries.
+
+ Args:
+ table_name: Name of the table
+ vector_size: Size of the vector column
+ refresh_metadata: Whether to refresh SQLAlchemy metadata after changes (default True)
+ """
+ if vector_size <= 0:
+ return
+
+ cache_key = (table_name, vector_size)
+
+ # Check cache first
+ if cache_key in self._vector_column_cache:
+ return
+
+ lock_prefix = self.get_lock_prefix()
+ vector_field_name = f"q_{vector_size}_vec"
+ vector_index_name = f"{vector_field_name}_idx"
+
+ # Check if already exists (may have been created by another process)
+ column_exists = self._column_exist(table_name, vector_field_name)
+ index_exists = self._index_exists(table_name, vector_index_name)
+
+ if column_exists and index_exists:
+ # Already exists, add to cache and return
+ with self._vector_column_cache_lock:
+ self._vector_column_cache.add(cache_key)
+ return
+
+ # Create column if needed
+ if not column_exists:
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_vector_column_{table_name}_{vector_field_name}",
+ check_func=lambda: self._column_exist(table_name, vector_field_name),
+ process_func=lambda: self._add_vector_column(table_name, vector_size),
+ )
+
+ # Create index if needed
+ if not index_exists:
+ _try_with_lock(
+ lock_name=f"{lock_prefix}add_vector_idx_{table_name}_{vector_field_name}",
+ check_func=lambda: self._index_exists(table_name, vector_index_name),
+ process_func=lambda: self._add_vector_index(table_name, vector_field_name),
+ )
+
+ if refresh_metadata:
+ self.client.refresh_metadata([table_name])
+
+ # Add to cache after successful creation
+ with self._vector_column_cache_lock:
+ self._vector_column_cache.add(cache_key)
+
+ def _execute_search_sql(self, sql: str) -> tuple[list, float]:
+ start_time = time.time()
+ res = self.client.perform_raw_text_sql(sql)
+ rows = res.fetchall()
+ elapsed_time = time.time() - start_time
+ return rows, elapsed_time
+
+ def _parse_fulltext_columns(
+ self,
+ fulltext_query: str,
+ fulltext_columns: list[str]
+ ) -> tuple[dict[str, str], dict[str, float]]:
+ """
+ Parse fulltext search columns with optional weight suffix and build search expressions.
+
+ Args:
+ fulltext_query: The escaped fulltext query string
+ fulltext_columns: List of column names, optionally with weight suffix (e.g., "col^0.5")
+
+ Returns:
+ Tuple of (fulltext_search_expr dict, fulltext_search_weight dict)
+ where weights are normalized to 0~1
+ """
+ fulltext_search_expr: dict[str, str] = {}
+ fulltext_search_weight: dict[str, float] = {}
+
+ # get fulltext match expression and weight values
+ for field in fulltext_columns:
+ parts = field.split("^")
+ column_name: str = parts[0]
+ column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
+
+ fulltext_search_weight[column_name] = column_weight
+ fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query)
+
+ # adjust the weight to 0~1
+ weight_sum = sum(fulltext_search_weight.values())
+ n = len(fulltext_search_weight)
+ if weight_sum <= 0 < n:
+ # All weights are 0 (e.g. "col^0"); use equal weights to avoid ZeroDivisionError
+ for column_name in fulltext_search_weight:
+ fulltext_search_weight[column_name] = 1.0 / n
+ else:
+ for column_name in fulltext_search_weight:
+ fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum
+
+ return fulltext_search_expr, fulltext_search_weight
+
+ def _build_vector_search_sql(
+ self,
+ table_name: str,
+ fields_expr: str,
+ vector_search_score_expr: str,
+ filters_expr: str,
+ vector_search_filter: str,
+ vector_search_expr: str,
+ limit: int,
+ vector_topn: int,
+ offset: int = 0
+ ) -> str:
+ sql = (
+ f"SELECT {fields_expr}, {vector_search_score_expr} AS _score"
+ f" FROM {table_name}"
+ f" WHERE {filters_expr} AND {vector_search_filter}"
+ f" ORDER BY {vector_search_expr}"
+ f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}"
+ )
+ if offset != 0:
+ sql += f" OFFSET {offset}"
+ return sql
+
+ def _build_fulltext_search_sql(
+ self,
+ table_name: str,
+ fields_expr: str,
+ fulltext_search_score_expr: str,
+ filters_expr: str,
+ fulltext_search_filter: str,
+ offset: int,
+ limit: int,
+ fulltext_topn: int,
+ hint: str = ""
+ ) -> str:
+ hint_expr = f"{hint} " if hint else ""
+ return (
+ f"SELECT {hint_expr}{fields_expr}, {fulltext_search_score_expr} AS _score"
+ f" FROM {table_name}"
+ f" WHERE {filters_expr} AND {fulltext_search_filter}"
+ f" ORDER BY _score DESC"
+ f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}"
+ )
+
+ def _build_filter_search_sql(
+ self,
+ table_name: str,
+ fields_expr: str,
+ filters_expr: str,
+ order_by_expr: str = "",
+ limit_expr: str = ""
+ ) -> str:
+ return (
+ f"SELECT {fields_expr}"
+ f" FROM {table_name}"
+ f" WHERE {filters_expr}"
+ f" {order_by_expr} {limit_expr}"
+ )
+
+ def _build_count_sql(
+ self,
+ table_name: str,
+ filters_expr: str,
+ extra_filter: str = "",
+ hint: str = ""
+ ) -> str:
+ hint_expr = f"{hint} " if hint else ""
+ where_clause = f"{filters_expr} AND {extra_filter}" if extra_filter else filters_expr
+ return f"SELECT {hint_expr}COUNT(id) FROM {table_name} WHERE {where_clause}"
+
+ def _row_to_entity(self, data, fields: list[str]) -> dict:
+ entity = {}
+ for i, field in enumerate(fields):
+ value = data[i]
+ if value is None:
+ continue
+ entity[field] = value
+ return entity
+
+ def _get_dataset_id_field(self) -> str:
+ return "kb_id"
+
+ def _get_filters(self, condition: dict) -> list[str]:
+ filters: list[str] = []
+ for k, v in condition.items():
+ if not v:
+ continue
+ if k == "exists":
+ filters.append(f"{v} IS NOT NULL")
+ elif k == "must_not" and isinstance(v, dict) and "exists" in v:
+ filters.append(f"{v.get('exists')} IS NULL")
+ elif isinstance(v, list):
+ values: list[str] = []
+ for item in v:
+ values.append(get_value_str(item))
+ value = ", ".join(values)
+ filters.append(f"{k} IN ({value})")
+ else:
+ filters.append(f"{k} = {get_value_str(v)}")
+ return filters
+
+ def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
+ if not self._check_table_exists_cached(index_name):
+ return None
+ try:
+ res = self.client.get(
+ table_name=index_name,
+ ids=[doc_id],
+ )
+ row = res.fetchone()
+ if row is None:
+ return None
+ return self._row_to_entity(row, fields=list(res.keys()))
+ except Exception as e:
+ self.logger.exception(f"OBConnectionBase.get({doc_id}) got exception")
+ raise e
+
+ def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
+ if not self._check_table_exists_cached(index_name):
+ return 0
+ # For doc_meta tables, don't add dataset_id to condition
+ if not index_name.startswith("ragflow_doc_meta_"):
+ condition[self._get_dataset_id_field()] = dataset_id
+ try:
+ from sqlalchemy import text
+ res = self.client.get(
+ table_name=index_name,
+ ids=None,
+ where_clause=[text(f) for f in self._get_filters(condition)],
+ output_column_name=["id"],
+ )
+ rows = res.fetchall()
+ if len(rows) == 0:
+ return 0
+ ids = [row[0] for row in rows]
+ self.logger.debug(f"OBConnection.delete, filters: {condition}, ids: {ids}")
+ self.client.delete(
+ table_name=index_name,
+ ids=ids,
+ )
+ return len(ids)
+ except Exception as e:
+ self.logger.error(f"OBConnection.delete error: {str(e)}")
+ return 0
+
+ """
+ Abstract CRUD methods that must be implemented by subclasses
+ """
+
+ @abstractmethod
+ def search(
+ self,
+ select_fields: list[str],
+ highlight_fields: list[str],
+ condition: dict,
+ match_expressions: list[MatchExpr],
+ order_by: OrderByExpr,
+ offset: int,
+ limit: int,
+ index_names: str | list[str],
+ knowledgebase_ids: list[str],
+ agg_fields: list[str] | None = None,
+ rank_feature: dict | None = None,
+ **kwargs,
+ ):
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
+ raise NotImplementedError("Not implemented")
+
+ """
+ Helper functions for search result - abstract methods
+ """
+
+ @abstractmethod
+ def get_total(self, res) -> int:
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_doc_ids(self, res) -> list[str]:
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_highlight(self, res, keywords: list[str], field_name: str):
+ raise NotImplementedError("Not implemented")
+
+ @abstractmethod
+ def get_aggregation(self, res, field_name: str):
+ raise NotImplementedError("Not implemented")
+
+ """
+ SQL - can be overridden by subclasses
+ """
+
+ def sql(self, sql: str, fetch_size: int, format: str):
+ """Execute SQL query - default implementation."""
+ return None
diff --git a/common/doc_store/ob_conn_pool.py b/common/doc_store/ob_conn_pool.py
new file mode 100644
index 00000000000..5cb995edb50
--- /dev/null
+++ b/common/doc_store/ob_conn_pool.py
@@ -0,0 +1,191 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+import os
+import time
+
+from pyobvector import ObVecClient
+from pyobvector.client import ClusterVersionException
+from pyobvector.client.hybrid_search import HybridSearch
+from pyobvector.util import ObVersion
+
+from common import settings
+from common.decorator import singleton
+
+ATTEMPT_TIME = 2
+OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
+
+logger = logging.getLogger('ragflow.ob_conn_pool')
+
+
+@singleton
+class OceanBaseConnectionPool:
+
+ def __init__(self):
+ self.client = None
+ self.es = None # HybridSearch client
+
+ if hasattr(settings, "OB"):
+ self.OB_CONFIG = settings.OB
+ else:
+ self.OB_CONFIG = settings.get_base_config("oceanbase", {})
+
+ scheme = self.OB_CONFIG.get("scheme")
+ ob_config = self.OB_CONFIG.get("config", {})
+
+ if scheme and scheme.lower() == "mysql":
+ mysql_config = settings.get_base_config("mysql", {})
+ logger.info("Use MySQL scheme to create OceanBase connection.")
+ host = mysql_config.get("host", "localhost")
+ port = mysql_config.get("port", 2881)
+ self.username = mysql_config.get("user", "root@test")
+ self.password = mysql_config.get("password", "infini_rag_flow")
+ max_connections = mysql_config.get("max_connections", 300)
+ else:
+ logger.info("Use customized config to create OceanBase connection.")
+ host = ob_config.get("host", "localhost")
+ port = ob_config.get("port", 2881)
+ self.username = ob_config.get("user", "root@test")
+ self.password = ob_config.get("password", "infini_rag_flow")
+ max_connections = ob_config.get("max_connections", 300)
+
+ self.db_name = ob_config.get("db_name", "test")
+ self.uri = f"{host}:{port}"
+
+ logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
+
+ max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
+ pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
+
+ for _ in range(ATTEMPT_TIME):
+ try:
+ self.client = ObVecClient(
+ uri=self.uri,
+ user=self.username,
+ password=self.password,
+ db_name=self.db_name,
+ pool_pre_ping=True,
+ pool_recycle=3600,
+ pool_size=max_connections,
+ max_overflow=max_overflow,
+ pool_timeout=pool_timeout,
+ )
+ break
+ except Exception as e:
+ logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.")
+ time.sleep(5)
+
+ if self.client is None:
+ msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts."
+ logger.error(msg)
+ raise Exception(msg)
+
+ self._check_ob_version()
+ self._try_to_update_ob_query_timeout()
+ self._init_hybrid_search(max_connections, max_overflow, pool_timeout)
+
+ logger.info(f"OceanBase {self.uri} is healthy.")
+
+ def _check_ob_version(self):
+ try:
+ res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone()
+ version_str = res[0] if res else None
+ logger.info(f"OceanBase {self.uri} version is {version_str}")
+ except Exception as e:
+ raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}")
+
+ if not version_str:
+ raise Exception(f"Failed to get OceanBase version from {self.uri}.")
+
+ ob_version = ObVersion.from_db_version_string(version_str)
+ if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1):
+ raise Exception(
+ f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
+ )
+
+ def _try_to_update_ob_query_timeout(self):
+ try:
+ rows = self.client.perform_raw_text_sql("SHOW VARIABLES LIKE 'ob_query_timeout'")
+ for row in rows:
+ val = row[1]
+ if val and int(val) >= OB_QUERY_TIMEOUT:
+ return
+ except Exception as e:
+ logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e))
+
+ try:
+ self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}")
+ logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT)
+ self.client.engine.dispose()
+ logger.info("Disposed all connections in engine pool to refresh connection pool")
+ except Exception as e:
+ logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}")
+
+ def _init_hybrid_search(self, max_connections, max_overflow, pool_timeout):
+ enable_hybrid_search = os.getenv('ENABLE_HYBRID_SEARCH', 'false').lower() in ['true', '1', 'yes', 'y']
+ if enable_hybrid_search:
+ try:
+ self.es = HybridSearch(
+ uri=self.uri,
+ user=self.username,
+ password=self.password,
+ db_name=self.db_name,
+ pool_pre_ping=True,
+ pool_recycle=3600,
+ pool_size=max_connections,
+ max_overflow=max_overflow,
+ pool_timeout=pool_timeout,
+ )
+ logger.info("OceanBase Hybrid Search feature is enabled")
+ except ClusterVersionException as e:
+ logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
+ self.es = None
+
+ def get_client(self) -> ObVecClient:
+ return self.client
+
+ def get_hybrid_search_client(self) -> HybridSearch | None:
+ return self.es
+
+ def get_db_name(self) -> str:
+ return self.db_name
+
+ def get_uri(self) -> str:
+ return self.uri
+
+ def refresh_client(self) -> ObVecClient:
+ try:
+ self.client.perform_raw_text_sql("SELECT 1 FROM DUAL")
+ return self.client
+ except Exception as e:
+ logger.warning(f"OceanBase connection unhealthy: {str(e)}, refreshing...")
+ self.client.engine.dispose()
+ return self.client
+
+ def __del__(self):
+ if hasattr(self, "client") and self.client:
+ try:
+ self.client.engine.dispose()
+ except Exception:
+ pass
+ if hasattr(self, "es") and self.es:
+ try:
+ self.es.engine.dispose()
+ except Exception:
+ pass
+
+
+OB_CONN = OceanBaseConnectionPool()
diff --git a/common/float_utils.py b/common/float_utils.py
index 74db3b1cfdf..d7ef42fbbe5 100644
--- a/common/float_utils.py
+++ b/common/float_utils.py
@@ -14,6 +14,7 @@
# limitations under the License.
#
+
def get_float(v):
"""
Convert a value to float, handling None and exceptions gracefully.
@@ -39,8 +40,19 @@ def get_float(v):
42.0
"""
if v is None:
- return float('-inf')
+ return float("-inf")
try:
return float(v)
except Exception:
- return float('-inf')
\ No newline at end of file
+ return float("-inf")
+
+
+def normalize_overlapped_percent(overlapped_percent):
+ try:
+ value = float(overlapped_percent)
+ except (TypeError, ValueError):
+ return 0
+ if 0 < value < 1:
+ value *= 100
+ value = int(value)
+ return max(0, min(value, 90))
diff --git a/common/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py
index 0e8cd5128bf..9033c79c4ab 100644
--- a/common/mcp_tool_call_conn.py
+++ b/common/mcp_tool_call_conn.py
@@ -42,9 +42,10 @@ def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
- def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
+ def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None:
self.__class__._ALL_INSTANCES.add(self)
+ self._custom_header = custom_header
self._mcp_server = mcp_server
self._server_variables = server_variables or {}
self._queue = asyncio.Queue()
@@ -59,6 +60,7 @@ def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = No
async def _mcp_server_loop(self) -> None:
url = self._mcp_server.url.strip()
raw_headers: dict[str, str] = self._mcp_server.headers or {}
+ custom_header: dict[str, str] = self._custom_header or {}
headers: dict[str, str] = {}
for h, v in raw_headers.items():
@@ -67,6 +69,11 @@ async def _mcp_server_loop(self) -> None:
if nh.strip() and nv.strip().strip("Bearer"):
headers[nh] = nv
+ for h, v in custom_header.items():
+ nh = Template(h).safe_substitute(custom_header)
+ nv = Template(v).safe_substitute(custom_header)
+ headers[nh] = nv
+
if self._mcp_server.server_type == MCPServerType.SSE:
# SSE transport
try:
diff --git a/common/metadata_utils.py b/common/metadata_utils.py
index fdca6b9356f..c919bd186af 100644
--- a/common/metadata_utils.py
+++ b/common/metadata_utils.py
@@ -19,15 +19,15 @@
import json_repair
-from rag.prompts.generator import gen_meta_filter
-
-
def convert_conditions(metadata_condition):
if metadata_condition is None:
metadata_condition = {}
op_mapping = {
"is": "=",
- "not is": "≠"
+ "not is": "≠",
+ ">=": "≥",
+ "<=": "≤",
+ "!=": "≠"
}
return [
{
@@ -47,24 +47,66 @@ def filter_out(v2docs, operator, value):
for input, docids in v2docs.items():
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
- try:
- if isinstance(input, list):
- input = input[0]
- input = ast.literal_eval(input)
- value = ast.literal_eval(value)
- except Exception:
- pass
- if isinstance(input, str):
- input = input.lower()
- if isinstance(value, str):
- value = value.lower()
+ # Check if input is in YYYY-MM-DD date format
+ input_str = str(input).strip()
+ value_str = str(value).strip()
+
+ # Strict date format detection: YYYY-MM-DD (must be 10 chars with correct format)
+ is_input_date = (
+ len(input_str) == 10 and
+ input_str[4] == '-' and
+ input_str[7] == '-' and
+ input_str[:4].isdigit() and
+ input_str[5:7].isdigit() and
+ input_str[8:10].isdigit()
+ )
+
+ is_value_date = (
+ len(value_str) == 10 and
+ value_str[4] == '-' and
+ value_str[7] == '-' and
+ value_str[:4].isdigit() and
+ value_str[5:7].isdigit() and
+ value_str[8:10].isdigit()
+ )
+
+ if is_value_date:
+ # Query value is in date format
+ if is_input_date:
+ # Data is also in date format: perform date comparison
+ input = input_str
+ value = value_str
+ else:
+ # Data is not in date format: skip this record (no match)
+ continue
+ else:
+ # Query value is not in date format: use original logic
+ try:
+ if isinstance(input, list):
+ input = input[0]
+ input = ast.literal_eval(input)
+ value = ast.literal_eval(value)
+ except Exception:
+ pass
+
+ # Convert strings to lowercase
+ if isinstance(input, str):
+ input = input.lower()
+ if isinstance(value, str):
+ value = value.lower()
+ else:
+ # Non-comparison operators: maintain original logic
+ if isinstance(input, str):
+ input = input.lower()
+ if isinstance(value, str):
+ value = value.lower()
matched = False
try:
if operator == "contains":
- matched = input in value if not isinstance(input, list) else all(i in value for i in input)
+ matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(str(i).find(value) >= 0 for i in input)
elif operator == "not contains":
- matched = input not in value if not isinstance(input, list) else all(i not in value for i in input)
+ matched = str(input).find(value) == -1 if not isinstance(input, list) else all(str(i).find(value) == -1 for i in input)
elif operator == "in":
matched = input in value if not isinstance(input, list) else all(i in value for i in input)
elif operator == "not in":
@@ -96,20 +138,24 @@ def filter_out(v2docs, operator, value):
ids.extend(docids)
return ids
- for k, v2docs in metas.items():
- for f in filters:
- if k != f["key"]:
- continue
+ for f in filters:
+ k = f["key"]
+ if k not in metas:
+ # Key not found in metas: treat as no match
+ ids = []
+ else:
+ v2docs = metas[k]
ids = filter_out(v2docs, f["op"], f["value"])
- if not doc_ids:
- doc_ids = set(ids)
+
+ if not doc_ids:
+ doc_ids = set(ids)
+ else:
+ if logic == "and":
+ doc_ids = doc_ids & set(ids)
+ if not doc_ids:
+ return []
else:
- if logic == "and":
- doc_ids = doc_ids & set(ids)
- else:
- doc_ids = doc_ids | set(ids)
- if not doc_ids:
- return []
+ doc_ids = doc_ids | set(ids)
return list(doc_ids)
@@ -133,6 +179,8 @@ async def apply_meta_data_filter(
list of doc_ids, ["-999"] when manual filters yield no result, or None
when auto/semi_auto filters return empty.
"""
+ from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import
+
doc_ids = list(base_doc_ids) if base_doc_ids else []
if not meta_data_filter:
@@ -146,11 +194,22 @@ async def apply_meta_data_filter(
if not doc_ids:
return None
elif method == "semi_auto":
- selected_keys = meta_data_filter.get("semi_auto", [])
+ selected_keys = []
+ constraints = {}
+ for item in meta_data_filter.get("semi_auto", []):
+ if isinstance(item, str):
+ selected_keys.append(item)
+ elif isinstance(item, dict):
+ key = item.get("key")
+ op = item.get("op")
+ selected_keys.append(key)
+ if op:
+ constraints[key] = op
+
if selected_keys:
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas:
- filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
+ filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
return None
@@ -212,7 +271,7 @@ def update_metadata_to(metadata, meta):
return metadata
-def metadata_schema(metadata: list|None) -> Dict[str, Any]:
+def metadata_schema(metadata: dict|list|None) -> Dict[str, Any]:
if not metadata:
return {}
properties = {}
@@ -238,3 +297,47 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
json_schema["additionalProperties"] = False
return json_schema
+
+
+def _is_json_schema(obj: dict) -> bool:
+ if not isinstance(obj, dict):
+ return False
+ if "$schema" in obj:
+ return True
+ return obj.get("type") == "object" and isinstance(obj.get("properties"), dict)
+
+
+def _is_metadata_list(obj: list) -> bool:
+ if not isinstance(obj, list) or not obj:
+ return False
+ for item in obj:
+ if not isinstance(item, dict):
+ return False
+ key = item.get("key")
+ if not isinstance(key, str) or not key:
+ return False
+ if "enum" in item and not isinstance(item["enum"], list):
+ return False
+ if "description" in item and not isinstance(item["description"], str):
+ return False
+ if "descriptions" in item and not isinstance(item["descriptions"], str):
+ return False
+ return True
+
+
+def turn2jsonschema(obj: dict | list) -> Dict[str, Any]:
+ if isinstance(obj, dict) and _is_json_schema(obj):
+ return obj
+ if isinstance(obj, list) and _is_metadata_list(obj):
+ normalized = []
+ for item in obj:
+ description = item.get("description", item.get("descriptions", ""))
+ normalized_item = {
+ "key": item.get("key"),
+ "description": description,
+ }
+ if "enum" in item:
+ normalized_item["enum"] = item["enum"]
+ normalized.append(normalized_item)
+ return metadata_schema(normalized)
+ return {}
diff --git a/common/misc_utils.py b/common/misc_utils.py
index ae56fe5c484..19b608ca7fe 100644
--- a/common/misc_utils.py
+++ b/common/misc_utils.py
@@ -14,15 +14,20 @@
# limitations under the License.
#
+import asyncio
import base64
+import functools
import hashlib
-import uuid
-import requests
-import threading
+import logging
+import os
import subprocess
import sys
-import os
-import logging
+import threading
+import uuid
+
+from concurrent.futures import ThreadPoolExecutor
+
+import requests
def get_uuid():
return uuid.uuid1().hex
@@ -106,3 +111,23 @@ def pip_install_torch():
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
+
+
+@once
+def _thread_pool_executor():
+ max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
+ try:
+ max_workers = int(max_workers_env)
+ except ValueError:
+ max_workers = 128
+ if max_workers < 1:
+ max_workers = 1
+ return ThreadPoolExecutor(max_workers=max_workers)
+
+
+async def thread_pool_exec(func, *args, **kwargs):
+ loop = asyncio.get_running_loop()
+ if kwargs:
+ func = functools.partial(func, *args, **kwargs)
+ return await loop.run_in_executor(_thread_pool_executor(), func)
+ return await loop.run_in_executor(_thread_pool_executor(), func, *args)
diff --git a/common/parser_config_utils.py b/common/parser_config_utils.py
index 0a79f3ad177..0bc7ffc28b3 100644
--- a/common/parser_config_utils.py
+++ b/common/parser_config_utils.py
@@ -26,5 +26,8 @@ def normalize_layout_recognizer(layout_recognizer_raw: Any) -> tuple[Any, str |
if lowered.endswith("@mineru"):
parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
layout_recognizer = "MinerU"
+ elif lowered.endswith("@paddleocr"):
+ parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
+ layout_recognizer = "PaddleOCR"
return layout_recognizer, parser_model_name
diff --git a/common/settings.py b/common/settings.py
index 7b19357ad4a..97be3c5215f 100644
--- a/common/settings.py
+++ b/common/settings.py
@@ -41,6 +41,7 @@
import memory.utils.es_conn as memory_es_conn
import memory.utils.infinity_conn as memory_infinity_conn
+import memory.utils.ob_conn as memory_ob_conn
LLM = None
LLM_FACTORY = None
@@ -79,6 +80,7 @@
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
+DOC_ENGINE_OCEANBASE = (DOC_ENGINE.lower() == "oceanbase")
docStoreConn = None
@@ -241,15 +243,20 @@ def init_settings():
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
OAUTH_CONFIG = get_base_config("oauth", {})
- global DOC_ENGINE, DOC_ENGINE_INFINITY, docStoreConn, ES, OB, OS, INFINITY
+ global DOC_ENGINE, DOC_ENGINE_INFINITY, DOC_ENGINE_OCEANBASE, docStoreConn, ES, OB, OS, INFINITY
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
+ DOC_ENGINE_OCEANBASE = (DOC_ENGINE.lower() == "oceanbase")
lower_case_doc_engine = DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch":
ES = get_base_config("es", {})
docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
- INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
+ INFINITY = get_base_config("infinity", {
+ "uri": "infinity:23817",
+ "postgres_port": 5432,
+ "db_name": "default_db"
+ })
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
OS = get_base_config("os", {})
@@ -257,6 +264,9 @@ def init_settings():
elif lower_case_doc_engine == "oceanbase":
OB = get_base_config("oceanbase", {})
docStoreConn = rag.utils.ob_conn.OBConnection()
+ elif lower_case_doc_engine == "seekdb":
+ OB = get_base_config("seekdb", {})
+ docStoreConn = rag.utils.ob_conn.OBConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
@@ -266,8 +276,14 @@ def init_settings():
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
- INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
+ INFINITY = get_base_config("infinity", {
+ "uri": "infinity:23817",
+ "postgres_port": 5432,
+ "db_name": "default_db"
+ })
msgStoreConn = memory_infinity_conn.InfinityConnection()
+ elif lower_case_doc_engine in ["oceanbase", "seekdb"]:
+ msgStoreConn = memory_ob_conn.OBConnection()
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
@@ -306,7 +322,7 @@ def init_settings():
global retriever, kg_retriever
retriever = search.Dealer(docStoreConn)
- from graphrag import search as kg_search
+ from rag.graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn)
diff --git a/conf/doc_meta_es_mapping.json b/conf/doc_meta_es_mapping.json
new file mode 100644
index 00000000000..eeab3b985a7
--- /dev/null
+++ b/conf/doc_meta_es_mapping.json
@@ -0,0 +1,29 @@
+{
+ "settings": {
+ "index": {
+ "number_of_shards": 2,
+ "number_of_replicas": 0,
+ "refresh_interval": "1000ms"
+ }
+ },
+ "mappings": {
+ "_source": {
+ "enabled": true
+ },
+ "dynamic": "runtime",
+ "properties": {
+ "id": {
+ "type": "keyword",
+ "store": true
+ },
+ "kb_id": {
+ "type": "keyword",
+ "store": true
+ },
+ "meta_fields": {
+ "type": "object",
+ "dynamic": true
+ }
+ }
+ }
+}
diff --git a/conf/doc_meta_infinity_mapping.json b/conf/doc_meta_infinity_mapping.json
new file mode 100644
index 00000000000..471912c6e20
--- /dev/null
+++ b/conf/doc_meta_infinity_mapping.json
@@ -0,0 +1,5 @@
+{
+ "id": {"type": "varchar", "default": ""},
+ "kb_id": {"type": "varchar", "default": ""},
+ "meta_fields": {"type": "json", "default": "{}"}
+}
\ No newline at end of file
diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json
index de2dd3a17e9..83e3d5f9828 100644
--- a/conf/infinity_mapping.json
+++ b/conf/infinity_mapping.json
@@ -1,7 +1,7 @@
{
"id": {"type": "varchar", "default": ""},
"doc_id": {"type": "varchar", "default": ""},
- "kb_id": {"type": "varchar", "default": ""},
+ "kb_id": {"type": "varchar", "default": "", "index_type": {"type": "secondary", "cardinality": "low"}},
"mom_id": {"type": "varchar", "default": ""},
"create_time": {"type": "varchar", "default": ""},
"create_timestamp_flt": {"type": "float", "default": 0.0},
@@ -9,6 +9,7 @@
"docnm": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "docnm_kwd, title_tks, title_sm_tks"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
+ "important_kwd_empty_count": {"type": "integer", "default": 0},
"important_keywords": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "important_kwd, important_tks"},
"questions": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "question_kwd, question_tks"},
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_with_weight, content_ltks, content_sm_ltks"},
@@ -20,7 +21,7 @@
"weight_flt": {"type": "float", "default": 0.0},
"rank_int": {"type": "integer", "default": 0},
"rank_flt": {"type": "float", "default": 0},
- "available_int": {"type": "integer", "default": 1},
+ "available_int": {"type": "integer", "default": 1, "index_type": {"type": "secondary", "cardinality": "low"}},
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"pagerank_fea": {"type": "integer", "default": 0},
diff --git a/conf/llm_factories.json b/conf/llm_factories.json
index 451c8f45235..be9e7322d77 100644
--- a/conf/llm_factories.json
+++ b/conf/llm_factories.json
@@ -994,6 +994,13 @@
"model_type": "chat",
"is_tools": true
},
+ {
+ "llm_name": "kimi-k2.5",
+ "tags": "LLM,CHAT,256k",
+ "max_tokens": 256000,
+ "model_type": "chat",
+ "is_tools": true
+ },
{
"llm_name": "kimi-latest",
"tags": "LLM,CHAT,8k,32k,128k",
@@ -1467,58 +1474,53 @@
"rank": "980",
"llm": [
{
- "llm_name": "gemini-3-pro-preview",
- "tags": "LLM,CHAT,1M,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-3-pro-preview",
+ "tags": "LLM,CHAT,1M,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-flash",
- "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-flash",
+ "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-pro",
- "tags": "LLM,CHAT,IMAGE2TEXT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-pro",
+ "tags": "LLM,CHAT,IMAGE2TEXT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-flash-lite",
- "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-flash-lite",
+ "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.0-flash",
- "tags": "LLM,CHAT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.0-flash",
+ "tags": "LLM,CHAT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.0-flash-lite",
- "tags": "LLM,CHAT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.0-flash-lite",
+ "tags": "LLM,CHAT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
+
{
- "llm_name": "text-embedding-004",
- "tags": "TEXT EMBEDDING",
- "max_tokens": 2048,
- "model_type": "embedding"
- },
- {
- "llm_name": "embedding-001",
- "tags": "TEXT EMBEDDING",
- "max_tokens": 2048,
- "model_type": "embedding"
+ "llm_name": "gemini-embedding-001",
+ "tags": "TEXT EMBEDDING",
+ "max_tokens": 2048,
+ "model_type": "embedding"
}
]
},
@@ -1593,9 +1595,30 @@
{
"name": "StepFun",
"logo": "",
- "tags": "LLM",
+ "tags": "LLM,IMAGE2TEXT,SPEECH2TEXT,TTS",
"status": "1",
"llm": [
+ {
+ "llm_name": "step-3",
+ "tags": "LLM,CHAT,IMAGE2TEXT,64k",
+ "max_tokens": 65536,
+ "model_type": "image2text",
+ "is_tools": true
+ },
+ {
+ "llm_name": "step-2-mini",
+ "tags": "LLM,CHAT,32k",
+ "max_tokens": 32768,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "step-2-16k",
+ "tags": "LLM,CHAT,16k",
+ "max_tokens": 16384,
+ "model_type": "chat",
+ "is_tools": true
+ },
{
"llm_name": "step-1-8k",
"tags": "LLM,CHAT,8k",
@@ -1610,13 +1633,6 @@
"model_type": "chat",
"is_tools": true
},
- {
- "llm_name": "step-1-128k",
- "tags": "LLM,CHAT,128k",
- "max_tokens": 131072,
- "model_type": "chat",
- "is_tools": true
- },
{
"llm_name": "step-1-256k",
"tags": "LLM,CHAT,256k",
@@ -1624,12 +1640,61 @@
"model_type": "chat",
"is_tools": true
},
+ {
+ "llm_name": "step-r1-v-mini",
+ "tags": "LLM,CHAT,IMAGE2TEXT,100k",
+ "max_tokens": 102400,
+ "model_type": "image2text",
+ "is_tools": true
+ },
{
"llm_name": "step-1v-8k",
- "tags": "LLM,CHAT,IMAGE2TEXT",
+ "tags": "LLM,CHAT,IMAGE2TEXT,8k",
"max_tokens": 8192,
"model_type": "image2text",
"is_tools": true
+ },
+ {
+ "llm_name": "step-1v-32k",
+ "tags": "LLM,CHAT,IMAGE2TEXT,32k",
+ "max_tokens": 32768,
+ "model_type": "image2text",
+ "is_tools": true
+ },
+ {
+ "llm_name": "step-1o-vision-32k",
+ "tags": "LLM,CHAT,IMAGE2TEXT,32k",
+ "max_tokens": 32768,
+ "model_type": "image2text",
+ "is_tools": true
+ },
+ {
+ "llm_name": "step-1o-turbo-vision",
+ "tags": "LLM,CHAT,IMAGE2TEXT,32k",
+ "max_tokens": 32768,
+ "model_type": "image2text",
+ "is_tools": true
+ },
+ {
+ "llm_name": "step-tts-mini",
+ "tags": "TTS,1000c",
+ "max_tokens": 1000,
+ "model_type": "tts",
+ "is_tools": false
+ },
+ {
+ "llm_name": "step-tts-vivid",
+ "tags": "TTS,1000c",
+ "max_tokens": 1000,
+ "model_type": "tts",
+ "is_tools": false
+ },
+ {
+ "llm_name": "step-asr",
+ "tags": "SPEECH2TEXT,100MB",
+ "max_tokens": 32768,
+ "model_type": "speech2text",
+ "is_tools": false
}
]
},
@@ -3731,23 +3796,23 @@
},
{
"llm_name": "Qwen3-Reranker-8B",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK,32K",
"max_tokens": 32768,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
"llm_name": "Qwen3-Reranker-4B",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK,32K",
"max_tokens": 32768,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
"llm_name": "Qwen3-Reranker-0.6B",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK,32K",
"max_tokens": 32768,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
@@ -3787,9 +3852,9 @@
},
{
"llm_name": "jina-reranker-m0",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK,10K",
"max_tokens": 10240,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
@@ -3801,9 +3866,9 @@
},
{
"llm_name": "bce-reranker-base_v1",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK",
"max_tokens": 512,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
@@ -3815,9 +3880,9 @@
},
{
"llm_name": "bge-reranker-v2-m3",
- "tags": "TEXT EMBEDDING,TEXT RE-RANK",
+ "tags": "TEXT RE-RANK",
"max_tokens": 8192,
- "model_type": "embedding",
+ "model_type": "reranker",
"is_tools": false
},
{
@@ -5531,6 +5596,51 @@
"status": "1",
"rank": "900",
"llm": []
+ },
+ {
+ "name": "PaddleOCR",
+ "logo": "",
+ "tags": "OCR",
+ "status": "1",
+ "rank": "910",
+ "llm": []
+ },
+ {
+ "name": "n1n",
+ "logo": "",
+ "tags": "LLM",
+ "status": "1",
+ "rank": "900",
+ "llm": [
+ {
+ "llm_name": "gpt-4o-mini",
+ "tags": "LLM,CHAT,128K,IMAGE2TEXT",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-4o",
+ "tags": "LLM,CHAT,128K,IMAGE2TEXT",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-3.5-turbo",
+ "tags": "LLM,CHAT,4K",
+ "max_tokens": 4096,
+ "model_type": "chat",
+ "is_tools": false
+ },
+ {
+ "llm_name": "deepseek-chat",
+ "tags": "LLM,CHAT,128K",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ }
+ ]
}
]
}
diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml
index afd9b98bcb0..b303d69ae75 100644
--- a/conf/service_conf.yaml
+++ b/conf/service_conf.yaml
@@ -29,6 +29,7 @@ os:
password: 'infini_rag_flow_OS_01'
infinity:
uri: 'localhost:23817'
+ postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
@@ -67,9 +68,11 @@ user_default_llm:
# oss:
# access_key: 'access_key'
# secret_key: 'secret_key'
-# endpoint_url: 'http://oss-cn-hangzhou.aliyuncs.com'
+# endpoint_url: 'https://s3.oss-cn-hangzhou.aliyuncs.com'
# region: 'cn-hangzhou'
# bucket: 'bucket_name'
+# signature_version: 's3'
+# addressing_style: 'virtual'
# azure:
# auth_type: 'sas'
# container_url: 'container_url'
diff --git a/conf/system_settings.json b/conf/system_settings.json
new file mode 100644
index 00000000000..f546aa1436b
--- /dev/null
+++ b/conf/system_settings.json
@@ -0,0 +1,88 @@
+{
+ "system_settings": [
+ {
+ "name": "enable_whitelist",
+ "source": "variable",
+ "data_type": "bool",
+ "value": "true"
+ },
+ {
+ "name": "default_role",
+ "source": "variable",
+ "data_type": "string",
+ "value": ""
+ },
+ {
+ "name": "mail.server",
+ "source": "variable",
+ "data_type": "string",
+ "value": ""
+ },
+ {
+ "name": "mail.port",
+ "source": "variable",
+ "data_type": "integer",
+ "value": ""
+ },
+ {
+ "name": "mail.use_ssl",
+ "source": "variable",
+ "data_type": "bool",
+ "value": "false"
+ },
+ {
+ "name": "mail.use_tls",
+ "source": "variable",
+ "data_type": "bool",
+ "value": "false"
+ },
+ {
+ "name": "mail.username",
+ "source": "variable",
+ "data_type": "string",
+ "value": ""
+ },
+ {
+ "name": "mail.password",
+ "source": "variable",
+ "data_type": "string",
+ "value": ""
+ },
+ {
+ "name": "mail.timeout",
+ "source": "variable",
+ "data_type": "integer",
+ "value": "10"
+ },
+ {
+ "name": "mail.default_sender",
+ "source": "variable",
+ "data_type": "string",
+ "value": ""
+ },
+ {
+ "name": "sandbox.provider_type",
+ "source": "variable",
+ "data_type": "string",
+ "value": "self_managed"
+ },
+ {
+ "name": "sandbox.self_managed",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{\"endpoint\": \"http://localhost:9385\", \"timeout\": 30, \"max_retries\": 3, \"pool_size\": 10}"
+ },
+ {
+ "name": "sandbox.aliyun_codeinterpreter",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{}"
+ },
+ {
+ "name": "sandbox.e2b",
+ "source": "variable",
+ "data_type": "json",
+ "value": "{}"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/deepdoc/README.md b/deepdoc/README.md
index 9a5e44089aa..db70e30d805 100644
--- a/deepdoc/README.md
+++ b/deepdoc/README.md
@@ -103,6 +103,31 @@ We use vision information to resolve problems as human being.
+
+ - **Table Auto-Rotation**. For scanned PDFs where tables may be incorrectly oriented (rotated 90°, 180°, or 270°),
+ the PDF parser automatically detects the best rotation angle using OCR confidence scores before performing
+ table structure recognition. This significantly improves OCR accuracy and table structure detection for rotated tables.
+
+ The feature evaluates 4 rotation angles (0°, 90°, 180°, 270°) and selects the one with highest OCR confidence.
+ After determining the best orientation, it re-performs OCR on the correctly rotated table image.
+
+ This feature is **enabled by default**. You can control it via environment variable:
+ ```bash
+ # Disable table auto-rotation
+ export TABLE_AUTO_ROTATE=false
+
+ # Enable table auto-rotation (default)
+ export TABLE_AUTO_ROTATE=true
+ ```
+
+ Or via API parameter:
+ ```python
+ from deepdoc.parser import PdfParser
+
+ parser = PdfParser()
+ # Disable auto-rotation for this call
+ boxes, tables = parser(pdf_path, auto_rotate_tables=False)
+ ```
## 3. Parser
diff --git a/deepdoc/README_zh.md b/deepdoc/README_zh.md
index 4ada7edb201..3eb38e3ddda 100644
--- a/deepdoc/README_zh.md
+++ b/deepdoc/README_zh.md
@@ -102,6 +102,30 @@ export HF_ENDPOINT=https://hf-mirror.com
+
+ - **表格自动旋转(Table Auto-Rotation)**。对于扫描的 PDF 文档,表格可能存在方向错误(旋转了 90°、180° 或 270°),
+ PDF 解析器会在进行表格结构识别之前,自动使用 OCR 置信度来检测最佳旋转角度。这大大提高了旋转表格的 OCR 准确性和表格结构检测效果。
+
+ 该功能会评估 4 个旋转角度(0°、90°、180°、270°),并选择 OCR 置信度最高的角度。
+ 确定最佳方向后,会对旋转后的表格图像重新进行 OCR 识别。
+
+ 此功能**默认启用**。您可以通过环境变量控制:
+ ```bash
+ # 禁用表格自动旋转
+ export TABLE_AUTO_ROTATE=false
+
+ # 启用表格自动旋转(默认)
+ export TABLE_AUTO_ROTATE=true
+ ```
+
+ 或通过 API 参数控制:
+ ```python
+ from deepdoc.parser import PdfParser
+
+ parser = PdfParser()
+ # 禁用此次调用的自动旋转
+ boxes, tables = parser(pdf_path, auto_rotate_tables=False)
+ ```
## 3. 解析器
diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py
index a8087ef8197..2fe3420192c 100644
--- a/deepdoc/parser/excel_parser.py
+++ b/deepdoc/parser/excel_parser.py
@@ -156,6 +156,55 @@ def _extract_images_from_worksheet(ws, sheetname=None):
continue
return raw_items
+ @staticmethod
+ def _get_actual_row_count(ws):
+ max_row = ws.max_row
+ if not max_row:
+ return 0
+ if max_row <= 10000:
+ return max_row
+
+ max_col = min(ws.max_column or 1, 50)
+
+ def row_has_data(row_idx):
+ for col_idx in range(1, max_col + 1):
+ cell = ws.cell(row=row_idx, column=col_idx)
+ if cell.value is not None and str(cell.value).strip():
+ return True
+ return False
+
+ if not any(row_has_data(i) for i in range(1, min(101, max_row + 1))):
+ return 0
+
+ left, right = 1, max_row
+ last_data_row = 1
+
+ while left <= right:
+ mid = (left + right) // 2
+ found = False
+ for r in range(mid, min(mid + 10, max_row + 1)):
+ if row_has_data(r):
+ found = True
+ last_data_row = max(last_data_row, r)
+ break
+ if found:
+ left = mid + 1
+ else:
+ right = mid - 1
+
+ for r in range(last_data_row, min(last_data_row + 500, max_row + 1)):
+ if row_has_data(r):
+ last_data_row = r
+
+ return last_data_row
+
+ @staticmethod
+ def _get_rows_limited(ws):
+ actual_rows = RAGFlowExcelParser._get_actual_row_count(ws)
+ if actual_rows == 0:
+ return []
+ return list(ws.iter_rows(min_row=1, max_row=actual_rows))
+
def html(self, fnm, chunk_rows=256):
from html import escape
@@ -171,7 +220,7 @@ def _fmt(v):
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
- rows = list(ws.rows)
+ rows = RAGFlowExcelParser._get_rows_limited(ws)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
@@ -223,7 +272,7 @@ def __call__(self, fnm):
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
- rows = list(ws.rows)
+ rows = RAGFlowExcelParser._get_rows_limited(ws)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
@@ -238,6 +287,8 @@ def __call__(self, fnm):
t = str(ti[i].value) if i < len(ti) else ""
t += (":" if t else "") + str(c.value)
fields.append(t)
+ if not fields:
+ continue
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
@@ -249,14 +300,14 @@ def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = RAGFlowExcelParser._load_excel_to_workbook(BytesIO(binary))
total = 0
-
+
for sheetname in wb.sheetnames:
- try:
- ws = wb[sheetname]
- total += len(list(ws.rows))
- except Exception as e:
- logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
- continue
+ try:
+ ws = wb[sheetname]
+ total += RAGFlowExcelParser._get_actual_row_count(ws)
+ except Exception as e:
+ logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
+ continue
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py
index 8dfcd02d2c5..ec5e333de28 100644
--- a/deepdoc/parser/figure_parser.py
+++ b/deepdoc/parser/figure_parser.py
@@ -14,6 +14,7 @@
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed
+import logging
from PIL import Image
@@ -21,9 +22,10 @@
from api.db.services.llm_service import LLMBundle
from common.connection_utils import timeout
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
-from rag.prompts.generator import vision_llm_figure_describe_prompt
-
+from rag.prompts.generator import vision_llm_figure_describe_prompt, vision_llm_figure_describe_prompt_with_context
+from rag.nlp import append_context2table_image4pdf
+# need to delete before pr
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if not figures_data_without_positions:
return []
@@ -36,7 +38,6 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if isinstance(figure_data[1], Image.Image)
]
-
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not sections:
return tbls
@@ -84,20 +85,36 @@ def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
if not tbls:
return []
+ sections = kwargs.get("sections")
+ parser_config = kwargs.get("parser_config", {})
+ context_size = max(0, int(parser_config.get("image_context_size", 0) or 0))
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
if vision_model:
+
def is_figure_item(item):
- return (
- isinstance(item[0][0], Image.Image) and
- isinstance(item[0][1], list)
- )
+ return isinstance(item[0][0], Image.Image) and isinstance(item[0][1], list)
+
figures_data = [item for item in tbls if is_figure_item(item)]
+ figure_contexts = []
+ if sections and figures_data and context_size > 0:
+ figure_contexts = append_context2table_image4pdf(
+ sections,
+ figures_data,
+ context_size,
+ return_context=True,
+ )
try:
- docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
+ docx_vision_parser = VisionFigureParser(
+ vision_model=vision_model,
+ figures_data=figures_data,
+ figure_contexts=figure_contexts,
+ context_size=context_size,
+ **kwargs,
+ )
boosted_figures = docx_vision_parser(callback=callback)
tbls = [item for item in tbls if not is_figure_item(item)]
tbls.extend(boosted_figures)
@@ -106,12 +123,57 @@ def is_figure_item(item):
return tbls
-shared_executor = ThreadPoolExecutor(max_workers=10)
+def vision_figure_parser_docx_wrapper_naive(chunks, idx_lst, callback=None, **kwargs):
+ if not chunks:
+ return []
+ try:
+ vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
+ callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
+ except Exception:
+ vision_model = None
+ if vision_model:
+ @timeout(30, 3)
+ def worker(idx, ck):
+ context_above = ck.get("context_above", "")
+ context_below = ck.get("context_below", "")
+ if context_above or context_below:
+ prompt = vision_llm_figure_describe_prompt_with_context(
+ # context_above + caption if any
+ context_above=ck.get("context_above") + ck.get("text", ""),
+ context_below=ck.get("context_below"),
+ )
+ logging.info(f"[VisionFigureParser] figure={idx} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context")
+ logging.info(f"[VisionFigureParser] figure={idx} context_above_snippet={context_above[:512]}")
+ logging.info(f"[VisionFigureParser] figure={idx} context_below_snippet={context_below[:512]}")
+ else:
+ prompt = vision_llm_figure_describe_prompt()
+ logging.info(f"[VisionFigureParser] figure={idx} context_len=0 prompt=default")
+ description_text = picture_vision_llm_chunk(
+ binary=ck.get("image"),
+ vision_model=vision_model,
+ prompt=prompt,
+ callback=callback,
+ )
+ return idx, description_text
+
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ futures = [
+ executor.submit(worker, idx, chunks[idx])
+ for idx in idx_lst
+ ]
+
+ for future in as_completed(futures):
+ idx, description = future.result()
+ chunks[idx]['text'] += description
+
+shared_executor = ThreadPoolExecutor(max_workers=10)
class VisionFigureParser:
def __init__(self, vision_model, figures_data, *args, **kwargs):
self.vision_model = vision_model
+ self.figure_contexts = kwargs.get("figure_contexts") or []
+ self.context_size = max(0, int(kwargs.get("context_size", 0) or 0))
self._extract_figures_info(figures_data)
assert len(self.figures) == len(self.descriptions)
assert not self.positions or (len(self.figures) == len(self.positions))
@@ -156,10 +218,25 @@ def __call__(self, **kwargs):
@timeout(30, 3)
def process(figure_idx, figure_binary):
+ context_above = ""
+ context_below = ""
+ if figure_idx < len(self.figure_contexts):
+ context_above, context_below = self.figure_contexts[figure_idx]
+ if context_above or context_below:
+ prompt = vision_llm_figure_describe_prompt_with_context(
+ context_above=context_above,
+ context_below=context_below,
+ )
+ logging.info(f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context")
+ logging.info(f"[VisionFigureParser] figure={figure_idx} context_above_snippet={context_above[:512]}")
+ logging.info(f"[VisionFigureParser] figure={figure_idx} context_below_snippet={context_below[:512]}")
+ else:
+ prompt = vision_llm_figure_describe_prompt()
+ logging.info(f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_len=0 prompt=default")
description_text = picture_vision_llm_chunk(
binary=figure_binary,
vision_model=self.vision_model,
- prompt=vision_llm_figure_describe_prompt(),
+ prompt=prompt,
callback=callback,
)
return figure_idx, description_text
diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py
index aba237dd1b2..cc4c99c76b8 100644
--- a/deepdoc/parser/mineru_parser.py
+++ b/deepdoc/parser/mineru_parser.py
@@ -17,6 +17,7 @@
import logging
import os
import re
+import shutil
import sys
import tempfile
import threading
@@ -138,39 +139,58 @@ def __init__(self, mineru_path: str = "mineru", mineru_api: str = "", mineru_ser
self.outlines = []
self.logger = logging.getLogger(self.__class__.__name__)
+ @staticmethod
+ def _is_zipinfo_symlink(member: zipfile.ZipInfo) -> bool:
+ return (member.external_attr >> 16) & 0o170000 == 0o120000
+
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
+ base_dir = Path(extract_to).resolve()
with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ members = zip_ref.infolist()
if not root_dir:
- files = zip_ref.namelist()
- if files and files[0].endswith("/"):
- root_dir = files[0]
+ if members and members[0].filename.endswith("/"):
+ root_dir = members[0].filename
else:
root_dir = None
-
- if not root_dir or not root_dir.endswith("/"):
- self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
- zip_ref.extractall(extract_to)
- return
-
- root_len = len(root_dir)
- for member in zip_ref.infolist():
- filename = member.filename
- if filename == root_dir:
+ if root_dir:
+ root_dir = root_dir.replace("\\", "/")
+ if not root_dir.endswith("/"):
+ root_dir += "/"
+
+ for member in members:
+ if member.flag_bits & 0x1:
+ raise RuntimeError(f"[MinerU] Encrypted zip entry not supported: {member.filename}")
+ if self._is_zipinfo_symlink(member):
+ raise RuntimeError(f"[MinerU] Symlink zip entry not supported: {member.filename}")
+
+ name = member.filename.replace("\\", "/")
+ if root_dir and name == root_dir:
self.logger.info("[MinerU] Ignore root folder...")
continue
+ if root_dir and name.startswith(root_dir):
+ name = name[len(root_dir) :]
+ if not name:
+ continue
+ if name.startswith("/") or name.startswith("//") or re.match(r"^[A-Za-z]:", name):
+ raise RuntimeError(f"[MinerU] Unsafe zip path (absolute): {member.filename}")
- path = filename
- if path.startswith(root_dir):
- path = path[root_len:]
+ parts = [p for p in name.split("/") if p not in ("", ".")]
+ if any(p == ".." for p in parts):
+ raise RuntimeError(f"[MinerU] Unsafe zip path (traversal): {member.filename}")
+
+ rel_path = os.path.join(*parts) if parts else ""
+ dest_path = (Path(extract_to) / rel_path).resolve(strict=False)
+ if dest_path != base_dir and base_dir not in dest_path.parents:
+ raise RuntimeError(f"[MinerU] Unsafe zip path (escape): {member.filename}")
- full_path = os.path.join(extract_to, path)
if member.is_dir():
- os.makedirs(full_path, exist_ok=True)
- else:
- os.makedirs(os.path.dirname(full_path), exist_ok=True)
- with open(full_path, "wb") as f:
- f.write(zip_ref.read(filename))
+ os.makedirs(dest_path, exist_ok=True)
+ continue
+
+ os.makedirs(dest_path.parent, exist_ok=True)
+ with zip_ref.open(member) as src, open(dest_path, "wb") as dst:
+ shutil.copyfileobj(src, dst)
@staticmethod
def _is_http_endpoint_valid(url, timeout=5):
@@ -237,8 +257,6 @@ def _run_mineru_api(
output_path = tempfile.mkdtemp(prefix=f"{pdf_file_name}_{options.method}_", dir=str(output_dir))
output_zip_path = os.path.join(str(output_dir), f"{Path(output_path).name}.zip")
- files = {"files": (pdf_file_name + ".pdf", open(pdf_file_path, "rb"), "application/pdf")}
-
data = {
"output_dir": "./output",
"lang_list": options.lang,
@@ -270,26 +288,35 @@ def _run_mineru_api(
self.logger.info(f"[MinerU] invoke api: {self.mineru_api}/file_parse backend={options.backend} server_url={data.get('server_url')}")
if callback:
callback(0.20, f"[MinerU] invoke api: {self.mineru_api}/file_parse")
- response = requests.post(url=f"{self.mineru_api}/file_parse", files=files, data=data, headers=headers,
- timeout=1800)
-
- response.raise_for_status()
- if response.headers.get("Content-Type") == "application/zip":
- self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
-
- if callback:
- callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
-
- with open(output_zip_path, "wb") as f:
- f.write(response.content)
-
- self.logger.info(f"[MinerU] Unzip to {output_path}...")
- self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
-
- if callback:
- callback(0.40, f"[MinerU] Unzip to {output_path}...")
- else:
- self.logger.warning(f"[MinerU] not zip returned from api: {response.headers.get('Content-Type')}")
+ with open(pdf_file_path, "rb") as pdf_file:
+ files = {"files": (pdf_file_name + ".pdf", pdf_file, "application/pdf")}
+ with requests.post(
+ url=f"{self.mineru_api}/file_parse",
+ files=files,
+ data=data,
+ headers=headers,
+ timeout=1800,
+ stream=True,
+ ) as response:
+ response.raise_for_status()
+ content_type = response.headers.get("Content-Type", "")
+ if content_type.startswith("application/zip"):
+ self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
+
+ if callback:
+ callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
+
+ with open(output_zip_path, "wb") as f:
+ response.raw.decode_content = True
+ shutil.copyfileobj(response.raw, f)
+
+ self.logger.info(f"[MinerU] Unzip to {output_path}...")
+ self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
+
+ if callback:
+ callback(0.40, f"[MinerU] Unzip to {output_path}...")
+ else:
+ self.logger.warning(f"[MinerU] not zip returned from api: {content_type}")
except Exception as e:
raise RuntimeError(f"[MinerU] api failed with exception {e}")
self.logger.info("[MinerU] Api completed successfully.")
diff --git a/deepdoc/parser/paddleocr_parser.py b/deepdoc/parser/paddleocr_parser.py
new file mode 100644
index 00000000000..85db63b862d
--- /dev/null
+++ b/deepdoc/parser/paddleocr_parser.py
@@ -0,0 +1,560 @@
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+import base64
+import logging
+import os
+import re
+from dataclasses import asdict, dataclass, field, fields
+from io import BytesIO
+from os import PathLike
+from pathlib import Path
+from typing import Any, Callable, ClassVar, Literal, Optional, Union, Tuple, List
+
+import numpy as np
+import pdfplumber
+import requests
+from PIL import Image
+
+try:
+ from deepdoc.parser.pdf_parser import RAGFlowPdfParser
+except Exception:
+
+ class RAGFlowPdfParser:
+ pass
+
+
+AlgorithmType = Literal["PaddleOCR-VL"]
+SectionTuple = tuple[str, ...]
+TableTuple = tuple[str, ...]
+ParseResult = tuple[list[SectionTuple], list[TableTuple]]
+
+
+_MARKDOWN_IMAGE_PATTERN = re.compile(
+ r"""
+ ]*>\s*
+
]*/>\s*
+
+ |
+ ]*/>
+ """,
+ re.IGNORECASE | re.VERBOSE | re.DOTALL,
+)
+
+
+def _remove_images_from_markdown(markdown: str) -> str:
+ return _MARKDOWN_IMAGE_PATTERN.sub("", markdown)
+
+
+@dataclass
+class PaddleOCRVLConfig:
+ """Configuration for PaddleOCR-VL algorithm."""
+
+ use_doc_orientation_classify: Optional[bool] = False
+ use_doc_orientation_classify: Optional[bool] = False
+ use_doc_unwarping: Optional[bool] = False
+ use_layout_detection: Optional[bool] = None
+ use_chart_recognition: Optional[bool] = None
+ use_seal_recognition: Optional[bool] = None
+ use_ocr_for_image_block: Optional[bool] = None
+ layout_threshold: Optional[Union[float, dict]] = None
+ layout_nms: Optional[bool] = None
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None
+ layout_merge_bboxes_mode: Optional[Union[str, dict]] = None
+ layout_shape_mode: Optional[str] = None
+ prompt_label: Optional[str] = None
+ format_block_content: Optional[bool] = True
+ repetition_penalty: Optional[float] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ min_pixels: Optional[int] = None
+ max_pixels: Optional[int] = None
+ max_new_tokens: Optional[int] = None
+ merge_layout_blocks: Optional[bool] = False
+ markdown_ignore_labels: Optional[List[str]] = None
+ vlm_extra_args: Optional[dict] = None
+ restructure_pages: Optional[bool] = False
+ merge_tables: Optional[bool] = None
+ relevel_titles: Optional[bool] = None
+
+
+@dataclass
+class PaddleOCRConfig:
+ """Main configuration for PaddleOCR parser."""
+
+ api_url: str = ""
+ access_token: Optional[str] = None
+ algorithm: AlgorithmType = "PaddleOCR-VL"
+ request_timeout: int = 600
+ prettify_markdown: bool = True
+ show_formula_number: bool = True
+ visualize: bool = False
+ additional_params: dict[str, Any] = field(default_factory=dict)
+ algorithm_config: dict[str, Any] = field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: Optional[dict[str, Any]]) -> "PaddleOCRConfig":
+ """Create configuration from dictionary."""
+ if not config:
+ return cls()
+
+ cfg = config.copy()
+ algorithm = cfg.get("algorithm", "PaddleOCR-VL")
+
+ # Validate algorithm
+ if algorithm not in ("PaddleOCR-VL"):
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
+
+ # Extract algorithm-specific configuration
+ algorithm_config: dict[str, Any] = {}
+ if algorithm == "PaddleOCR-VL":
+ algorithm_config = asdict(PaddleOCRVLConfig())
+ algorithm_config_user = cfg.get("algorithm_config")
+ if isinstance(algorithm_config_user, dict):
+ algorithm_config.update({k: v for k, v in algorithm_config_user.items() if v is not None})
+
+ # Remove processed keys
+ cfg.pop("algorithm_config", None)
+
+ # Prepare initialization arguments
+ field_names = {field.name for field in fields(cls)}
+ init_kwargs: dict[str, Any] = {}
+
+ for field_name in field_names:
+ if field_name in cfg:
+ init_kwargs[field_name] = cfg[field_name]
+
+ init_kwargs["algorithm_config"] = algorithm_config
+
+ return cls(**init_kwargs)
+
+ @classmethod
+ def from_kwargs(cls, **kwargs: Any) -> "PaddleOCRConfig":
+ """Create configuration from keyword arguments."""
+ return cls.from_dict(kwargs)
+
+
+class PaddleOCRParser(RAGFlowPdfParser):
+ """Parser for PDF documents using PaddleOCR API."""
+
+ _ZOOMIN = 2
+
+ _COMMON_FIELD_MAPPING: ClassVar[dict[str, str]] = {
+ "prettify_markdown": "prettifyMarkdown",
+ "show_formula_number": "showFormulaNumber",
+ "visualize": "visualize",
+ }
+
+ _ALGORITHM_FIELD_MAPPINGS: ClassVar[dict[str, dict[str, str]]] = {
+ "PaddleOCR-VL": {
+ "use_doc_orientation_classify": "useDocOrientationClassify",
+ "use_doc_unwarping": "useDocUnwarping",
+ "use_layout_detection": "useLayoutDetection",
+ "use_chart_recognition": "useChartRecognition",
+ "use_seal_recognition": "useSealRecognition",
+ "use_ocr_for_image_block": "useOcrForImageBlock",
+ "layout_threshold": "layoutThreshold",
+ "layout_nms": "layoutNms",
+ "layout_unclip_ratio": "layoutUnclipRatio",
+ "layout_merge_bboxes_mode": "layoutMergeBboxesMode",
+ "layout_shape_mode": "layoutShapeMode",
+ "prompt_label": "promptLabel",
+ "format_block_content": "formatBlockContent",
+ "repetition_penalty": "repetitionPenalty",
+ "temperature": "temperature",
+ "top_p": "topP",
+ "min_pixels": "minPixels",
+ "max_pixels": "maxPixels",
+ "max_new_tokens": "maxNewTokens",
+ "merge_layout_blocks": "mergeLayoutBlocks",
+ "markdown_ignore_labels": "markdownIgnoreLabels",
+ "vlm_extra_args": "vlmExtraArgs",
+ "restructure_pages": "restructurePages",
+ "merge_tables": "mergeTables",
+ "relevel_titles": "relevelTitles",
+ },
+ }
+
+ def __init__(
+ self,
+ api_url: Optional[str] = None,
+ access_token: Optional[str] = None,
+ algorithm: AlgorithmType = "PaddleOCR-VL",
+ *,
+ request_timeout: int = 600,
+ ):
+ """Initialize PaddleOCR parser."""
+ super().__init__()
+
+ self.api_url = api_url.rstrip("/") if api_url else os.getenv("PADDLEOCR_API_URL", "")
+ self.access_token = access_token or os.getenv("PADDLEOCR_ACCESS_TOKEN")
+ self.algorithm = algorithm
+ self.request_timeout = request_timeout
+ self.logger = logging.getLogger(self.__class__.__name__)
+
+ # Force PDF file type
+ self.file_type = 0
+
+ # Initialize page images for cropping
+ self.page_images: list[Image.Image] = []
+ self.page_from = 0
+
+ # Public methods
+ def check_installation(self) -> tuple[bool, str]:
+ """Check if the parser is properly installed and configured."""
+ if not self.api_url:
+ return False, "[PaddleOCR] API URL not configured"
+
+ # TODO [@Bobholamovic]: Check URL availability and token validity
+
+ return True, ""
+
+ def parse_pdf(
+ self,
+ filepath: str | PathLike[str],
+ binary: BytesIO | bytes | None = None,
+ callback: Optional[Callable[[float, str], None]] = None,
+ *,
+ parse_method: str = "raw",
+ api_url: Optional[str] = None,
+ access_token: Optional[str] = None,
+ algorithm: Optional[AlgorithmType] = None,
+ request_timeout: Optional[int] = None,
+ prettify_markdown: Optional[bool] = None,
+ show_formula_number: Optional[bool] = None,
+ visualize: Optional[bool] = None,
+ additional_params: Optional[dict[str, Any]] = None,
+ algorithm_config: Optional[dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> ParseResult:
+ """Parse PDF document using PaddleOCR API."""
+ # Create configuration - pass all kwargs to capture VL config parameters
+ config_dict = {
+ "api_url": api_url if api_url is not None else self.api_url,
+ "access_token": access_token if access_token is not None else self.access_token,
+ "algorithm": algorithm if algorithm is not None else self.algorithm,
+ "request_timeout": request_timeout if request_timeout is not None else self.request_timeout,
+ }
+ if prettify_markdown is not None:
+ config_dict["prettify_markdown"] = prettify_markdown
+ if show_formula_number is not None:
+ config_dict["show_formula_number"] = show_formula_number
+ if visualize is not None:
+ config_dict["visualize"] = visualize
+ if additional_params is not None:
+ config_dict["additional_params"] = additional_params
+ if algorithm_config is not None:
+ config_dict["algorithm_config"] = algorithm_config
+
+ cfg = PaddleOCRConfig.from_dict(config_dict)
+
+ if not cfg.api_url:
+ raise RuntimeError("[PaddleOCR] API URL missing")
+
+ # Prepare file data and generate page images for cropping
+ data_bytes = self._prepare_file_data(filepath, binary)
+
+ # Generate page images for cropping functionality
+ input_source = filepath if binary is None else binary
+ try:
+ self.__images__(input_source, callback=callback)
+ except Exception as e:
+ self.logger.warning(f"[PaddleOCR] Failed to generate page images for cropping: {e}")
+
+ # Build and send request
+ result = self._send_request(data_bytes, cfg, callback)
+
+ # Process response
+ sections = self._transfer_to_sections(result, algorithm=cfg.algorithm, parse_method=parse_method)
+ if callback:
+ callback(0.9, f"[PaddleOCR] done, sections: {len(sections)}")
+
+ tables = self._transfer_to_tables(result)
+ if callback:
+ callback(1.0, f"[PaddleOCR] done, tables: {len(tables)}")
+
+ return sections, tables
+
+ def _prepare_file_data(self, filepath: str | PathLike[str], binary: BytesIO | bytes | None) -> bytes:
+ """Prepare file data for API request."""
+ source_path = Path(filepath)
+
+ if binary is not None:
+ if isinstance(binary, (bytes, bytearray)):
+ return binary
+ return binary.getbuffer().tobytes()
+
+ if not source_path.exists():
+ raise FileNotFoundError(f"[PaddleOCR] file not found: {source_path}")
+
+ return source_path.read_bytes()
+
+ def _build_payload(self, data: bytes, file_type: int, config: PaddleOCRConfig) -> dict[str, Any]:
+ """Build payload for API request."""
+ payload: dict[str, Any] = {
+ "file": base64.b64encode(data).decode("ascii"),
+ "fileType": file_type,
+ }
+
+ # Add common parameters
+ for param_key, param_value in [
+ ("prettify_markdown", config.prettify_markdown),
+ ("show_formula_number", config.show_formula_number),
+ ("visualize", config.visualize),
+ ]:
+ if param_value is not None:
+ api_param = self._COMMON_FIELD_MAPPING[param_key]
+ payload[api_param] = param_value
+
+ # Add algorithm-specific parameters
+ algorithm_mapping = self._ALGORITHM_FIELD_MAPPINGS.get(config.algorithm, {})
+ for param_key, param_value in config.algorithm_config.items():
+ if param_value is not None and param_key in algorithm_mapping:
+ api_param = algorithm_mapping[param_key]
+ payload[api_param] = param_value
+
+ # Add any additional parameters
+ if config.additional_params:
+ payload.update(config.additional_params)
+
+ return payload
+
+ def _send_request(self, data: bytes, config: PaddleOCRConfig, callback: Optional[Callable[[float, str], None]]) -> dict[str, Any]:
+ """Send request to PaddleOCR API and parse response."""
+ # Build payload
+ payload = self._build_payload(data, self.file_type, config)
+
+ # Prepare headers
+ headers = {"Content-Type": "application/json", "Client-Platform": "ragflow"}
+ if config.access_token:
+ headers["Authorization"] = f"token {config.access_token}"
+
+ self.logger.info("[PaddleOCR] invoking API")
+ if callback:
+ callback(0.1, "[PaddleOCR] submitting request")
+
+ # Send request
+ try:
+ resp = requests.post(config.api_url, json=payload, headers=headers, timeout=self.request_timeout)
+ resp.raise_for_status()
+ except Exception as exc:
+ if callback:
+ callback(-1, f"[PaddleOCR] request failed: {exc}")
+ raise RuntimeError(f"[PaddleOCR] request failed: {exc}")
+
+ # Parse response
+ try:
+ response_data = resp.json()
+ except Exception as exc:
+ raise RuntimeError(f"[PaddleOCR] response is not JSON: {exc}") from exc
+
+ if callback:
+ callback(0.8, "[PaddleOCR] response received")
+
+ # Validate response format
+ if response_data.get("errorCode") != 0 or not isinstance(response_data.get("result"), dict):
+ if callback:
+ callback(-1, "[PaddleOCR] invalid response format")
+ raise RuntimeError("[PaddleOCR] invalid response format")
+
+ return response_data["result"]
+
+ def _transfer_to_sections(self, result: dict[str, Any], algorithm: AlgorithmType, parse_method: str) -> list[SectionTuple]:
+ """Convert API response to section tuples."""
+ sections: list[SectionTuple] = []
+
+ if algorithm in ("PaddleOCR-VL",):
+ layout_parsing_results = result.get("layoutParsingResults", [])
+
+ for page_idx, layout_result in enumerate(layout_parsing_results):
+ pruned_result = layout_result.get("prunedResult", {})
+ parsing_res_list = pruned_result.get("parsing_res_list", [])
+
+ for block in parsing_res_list:
+ block_content = block.get("block_content", "").strip()
+ if not block_content:
+ continue
+
+ # Remove images
+ block_content = _remove_images_from_markdown(block_content)
+
+ label = block.get("block_label", "")
+ block_bbox = block.get("block_bbox", [0, 0, 0, 0])
+
+ tag = f"@@{page_idx + 1}\t{block_bbox[0] // self._ZOOMIN}\t{block_bbox[2] // self._ZOOMIN}\t{block_bbox[1] // self._ZOOMIN}\t{block_bbox[3] // self._ZOOMIN}##"
+
+ if parse_method == "manual":
+ sections.append((block_content, label, tag))
+ elif parse_method == "paper":
+ sections.append((block_content + tag, label))
+ else:
+ sections.append((block_content, tag))
+
+ return sections
+
+ def _transfer_to_tables(self, result: dict[str, Any]) -> list[TableTuple]:
+ """Convert API response to table tuples."""
+ return []
+
+ def __images__(self, fnm, page_from=0, page_to=100, callback=None):
+ """Generate page images from PDF for cropping."""
+ self.page_from = page_from
+ self.page_to = page_to
+ try:
+ with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
+ self.pdf = pdf
+ self.page_images = [p.to_image(resolution=72, antialias=True).original for i, p in enumerate(self.pdf.pages[page_from:page_to])]
+ except Exception as e:
+ self.page_images = None
+ self.logger.exception(e)
+
+ @staticmethod
+ def extract_positions(txt: str):
+ """Extract position information from text tags."""
+ poss = []
+ for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
+ pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
+ left, right, top, bottom = float(left), float(right), float(top), float(bottom)
+ poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
+ return poss
+
+ def crop(self, text: str, need_position: bool = False):
+ """Crop images from PDF based on position tags in text."""
+ imgs = []
+ poss = self.extract_positions(text)
+
+ if not poss:
+ if need_position:
+ return None, None
+ return
+
+ if not getattr(self, "page_images", None):
+ self.logger.warning("[PaddleOCR] crop called without page images; skipping image generation.")
+ if need_position:
+ return None, None
+ return
+
+ page_count = len(self.page_images)
+
+ filtered_poss = []
+ for pns, left, right, top, bottom in poss:
+ if not pns:
+ self.logger.warning("[PaddleOCR] Empty page index list in crop; skipping this position.")
+ continue
+ valid_pns = [p for p in pns if 0 <= p < page_count]
+ if not valid_pns:
+ self.logger.warning(f"[PaddleOCR] All page indices {pns} out of range for {page_count} pages; skipping.")
+ continue
+ filtered_poss.append((valid_pns, left, right, top, bottom))
+
+ poss = filtered_poss
+ if not poss:
+ self.logger.warning("[PaddleOCR] No valid positions after filtering; skip cropping.")
+ if need_position:
+ return None, None
+ return
+
+ max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
+ GAP = 6
+ pos = poss[0]
+ first_page_idx = pos[0][0]
+ poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
+ pos = poss[-1]
+ last_page_idx = pos[0][-1]
+ if not (0 <= last_page_idx < page_count):
+ self.logger.warning(f"[PaddleOCR] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
+ if need_position:
+ return None, None
+ return
+ last_page_height = self.page_images[last_page_idx].size[1]
+ poss.append(
+ (
+ [last_page_idx],
+ pos[1],
+ pos[2],
+ min(last_page_height, pos[4] + GAP),
+ min(last_page_height, pos[4] + 120),
+ )
+ )
+
+ positions = []
+ for ii, (pns, left, right, top, bottom) in enumerate(poss):
+ right = left + max_width
+
+ if bottom <= top:
+ bottom = top + 2
+
+ for pn in pns[1:]:
+ if 0 <= pn - 1 < page_count:
+ bottom += self.page_images[pn - 1].size[1]
+ else:
+ self.logger.warning(f"[PaddleOCR] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
+
+ if not (0 <= pns[0] < page_count):
+ self.logger.warning(f"[PaddleOCR] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
+ continue
+
+ img0 = self.page_images[pns[0]]
+ x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
+ crop0 = img0.crop((x0, y0, x1, y1))
+ imgs.append(crop0)
+ if 0 < ii < len(poss) - 1:
+ positions.append((pns[0] + self.page_from, x0, x1, y0, y1))
+
+ bottom -= img0.size[1]
+ for pn in pns[1:]:
+ if not (0 <= pn < page_count):
+ self.logger.warning(f"[PaddleOCR] Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
+ continue
+ page = self.page_images[pn]
+ x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
+ cimgp = page.crop((x0, y0, x1, y1))
+ imgs.append(cimgp)
+ if 0 < ii < len(poss) - 1:
+ positions.append((pn + self.page_from, x0, x1, y0, y1))
+ bottom -= page.size[1]
+
+ if not imgs:
+ if need_position:
+ return None, None
+ return
+
+ height = 0
+ for img in imgs:
+ height += img.size[1] + GAP
+ height = int(height)
+ width = int(np.max([i.size[0] for i in imgs]))
+ pic = Image.new("RGB", (width, height), (245, 245, 245))
+ height = 0
+ for ii, img in enumerate(imgs):
+ if ii == 0 or ii + 1 == len(imgs):
+ img = img.convert("RGBA")
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
+ overlay.putalpha(128)
+ img = Image.alpha_composite(img, overlay).convert("RGB")
+ pic.paste(img, (0, int(height)))
+ height += img.size[1] + GAP
+
+ if need_position:
+ return pic, positions
+ return pic
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ parser = PaddleOCRParser(api_url=os.getenv("PADDLEOCR_API_URL", ""), algorithm=os.getenv("PADDLEOCR_ALGORITHM", "PaddleOCR-VL"))
+ ok, reason = parser.check_installation()
+ print("PaddleOCR available:", ok, reason)
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index ce6b9298b1f..6681e4a893a 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -43,6 +43,10 @@
from rag.prompts.generator import vision_llm_describe_prompt
from common import settings
+
+
+from common.misc_utils import thread_pool_exec
+
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
@@ -88,6 +92,7 @@ def __init__(self, **kwargs):
try:
pip_install_torch()
import torch.cuda
+
if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"})
except Exception:
@@ -192,13 +197,125 @@ def _has_color(self, o):
return False
return True
- def _table_transformer_job(self, ZM):
+ def _evaluate_table_orientation(self, table_img, sample_ratio=0.3):
+ """
+ Evaluate the best rotation orientation for a table image.
+
+ Tests 4 rotation angles (0°, 90°, 180°, 270°) and uses OCR
+ confidence scores to determine the best orientation.
+
+ Args:
+ table_img: PIL Image object of the table region
+ sample_ratio: Sampling ratio for quick evaluation
+
+ Returns:
+ tuple: (best_angle, best_img, confidence_scores)
+ - best_angle: Best rotation angle (0, 90, 180, 270)
+ - best_img: Image rotated to best orientation
+ - confidence_scores: Dict of scores for each angle
+ """
+
+ rotations = [
+ (0, "original"),
+ (90, "rotate_90"), # clockwise 90°
+ (180, "rotate_180"), # 180°
+ (270, "rotate_270"), # clockwise 270° (counter-clockwise 90°)
+ ]
+
+ results = {}
+ best_score = -1
+ best_angle = 0
+ best_img = table_img
+ score_0 = None
+
+ for angle, name in rotations:
+ # Rotate image
+ if angle == 0:
+ rotated_img = table_img
+ else:
+ # PIL's rotate is counter-clockwise, use negative angle for clockwise
+ rotated_img = table_img.rotate(-angle, expand=True)
+
+ # Convert to numpy array for OCR
+ img_array = np.array(rotated_img)
+
+ # Perform OCR detection and recognition
+ try:
+ ocr_results = self.ocr(img_array)
+
+ if ocr_results:
+ # Calculate average confidence
+ scores = [conf for _, (_, conf) in ocr_results]
+ avg_score = sum(scores) / len(scores) if scores else 0
+ total_regions = len(scores)
+
+ # Combined score: considers both average confidence and number of regions
+ # More regions + higher confidence = better orientation
+ combined_score = avg_score * (1 + 0.1 * min(total_regions, 50) / 50)
+ else:
+ avg_score = 0
+ total_regions = 0
+ combined_score = 0
+
+ except Exception as e:
+ logging.warning(f"OCR failed for angle {angle}: {e}")
+ avg_score = 0
+ total_regions = 0
+ combined_score = 0
+
+ results[angle] = {"avg_confidence": avg_score, "total_regions": total_regions, "combined_score": combined_score}
+ if angle == 0:
+ score_0 = combined_score
+
+ logging.debug(f"Table orientation {angle}°: avg_conf={avg_score:.4f}, regions={total_regions}, combined={combined_score:.4f}")
+
+ if combined_score > best_score:
+ best_score = combined_score
+ best_angle = angle
+ best_img = rotated_img
+
+ # Absolute threshold rule:
+ # Only choose non-0° if it exceeds 0° by more than 0.2 and 0° score is below 0.8.
+ if best_angle != 0 and score_0 is not None:
+ if not (best_score - score_0 > 0.2 and score_0 < 0.8):
+ best_angle = 0
+ best_img = table_img
+ best_score = score_0
+
+ results[best_angle] = results.get(best_angle, {"avg_confidence": 0, "total_regions": 0, "combined_score": 0})
+
+ logging.info(f"Best table orientation: {best_angle}° (score={best_score:.4f})")
+
+ return best_angle, best_img, results
+
+ def _table_transformer_job(self, ZM, auto_rotate=True):
+ """
+ Process table structure recognition.
+
+ When auto_rotate=True, the complete workflow:
+ 1. Evaluate table orientation and select the best rotation angle
+ 2. Use rotated image for table structure recognition (TSR)
+ 3. Re-OCR the rotated image
+ 4. Match new OCR results with TSR cell coordinates
+
+ Args:
+ ZM: Zoom factor
+ auto_rotate: Whether to enable auto orientation correction
+ """
logging.debug("Table processing...")
imgs, pos = [], []
tbcnt = [0]
MARGIN = 10
self.tb_cpns = []
+ self.table_rotations = {} # Store rotation info for each table
+ self.rotated_table_imgs = {} # Store rotated table images
+
assert len(self.page_layout) == len(self.page_images)
+
+ # Collect layout info for all tables
+ table_layouts = [] # [(page, table_layout, left, top, right, bott), ...]
+
+ table_index = 0
for p, tbls in enumerate(self.page_layout): # for page
tbls = [f for f in tbls if f["type"] == "table"]
tbcnt.append(len(tbls))
@@ -210,29 +327,68 @@ def _table_transformer_job(self, ZM):
top *= ZM
right *= ZM
bott *= ZM
- pos.append((left, top))
- imgs.append(self.page_images[p].crop((left, top, right, bott)))
+ pos.append((left, top, p, table_index)) # Add page and table_index
+
+ # Record table layout info
+ table_layouts.append({"page": p, "table_index": table_index, "layout": tb, "coords": (left, top, right, bott)})
+
+ # Crop table image
+ table_img = self.page_images[p].crop((left, top, right, bott))
+
+ if auto_rotate:
+ # Evaluate table orientation
+ logging.debug(f"Evaluating orientation for table {table_index} on page {p}")
+ best_angle, rotated_img, rotation_scores = self._evaluate_table_orientation(table_img)
+
+ # Store rotation info
+ self.table_rotations[table_index] = {
+ "page": p,
+ "original_pos": (left, top, right, bott),
+ "best_angle": best_angle,
+ "scores": rotation_scores,
+ "rotated_size": rotated_img.size, # (width, height)
+ }
+
+ # Store the rotated image
+ self.rotated_table_imgs[table_index] = rotated_img
+ imgs.append(rotated_img)
+
+ else:
+ imgs.append(table_img)
+ self.table_rotations[table_index] = {"page": p, "original_pos": (left, top, right, bott), "best_angle": 0, "scores": {}, "rotated_size": table_img.size}
+ self.rotated_table_imgs[table_index] = table_img
+
+ table_index += 1
assert len(self.page_images) == len(tbcnt) - 1
if not imgs:
return
+
+ # Perform table structure recognition (TSR)
recos = self.tbl_det(imgs)
+
+ # If tables were rotated, re-OCR the rotated images and replace table boxes
+ if auto_rotate:
+ self._ocr_rotated_tables(ZM, table_layouts, recos, tbcnt)
+
+ # Process TSR results (keep original logic but handle rotated coordinates)
tbcnt = np.cumsum(tbcnt)
for i in range(len(tbcnt) - 1): # for page
pg = []
for j, tb_items in enumerate(recos[tbcnt[i] : tbcnt[i + 1]]): # for table
poss = pos[tbcnt[i] : tbcnt[i + 1]]
for it in tb_items: # for table components
- it["x0"] = it["x0"] + poss[j][0]
- it["x1"] = it["x1"] + poss[j][0]
- it["top"] = it["top"] + poss[j][1]
- it["bottom"] = it["bottom"] + poss[j][1]
- for n in ["x0", "x1", "top", "bottom"]:
- it[n] /= ZM
- it["top"] += self.page_cum_height[i]
- it["bottom"] += self.page_cum_height[i]
- it["pn"] = i
+ # TSR coordinates are relative to rotated image, need to record
+ it["x0_rotated"] = it["x0"]
+ it["x1_rotated"] = it["x1"]
+ it["top_rotated"] = it["top"]
+ it["bottom_rotated"] = it["bottom"]
+
+ # For rotated tables, coordinate transformation to page space requires rotation
+ # Since we already re-OCR'd on rotated image, keep simple processing here
+ it["pn"] = poss[j][2] # page number
it["layoutno"] = j
+ it["table_index"] = poss[j][3] # table index
pg.append(it)
self.tb_cpns.extend(pg)
@@ -245,8 +401,9 @@ def gather(kwd, fzy=10, ption=0.6):
headers = gather(r".*header$")
rows = gather(r".* (row|header)")
spans = gather(r".*spanning")
- clmns = sorted([r for r in self.tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
+ clmns = sorted([r for r in self.tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0_rotated"] if "x0_rotated" in x else x["x0"]))
clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
+
for b in self.boxes:
if b.get("layout_type", "") != "table":
continue
@@ -278,6 +435,153 @@ def gather(kwd, fzy=10, ption=0.6):
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
+ def _ocr_rotated_tables(self, ZM, table_layouts, tsr_results, tbcnt):
+ """
+ Re-OCR rotated table images and update self.boxes.
+
+ Args:
+ ZM: Zoom factor
+ table_layouts: List of table layout info
+ tsr_results: TSR recognition results
+ tbcnt: Cumulative table count per page
+ """
+ tbcnt = np.cumsum(tbcnt)
+
+ def _table_region(layout, page_index):
+ table_x0 = layout["x0"]
+ table_top = layout["top"]
+ table_x1 = layout["x1"]
+ table_bottom = layout["bottom"]
+ table_top_cum = table_top + self.page_cum_height[page_index]
+ table_bottom_cum = table_bottom + self.page_cum_height[page_index]
+ return table_x0, table_top, table_x1, table_bottom, table_top_cum, table_bottom_cum
+
+ def _collect_table_boxes(page_index, table_x0, table_x1, table_top_cum, table_bottom_cum):
+ indices = [
+ i
+ for i, b in enumerate(self.boxes)
+ if (
+ b.get("page_number") == page_index + self.page_from
+ and b.get("layout_type") == "table"
+ and b["x0"] >= table_x0 - 5
+ and b["x1"] <= table_x1 + 5
+ and b["top"] >= table_top_cum - 5
+ and b["bottom"] <= table_bottom_cum + 5
+ )
+ ]
+ original_boxes = [self.boxes[i] for i in indices]
+ insert_at = indices[0] if indices else len(self.boxes)
+ for i in reversed(indices):
+ self.boxes.pop(i)
+ return original_boxes, insert_at
+
+ def _restore_boxes(original_boxes, insert_at):
+ for b in original_boxes:
+ self.boxes.insert(insert_at, b)
+ insert_at += 1
+ return insert_at
+
+ def _map_rotated_point(x, y, angle, width, height):
+ # Map a point from rotated image coords back to original image coords.
+ if angle == 0:
+ return x, y
+ if angle == 90:
+ # clockwise 90: original->rotated (x', y') = (y, width - x)
+ # inverse:
+ return width - y, x
+ if angle == 180:
+ return width - x, height - y
+ if angle == 270:
+ # clockwise 270: original->rotated (x', y') = (height - y, x)
+ # inverse:
+ return y, height - x
+ return x, y
+
+ def _insert_ocr_boxes(ocr_results, page_index, table_x0, table_top, insert_at, table_index, best_angle, table_w_px, table_h_px):
+ added = 0
+ for bbox, (text, conf) in ocr_results:
+ if conf < 0.5:
+ continue
+ mapped = [_map_rotated_point(p[0], p[1], best_angle, table_w_px, table_h_px) for p in bbox]
+ x_coords = [p[0] for p in mapped]
+ y_coords = [p[1] for p in mapped]
+ box_x0 = min(x_coords) / ZM
+ box_x1 = max(x_coords) / ZM
+ box_top = min(y_coords) / ZM
+ box_bottom = max(y_coords) / ZM
+ new_box = {
+ "text": text,
+ "x0": box_x0 + table_x0,
+ "x1": box_x1 + table_x0,
+ "top": box_top + table_top + self.page_cum_height[page_index],
+ "bottom": box_bottom + table_top + self.page_cum_height[page_index],
+ "page_number": page_index + self.page_from,
+ "layout_type": "table",
+ "layoutno": f"table-{table_index}",
+ "_rotated": True,
+ "_rotation_angle": best_angle,
+ "_table_index": table_index,
+ "_rotated_x0": box_x0,
+ "_rotated_x1": box_x1,
+ "_rotated_top": box_top,
+ "_rotated_bottom": box_bottom,
+ }
+ self.boxes.insert(insert_at, new_box)
+ insert_at += 1
+ added += 1
+ return added
+
+ for tbl_info in table_layouts:
+ table_index = tbl_info["table_index"]
+ page = tbl_info["page"]
+ layout = tbl_info["layout"]
+ left, top, right, bott = tbl_info["coords"]
+
+ rotation_info = self.table_rotations.get(table_index, {})
+ best_angle = rotation_info.get("best_angle", 0)
+
+ # Get the rotated table image
+ rotated_img = self.rotated_table_imgs.get(table_index)
+ if rotated_img is None:
+ continue
+
+ # If no rotation, keep original OCR boxes untouched.
+ if best_angle == 0:
+ continue
+
+ # Table region is defined by layout's x0, top, x1, bottom (page-local coords)
+ table_x0, table_top, table_x1, table_bottom, table_top_cum, table_bottom_cum = _table_region(layout, page)
+ original_boxes, insert_at = _collect_table_boxes(page, table_x0, table_x1, table_top_cum, table_bottom_cum)
+
+ logging.info(f"Re-OCR table {table_index} on page {page} with rotation {best_angle}°")
+
+ # Perform OCR on rotated image
+ img_array = np.array(rotated_img)
+ ocr_results = self.ocr(img_array)
+
+ if not ocr_results:
+ logging.warning(f"No OCR results for rotated table {table_index}, restoring originals")
+ _restore_boxes(original_boxes, insert_at)
+ continue
+
+ # Add new OCR results to self.boxes
+ # OCR coordinates are relative to rotated image, map back to original table coords
+ table_w_px = right - left
+ table_h_px = bott - top
+ added = _insert_ocr_boxes(
+ ocr_results,
+ page,
+ table_x0,
+ table_top,
+ insert_at,
+ table_index,
+ best_angle,
+ table_w_px,
+ table_h_px,
+ )
+
+ logging.info(f"Added {added} OCR results from rotated table {table_index}")
+
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
start = timer()
bxs = self.ocr.detect(np.array(img), device_id)
@@ -408,11 +712,9 @@ def _assign_column(self, boxes, zoomin=3):
page_cols[pg] = best_k
logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
-
global_cols = Counter(page_cols.values()).most_common(1)[0][0]
logging.info(f"Global column_num decided by majority: {global_cols}")
-
for pg, bxs in by_page.items():
if not bxs:
continue
@@ -476,7 +778,7 @@ def start_with(b, txts):
self.boxes = bxs
def _naive_vertical_merge(self, zoomin=3):
- #bxs = self._assign_column(self.boxes, zoomin)
+ # bxs = self._assign_column(self.boxes, zoomin)
bxs = self.boxes
grouped = defaultdict(list)
@@ -553,7 +855,8 @@ def _naive_vertical_merge(self, zoomin=3):
merged_boxes.extend(bxs)
- #self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
+ # self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
+ self.boxes = merged_boxes
def _final_reading_order_merge(self, zoomin=3):
if not self.boxes:
@@ -855,7 +1158,30 @@ def nearest(tbls):
def cropout(bxs, ltype, poss):
nonlocal ZM
- pn = set([b["page_number"] - 1 for b in bxs])
+ max_page_index = len(self.page_images) - 1
+
+ def local_page_index(page_number):
+ idx = page_number - 1 if page_number > 0 else 0
+ if idx > max_page_index and self.page_from:
+ idx = page_number - 1 - self.page_from
+ return idx
+
+ pn = set()
+ for b in bxs:
+ idx = local_page_index(b["page_number"])
+ if 0 <= idx <= max_page_index:
+ pn.add(idx)
+ else:
+ logging.warning(
+ "Skip out-of-range page_number %s (page_from=%s, pages=%s)",
+ b.get("page_number"),
+ self.page_from,
+ len(self.page_images),
+ )
+
+ if not pn:
+ return None
+
if len(pn) < 2:
pn = list(pn)[0]
ht = self.page_cum_height[pn]
@@ -874,12 +1200,16 @@ def cropout(bxs, ltype, poss):
return self.page_images[pn].crop((left * ZM, top * ZM, right * ZM, bott * ZM))
pn = {}
for b in bxs:
- p = b["page_number"] - 1
- if p not in pn:
- pn[p] = []
- pn[p].append(b)
+ p = local_page_index(b["page_number"])
+ if 0 <= p <= max_page_index:
+ if p not in pn:
+ pn[p] = []
+ pn[p].append(b)
pn = sorted(pn.items(), key=lambda x: x[0])
imgs = [cropout(arr, ltype, poss) for p, arr in pn]
+ imgs = [img for img in imgs if img is not None]
+ if not imgs:
+ return None
pic = Image.new("RGB", (int(np.max([i.size[0] for i in imgs])), int(np.sum([m.size[1] for m in imgs]))), (245, 245, 245))
height = 0
for img in imgs:
@@ -900,10 +1230,16 @@ def cropout(bxs, ltype, poss):
poss = []
if separate_tables_figures:
- figure_results.append((cropout(bxs, "figure", poss), [txt]))
+ img = cropout(bxs, "figure", poss)
+ if img is None:
+ continue
+ figure_results.append((img, [txt]))
figure_positions.append(poss)
else:
- res.append((cropout(bxs, "figure", poss), [txt]))
+ img = cropout(bxs, "figure", poss)
+ if img is None:
+ continue
+ res.append((img, [txt]))
positions.append(poss)
for k, bxs in tables.items():
@@ -913,7 +1249,10 @@ def cropout(bxs, ltype, poss):
poss = []
- res.append((cropout(bxs, "table", poss), self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
+ img = cropout(bxs, "table", poss)
+ if img is None:
+ continue
+ res.append((img, self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
positions.append(poss)
if separate_tables_figures:
@@ -1113,7 +1452,7 @@ async def __img_ocr(i, id, img, chars, limiter):
if limiter:
async with limiter:
- await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
+ await thread_pool_exec(self.__ocr, i + 1, img, chars, zoomin, id)
else:
self.__ocr(i + 1, img, chars, zoomin, id)
@@ -1179,10 +1518,26 @@ async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
if len(self.boxes) == 0 and zoomin < 9:
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
- def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
+ def __call__(self, fnm, need_image=True, zoomin=3, return_html=False, auto_rotate_tables=None):
+ """
+ Parse a PDF file.
+
+ Args:
+ fnm: PDF file path or binary content
+ need_image: Whether to extract images
+ zoomin: Zoom factor
+ return_html: Whether to return tables in HTML format
+ auto_rotate_tables: Whether to enable auto orientation correction for tables.
+ None: Use TABLE_AUTO_ROTATE env var setting (default: True)
+ True: Enable auto orientation correction
+ False: Disable auto orientation correction
+ """
+ if auto_rotate_tables is None:
+ auto_rotate_tables = os.getenv("TABLE_AUTO_ROTATE", "true").lower() in ("true", "1", "yes")
+
self.__images__(fnm, zoomin)
self._layouts_rec(zoomin)
- self._table_transformer_job(zoomin)
+ self._table_transformer_job(zoomin, auto_rotate=auto_rotate_tables)
self._text_merge()
self._concat_downward()
self._filter_forpages()
@@ -1200,8 +1555,11 @@ def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
if callback:
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
+ # Read table auto-rotation setting from environment variable
+ auto_rotate_tables = os.getenv("TABLE_AUTO_ROTATE", "true").lower() in ("true", "1", "yes")
+
start = timer()
- self._table_transformer_job(zoomin)
+ self._table_transformer_job(zoomin, auto_rotate=auto_rotate_tables)
if callback:
callback(0.83, "Table analysis ({:.2f}s)".format(timer() - start))
@@ -1493,10 +1851,7 @@ def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
if text:
width, height = self.page_images[idx].size
- all_docs.append((
- text,
- f"@@{pdf_page_num + 1}\t{0.0:.1f}\t{width / zoomin:.1f}\t{0.0:.1f}\t{height / zoomin:.1f}##"
- ))
+ all_docs.append((text, f"@@{pdf_page_num + 1}\t{0.0:.1f}\t{width / zoomin:.1f}\t{0.0:.1f}\t{height / zoomin:.1f}##"))
return all_docs, []
diff --git a/deepdoc/parser/ppt_parser.py b/deepdoc/parser/ppt_parser.py
index 1b04b4d7c31..afff23d7de6 100644
--- a/deepdoc/parser/ppt_parser.py
+++ b/deepdoc/parser/ppt_parser.py
@@ -22,6 +22,16 @@
class RAGFlowPptParser:
def __init__(self):
super().__init__()
+ self._shape_cache = {}
+
+ def __sort_shapes(self, shapes):
+ cache_key = id(shapes)
+ if cache_key not in self._shape_cache:
+ self._shape_cache[cache_key] = sorted(
+ shapes,
+ key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)
+ )
+ return self._shape_cache[cache_key]
def __get_bulleted_text(self, paragraph):
is_bulleted = bool(paragraph._p.xpath("./a:pPr/a:buChar")) or bool(paragraph._p.xpath("./a:pPr/a:buAutoNum")) or bool(paragraph._p.xpath("./a:pPr/a:buBlip"))
@@ -62,7 +72,7 @@ def __extract(self, shape):
# Handle group shape
if shape_type == 6:
texts = []
- for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
+ for p in self.__sort_shapes(shape.shapes):
t = self.__extract(p)
if t:
texts.append(t)
@@ -86,8 +96,7 @@ def __call__(self, fnm, from_page, to_page, callback=None):
if i >= to_page:
break
texts = []
- for shape in sorted(
- slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
+ for shape in self.__sort_shapes(slide.shapes):
txt = self.__extract(shape)
if txt:
texts.append(txt)
diff --git a/deepdoc/parser/tcadp_parser.py b/deepdoc/parser/tcadp_parser.py
index 8d704baed29..af1c9034895 100644
--- a/deepdoc/parser/tcadp_parser.py
+++ b/deepdoc/parser/tcadp_parser.py
@@ -17,6 +17,7 @@
import json
import logging
import os
+import re
import shutil
import tempfile
import time
@@ -48,10 +49,10 @@ def __init__(self, secret_id, secret_key, region):
self.secret_key = secret_key
self.region = region
self.outlines = []
-
+
# Create credentials
self.cred = credential.Credential(secret_id, secret_key)
-
+
# Instantiate an http option, optional, can be skipped if no special requirements
self.httpProfile = HttpProfile()
self.httpProfile.endpoint = "lkeap.tencentcloudapi.com"
@@ -59,7 +60,7 @@ def __init__(self, secret_id, secret_key, region):
# Instantiate a client option, optional, can be skipped if no special requirements
self.clientProfile = ClientProfile()
self.clientProfile.httpProfile = self.httpProfile
-
+
# Instantiate the client object for the product to be requested, clientProfile is optional
self.client = lkeap_client.LkeapClient(self.cred, region, self.clientProfile)
@@ -68,14 +69,14 @@ def reconstruct_document_sse(self, file_type, file_url=None, file_base64=None, f
try:
# Instantiate a request object, each interface corresponds to a request object
req = models.ReconstructDocumentSSERequest()
-
+
# Build request parameters
params = {
"FileType": file_type,
"FileStartPageNumber": file_start_page,
"FileEndPageNumber": file_end_page,
}
-
+
# According to Tencent Cloud API documentation, either FileUrl or FileBase64 parameter must be provided, if both are provided only FileUrl will be used
if file_url:
params["FileUrl"] = file_url
@@ -94,7 +95,7 @@ def reconstruct_document_sse(self, file_type, file_url=None, file_base64=None, f
# The returned resp is an instance of ReconstructDocumentSSEResponse, corresponding to the request object
resp = self.client.ReconstructDocumentSSE(req)
parser_result = {}
-
+
# Output json format string response
if isinstance(resp, types.GeneratorType): # Streaming response
logging.info("[TCADP] Detected streaming response")
@@ -104,7 +105,7 @@ def reconstruct_document_sse(self, file_type, file_url=None, file_base64=None, f
try:
data_dict = json.loads(event['data'])
logging.info(f"[TCADP] Parsed data: {data_dict}")
-
+
if data_dict.get('Progress') == "100":
parser_result = data_dict
logging.info("[TCADP] Document parsing completed!")
@@ -118,14 +119,14 @@ def reconstruct_document_sse(self, file_type, file_url=None, file_base64=None, f
logging.warning("[TCADP] Failed parsing pages:")
for page in failed_pages:
logging.warning(f"[TCADP] Page number: {page.get('PageNumber')}, Error: {page.get('ErrorMsg')}")
-
+
# Check if there is a download link
download_url = data_dict.get("DocumentRecognizeResultUrl")
if download_url:
logging.info(f"[TCADP] Got download link: {download_url}")
else:
logging.warning("[TCADP] No download link obtained")
-
+
break # Found final result, exit loop
else:
# Print progress information
@@ -168,9 +169,6 @@ def download_result_file(self, download_url, output_dir):
return None
try:
- response = requests.get(download_url)
- response.raise_for_status()
-
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
@@ -179,29 +177,36 @@ def download_result_file(self, download_url, output_dir):
filename = f"tcadp_result_{timestamp}.zip"
file_path = os.path.join(output_dir, filename)
- # Save file
- with open(file_path, "wb") as f:
- f.write(response.content)
+ with requests.get(download_url, stream=True) as response:
+ response.raise_for_status()
+ with open(file_path, "wb") as f:
+ response.raw.decode_content = True
+ shutil.copyfileobj(response.raw, f)
logging.info(f"[TCADP] Document parsing result downloaded to: {os.path.basename(file_path)}")
return file_path
- except requests.exceptions.RequestException as e:
+ except Exception as e:
logging.error(f"[TCADP] Failed to download file: {e}")
+ try:
+ if "file_path" in locals() and os.path.exists(file_path):
+ os.unlink(file_path)
+ except Exception:
+ pass
return None
class TCADPParser(RAGFlowPdfParser):
- def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou",
+ def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou",
table_result_type: str = None, markdown_image_response_type: str = None):
super().__init__()
-
+
# First initialize logger
self.logger = logging.getLogger(self.__class__.__name__)
-
+
# Log received parameters
self.logger.info(f"[TCADP] Initializing with parameters - table_result_type: {table_result_type}, markdown_image_response_type: {markdown_image_response_type}")
-
+
# Priority: read configuration from RAGFlow configuration system (service_conf.yaml)
try:
tcadp_parser = get_base_config("tcadp_config", {})
@@ -212,7 +217,7 @@ def __init__(self, secret_id: str = None, secret_key: str = None, region: str =
# Set table_result_type and markdown_image_response_type from config or parameters
self.table_result_type = table_result_type if table_result_type is not None else tcadp_parser.get("table_result_type", "1")
self.markdown_image_response_type = markdown_image_response_type if markdown_image_response_type is not None else tcadp_parser.get("markdown_image_response_type", "1")
-
+
else:
self.logger.error("[TCADP] Please configure tcadp_config in service_conf.yaml first")
# If config file is empty, use provided parameters or defaults
@@ -237,6 +242,10 @@ def __init__(self, secret_id: str = None, secret_key: str = None, region: str =
if not self.secret_id or not self.secret_key:
raise ValueError("[TCADP] Please set Tencent Cloud API keys, configure tcadp_config in service_conf.yaml")
+ @staticmethod
+ def _is_zipinfo_symlink(member: zipfile.ZipInfo) -> bool:
+ return (member.external_attr >> 16) & 0o170000 == 0o120000
+
def check_installation(self) -> bool:
"""Check if Tencent Cloud API configuration is correct"""
try:
@@ -255,7 +264,7 @@ def check_installation(self) -> bool:
def _file_to_base64(self, file_path: str, binary: bytes = None) -> str:
"""Convert file to Base64 format"""
-
+
if binary:
# If binary data is directly available, convert directly
return base64.b64encode(binary).decode('utf-8')
@@ -271,23 +280,34 @@ def _extract_content_from_zip(self, zip_path: str) -> list[dict[str, Any]]:
try:
with zipfile.ZipFile(zip_path, "r") as zip_file:
- # Find JSON result files
- json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
-
- for json_file in json_files:
- with zip_file.open(json_file) as f:
- data = json.load(f)
- if isinstance(data, list):
- results.extend(data)
+ members = zip_file.infolist()
+ for member in members:
+ name = member.filename.replace("\\", "/")
+ if member.is_dir():
+ continue
+ if member.flag_bits & 0x1:
+ raise RuntimeError(f"[TCADP] Encrypted zip entry not supported: {member.filename}")
+ if self._is_zipinfo_symlink(member):
+ raise RuntimeError(f"[TCADP] Symlink zip entry not supported: {member.filename}")
+ if name.startswith("/") or name.startswith("//") or re.match(r"^[A-Za-z]:", name):
+ raise RuntimeError(f"[TCADP] Unsafe zip path (absolute): {member.filename}")
+ parts = [p for p in name.split("/") if p not in ("", ".")]
+ if any(p == ".." for p in parts):
+ raise RuntimeError(f"[TCADP] Unsafe zip path (traversal): {member.filename}")
+
+ if not (name.endswith(".json") or name.endswith(".md")):
+ continue
+
+ with zip_file.open(member) as f:
+ if name.endswith(".json"):
+ data = json.load(f)
+ if isinstance(data, list):
+ results.extend(data)
+ else:
+ results.append(data)
else:
- results.append(data)
-
- # Find Markdown files
- md_files = [f for f in zip_file.namelist() if f.endswith(".md")]
- for md_file in md_files:
- with zip_file.open(md_file) as f:
- content = f.read().decode("utf-8")
- results.append({"type": "text", "content": content, "file": md_file})
+ content = f.read().decode("utf-8")
+ results.append({"type": "text", "content": content, "file": name})
except Exception as e:
self.logger.error(f"[TCADP] Failed to extract ZIP file content: {e}")
@@ -395,7 +415,7 @@ def parse_pdf(
# Convert file to Base64 format
if callback:
callback(0.2, "[TCADP] Converting file to Base64 format")
-
+
file_base64 = self._file_to_base64(file_path, binary)
if callback:
callback(0.25, f"[TCADP] File converted to Base64, size: {len(file_base64)} characters")
@@ -420,23 +440,23 @@ def parse_pdf(
"TableResultType": self.table_result_type,
"MarkdownImageResponseType": self.markdown_image_response_type
}
-
+
self.logger.info(f"[TCADP] API request config - TableResultType: {self.table_result_type}, MarkdownImageResponseType: {self.markdown_image_response_type}")
result = client.reconstruct_document_sse(
- file_type=file_type,
- file_base64=file_base64,
- file_start_page=file_start_page,
- file_end_page=file_end_page,
+ file_type=file_type,
+ file_base64=file_base64,
+ file_start_page=file_start_page,
+ file_end_page=file_end_page,
config=config
)
-
+
if result:
self.logger.info(f"[TCADP] Attempt {attempt + 1} successful")
break
else:
self.logger.warning(f"[TCADP] Attempt {attempt + 1} failed, result is None")
-
+
except Exception as e:
self.logger.error(f"[TCADP] Attempt {attempt + 1} exception: {e}")
if attempt == max_retries - 1:
diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py
index afa6921272e..1f573bda595 100644
--- a/deepdoc/vision/ocr.py
+++ b/deepdoc/vision/ocr.py
@@ -96,8 +96,9 @@ def cuda_is_available():
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
- options.intra_op_num_threads = 2
- options.inter_op_num_threads = 2
+ # Prevent CPU oversubscription by allowing explicit thread control in multi-worker environments
+ options.intra_op_num_threads = int(os.environ.get("OCR_INTRA_OP_NUM_THREADS", "2"))
+ options.inter_op_num_threads = int(os.environ.get("OCR_INTER_OP_NUM_THREADS", "2"))
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution
@@ -117,6 +118,11 @@ def cuda_is_available():
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
+ # Explicit arena shrinkage for GPU to release VRAM back to the system after each run
+ if os.environ.get("OCR_GPUMEM_ARENA_SHRINKAGE") == "1":
+ run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", f"gpu:{provider_device_id}")
+ logging.info(
+ f"load_model {model_file_path} enabled GPU memory arena shrinkage on device {provider_device_id}")
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(
diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py
index d3b33b12244..58ada1b15e4 100644
--- a/deepdoc/vision/t_ocr.py
+++ b/deepdoc/vision/t_ocr.py
@@ -18,6 +18,10 @@
import logging
import os
import sys
+
+
+from common.misc_utils import thread_pool_exec
+
sys.path.insert(
0,
os.path.abspath(
@@ -64,9 +68,9 @@ async def __ocr_thread(i, id, img, limiter = None):
if limiter:
async with limiter:
print(f"Task {i} use device {id}")
- await asyncio.to_thread(__ocr, i, id, img)
+ await thread_pool_exec(__ocr, i, id, img)
else:
- await asyncio.to_thread(__ocr, i, id, img)
+ await thread_pool_exec(__ocr, i, id, img)
async def __ocr_launcher():
diff --git a/docker/.env b/docker/.env
index 2d31177d759..7e1bdf801bc 100644
--- a/docker/.env
+++ b/docker/.env
@@ -16,6 +16,7 @@
# - `infinity` (https://github.com/infiniflow/infinity)
# - `oceanbase` (https://github.com/oceanbase/oceanbase)
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
+# - `seekdb` (https://github.com/oceanbase/seekdb)
DOC_ENGINE=${DOC_ENGINE:-elasticsearch}
# Device on which deepdoc inference run.
@@ -92,6 +93,19 @@ OB_SYSTEM_MEMORY=${OB_SYSTEM_MEMORY:-2G}
OB_DATAFILE_SIZE=${OB_DATAFILE_SIZE:-20G}
OB_LOG_DISK_SIZE=${OB_LOG_DISK_SIZE:-20G}
+# The hostname where the SeekDB service is exposed
+SEEKDB_HOST=seekdb
+# The port used to expose the SeekDB service
+SEEKDB_PORT=2881
+# The username for SeekDB
+SEEKDB_USER=root
+# The password for SeekDB
+SEEKDB_PASSWORD=infini_rag_flow
+# The doc database of the SeekDB service to use
+SEEKDB_DOC_DBNAME=ragflow_doc
+# SeekDB memory limit
+SEEKDB_MEMORY_LIMIT=2G
+
# The password for MySQL.
# WARNING: Change this for production!
MYSQL_PASSWORD=infini_rag_flow
@@ -99,9 +113,12 @@ MYSQL_PASSWORD=infini_rag_flow
MYSQL_HOST=mysql
# The database of the MySQL service to use
MYSQL_DBNAME=rag_flow
+# The port used to connect to MySQL from RAGFlow container.
+# Change this if you use external MySQL.
+MYSQL_PORT=3306
# The port used to expose the MySQL service to the host machine,
# allowing EXTERNAL access to the MySQL database running inside the Docker container.
-MYSQL_PORT=5455
+EXPOSE_MYSQL_PORT=5455
# The maximum size of communication packets sent to the MySQL server
MYSQL_MAX_PACKET=1073741824
@@ -137,11 +154,11 @@ ADMIN_SVR_HTTP_PORT=9381
SVR_MCP_PORT=9382
# The RAGFlow Docker image to download. v0.22+ doesn't include embedding models.
-RAGFLOW_IMAGE=infiniflow/ragflow:v0.23.1
+RAGFLOW_IMAGE=infiniflow/ragflow:v0.24.0
# If you cannot download the RAGFlow Docker image:
-# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.23.1
-# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.23.1
+# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.24.0
+# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.24.0
#
# - For the `nightly` edition, uncomment either of the following:
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly
@@ -210,6 +227,7 @@ EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-16}
# ENDPOINT=http://oss-cn-hangzhou.aliyuncs.com
# REGION=cn-hangzhou
# BUCKET=ragflow65536
+#
# A user registration switch:
# - Enable registration: 1
@@ -255,3 +273,7 @@ DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
# RAGFLOW_CRYPTO_ENABLED=true
# RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc
# RAGFLOW_CRYPTO_KEY=ragflow-crypto-key
+
+
+# Used for ThreadPoolExecutor
+THREAD_POOL_MAX_WORKERS=128
\ No newline at end of file
diff --git a/docker/README.md b/docker/README.md
index f2f59cdf631..c6422bad8c7 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -52,6 +52,8 @@ The [.env](./.env) file contains important environment variables for Docker.
- `MYSQL_PASSWORD`
The password for MySQL.
- `MYSQL_PORT`
+ The port to connect to MySQL from RAGFlow container. Defaults to `3306`. Change this if you use an external MySQL.
+- `EXPOSE_MYSQL_PORT`
The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`.
### MinIO
@@ -77,7 +79,7 @@ The [.env](./.env) file contains important environment variables for Docker.
- `SVR_HTTP_PORT`
The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`.
- `RAGFLOW-IMAGE`
- The Docker image edition. Defaults to `infiniflow/ragflow:v0.23.1`. The RAGFlow Docker image does not include embedding models.
+ The Docker image edition. Defaults to `infiniflow/ragflow:v0.24.0`. The RAGFlow Docker image does not include embedding models.
> [!TIP]
diff --git a/docker/docker-compose-base.yml b/docker/docker-compose-base.yml
index 11104aef53c..f82f8027333 100644
--- a/docker/docker-compose-base.yml
+++ b/docker/docker-compose-base.yml
@@ -72,7 +72,7 @@ services:
infinity:
profiles:
- infinity
- image: infiniflow/infinity:v0.6.15
+ image: infiniflow/infinity:v0.7.0-dev2
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml
@@ -121,6 +121,30 @@ services:
- ragflow
restart: unless-stopped
+ seekdb:
+ profiles:
+ - seekdb
+ image: oceanbase/seekdb:latest
+ container_name: seekdb
+ volumes:
+ - ./seekdb:/var/lib/oceanbase
+ ports:
+ - ${SEEKDB_PORT:-2881}:2881
+ env_file: .env
+ environment:
+ - ROOT_PASSWORD=${SEEKDB_PASSWORD:-infini_rag_flow}
+ - MEMORY_LIMIT=${SEEKDB_MEMORY_LIMIT:-2G}
+ - REPORTER=ragflow-seekdb
+ mem_limit: ${MEM_LIMIT}
+ healthcheck:
+ test: ['CMD-SHELL', 'mysql -h127.0.0.1 -P2881 -uroot -p${SEEKDB_PASSWORD:-infini_rag_flow} -e "CREATE DATABASE IF NOT EXISTS ${SEEKDB_DOC_DBNAME:-ragflow_doc};"']
+ interval: 5s
+ retries: 60
+ timeout: 5s
+ networks:
+ - ragflow
+ restart: unless-stopped
+
sandbox-executor-manager:
profiles:
- sandbox
@@ -164,7 +188,7 @@ services:
--init-file /data/application/init.sql
--binlog_expire_logs_seconds=604800
ports:
- - ${MYSQL_PORT}:3306
+ - ${EXPOSE_MYSQL_PORT}:3306
volumes:
- mysql_data:/var/lib/mysql
- ./init.sql:/data/application/init.sql
@@ -283,6 +307,8 @@ volumes:
driver: local
ob_data:
driver: local
+ seekdb_data:
+ driver: local
mysql_data:
driver: local
minio_data:
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 1b9a41951e9..a32c2b609ef 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -39,7 +39,6 @@ services:
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
- - ../history_data_agent:/ragflow/history_data_agent
- ./service_conf.yaml.template:/ragflow/conf/service_conf.yaml.template
- ./entrypoint.sh:/ragflow/entrypoint.sh
env_file: .env
@@ -88,7 +87,6 @@ services:
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
- - ../history_data_agent:/ragflow/history_data_agent
- ./service_conf.yaml.template:/ragflow/conf/service_conf.yaml.template
- ./entrypoint.sh:/ragflow/entrypoint.sh
env_file: .env
diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh
index 62e0ed84801..4fb5cbde3dd 100755
--- a/docker/entrypoint.sh
+++ b/docker/entrypoint.sh
@@ -156,8 +156,20 @@ TEMPLATE_FILE="${CONF_DIR}/service_conf.yaml.template"
CONF_FILE="${CONF_DIR}/service_conf.yaml"
rm -f "${CONF_FILE}"
+DEF_ENV_VALUE_PATTERN="\$\{([^:]+):-([^}]+)\}"
while IFS= read -r line || [[ -n "$line" ]]; do
- eval "echo \"$line\"" >> "${CONF_FILE}"
+ if [[ "$line" =~ DEF_ENV_VALUE_PATTERN ]]; then
+ varname="${BASH_REMATCH[1]}"
+ default="${BASH_REMATCH[2]}"
+
+ if [ -n "${!varname}" ]; then
+ eval "echo \"$line"\" >> "${CONF_FILE}"
+ else
+ echo "$line" | sed -E "s/\\\$\{[^:]+:-([^}]+)\}/\1/g" >> "${CONF_FILE}"
+ fi
+ else
+ eval "echo \"$line\"" >> "${CONF_FILE}"
+ fi
done < "${TEMPLATE_FILE}"
export LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu/"
@@ -195,10 +207,9 @@ function start_mcp_server() {
function ensure_docling() {
[[ "${USE_DOCLING}" == "true" ]] || { echo "[docling] disabled by USE_DOCLING"; return 0; }
- python3 -c 'import pip' >/dev/null 2>&1 || python3 -m ensurepip --upgrade || true
- DOCLING_PIN="${DOCLING_VERSION:-==2.58.0}"
- python3 -c "import importlib.util,sys; sys.exit(0 if importlib.util.find_spec('docling') else 1)" \
- || python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --extra-index-url https://pypi.org/simple --no-cache-dir "docling${DOCLING_PIN}"
+ DOCLING_PIN="${DOCLING_VERSION:-==2.71.0}"
+ "$PY" -c "import importlib.util,sys; sys.exit(0 if importlib.util.find_spec('docling') else 1)" \
+ || uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --extra-index-url https://pypi.org/simple --no-cache-dir "docling${DOCLING_PIN}"
}
# -----------------------------------------------------------------------------
diff --git a/docker/infinity_conf.toml b/docker/infinity_conf.toml
index d1dc8bfdc31..661877389e5 100644
--- a/docker/infinity_conf.toml
+++ b/docker/infinity_conf.toml
@@ -1,5 +1,5 @@
[general]
-version = "0.6.15"
+version = "0.7.0"
time_zone = "utc-8"
[network]
diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template
index 1500c2eaf4f..f283f08530e 100644
--- a/docker/service_conf.yaml.template
+++ b/docker/service_conf.yaml.template
@@ -9,7 +9,7 @@ mysql:
user: '${MYSQL_USER:-root}'
password: '${MYSQL_PASSWORD:-infini_rag_flow}'
host: '${MYSQL_HOST:-mysql}'
- port: 3306
+ port: ${MYSQL_PORT:-3306}
max_connections: 900
stale_timeout: 300
max_allowed_packet: ${MYSQL_MAX_PACKET:-1073741824}
@@ -29,6 +29,7 @@ os:
password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}'
infinity:
uri: '${INFINITY_HOST:-infinity}:23817'
+ postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
@@ -38,6 +39,14 @@ oceanbase:
password: '${OCEANBASE_PASSWORD:-infini_rag_flow}'
host: '${OCEANBASE_HOST:-oceanbase}'
port: ${OCEANBASE_PORT:-2881}
+seekdb:
+ scheme: 'oceanbase' # SeekDB is the lite version of OceanBase
+ config:
+ db_name: '${SEEKDB_DOC_DBNAME:-ragflow_doc}'
+ user: '${SEEKDB_USER:-root}'
+ password: '${SEEKDB_PASSWORD:-infini_rag_flow}'
+ host: '${SEEKDB_HOST:-seekdb}'
+ port: ${SEEKDB_PORT:-2881}
redis:
db: 1
username: '${REDIS_USERNAME:-}'
@@ -72,6 +81,8 @@ user_default_llm:
# region: '${REGION}'
# bucket: '${BUCKET}'
# prefix_path: '${OSS_PREFIX_PATH}'
+# signature_version: 's3'
+# addressing_style: 'virtual'
# azure:
# auth_type: 'sas'
# container_url: 'container_url'
diff --git a/docs/basics/agent_context_engine.md b/docs/basics/agent_context_engine.md
index c00531e2843..4c4f9e2e253 100644
--- a/docs/basics/agent_context_engine.md
+++ b/docs/basics/agent_context_engine.md
@@ -1,6 +1,6 @@
---
sidebar_position: 2
-slug: /what_is_agent_context_engine
+slug: /what-is-agent-context-engine
---
# What is Agent context engine?
@@ -31,7 +31,7 @@ At its core, an Agent Context Engine is built on a triumvirate of next-generatio
2. The Memory Layer: An Agent’s intelligence is defined by its ability to learn from interaction. The Memory Layer is a specialized retrieval system for dynamic, episodic data: conversation history, user preferences, and the agent’s own internal state (e.g., "waiting for human input"). It manages the lifecycle of this data—storing raw dialogue, triggering summarization into semantic memory, and retrieving relevant past interactions to provide continuity and personalization. Technologically, it is a close sibling to RAG, but focused on a temporal stream of data.
-3. The Tool Orchestrator: As MCP (Model Context Protocol) enables the connection of hundreds of internal services as tools, a new problem arises: tool selection. The Context Engine solves this with Tool Retrieval. Instead of dumping all tool descriptions into the prompt, it maintains an index of tools and—critically—an index of Playbooks or Guidelines (best practices on when and how to use tools). For a given task, it retrieves only the most relevant tools and instructions, transforming the LLM’s job from "searching a haystack" to "following a recipe."
+3. The Tool Orchestrator: As MCP (Model Context Protocol) enables the connection of hundreds of internal services as tools, a new problem arises: tool selection. The Context Engine solves this with Tool Retrieval. Instead of dumping all tool descriptions into the prompt, it maintains an index of tools and—critically—an index of Skills (best practices on when and how to use tools). For a given task, it retrieves only the most relevant tools and instructions, transforming the LLM’s job from "searching a haystack" to "following a recipe."
## Why we need a dedicated engine? The case for a unified substrate
@@ -58,4 +58,4 @@ We left behind the label of “yet another RAG system” long ago. From DeepDoc
We believe tomorrow’s enterprise AI advantage will hinge not on who owns the largest model, but on who can feed that model the highest-quality, most real-time, and most relevant context. An Agentic Context Engine is the critical infrastructure that turns this vision into reality.
-In the paradigm shift from “hand-crafted prompts” to “intelligent context,” RAGFlow is determined to be the most steadfast propeller and enabler. We invite every developer, enterprise, and researcher who cares about the future of AI agents to follow RAGFlow’s journey—so together we can witness and build the cornerstone of the next-generation AI stack.
\ No newline at end of file
+In the paradigm shift from “hand-crafted prompts” to “intelligent context,” RAGFlow is determined to be the most steadfast propeller and enabler. We invite every developer, enterprise, and researcher who cares about the future of AI agents to follow RAGFlow’s journey—so together we can witness and build the cornerstone of the next-generation AI stack.
diff --git a/docs/basics/rag.md b/docs/basics/rag.md
index 90054ed56bd..470c6e05903 100644
--- a/docs/basics/rag.md
+++ b/docs/basics/rag.md
@@ -1,9 +1,9 @@
---
sidebar_position: 1
-slug: /what_is_rag
+slug: /what-is-rag
---
-# What is Retreival-Augmented-Generation (RAG)?
+# What is Retrieval-Augmented-Generation (RAG)?
Since large language models (LLMs) became the focus of technology, their ability to handle general knowledge has been astonishing. However, when questions shift to internal corporate documents, proprietary knowledge bases, or real-time data, the limitations of LLMs become glaringly apparent: they cannot access private information outside their training data. Retrieval-Augmented Generation (RAG) was born precisely to address this core need. Before an LLM generates an answer, it first retrieves the most relevant context from an external knowledge base and inputs it as "reference material" to the LLM, thereby guiding it to produce accurate answers. In short, RAG elevates LLMs from "relying on memory" to "having evidence to rely on," significantly improving their accuracy and trustworthiness in specialized fields and real-time information queries.
@@ -104,4 +104,4 @@ The evolution of RAG is unfolding along several clear paths:
3. Towards context engineering 2.0
Current RAG can be viewed as Context Engineering 1.0, whose core is assembling static knowledge context for single Q&A tasks. The forthcoming Context Engineering 2.0 will extend with RAG technology at its core, becoming a system that automatically and dynamically assembles comprehensive context for agents. The context fused by this system will come not only from documents but also include interaction memory, available tools/skills, and real-time environmental information. This marks the transition of agent development from a "handicraft workshop" model to the industrial starting point of automated context engineering.
-The essence of RAG is to build a dedicated, efficient, and trustworthy external data interface for large language models; its core is Retrieval, not Generation. Starting from the practical need to solve private data access, its technical depth is reflected in the optimization of retrieval for complex unstructured data. With its deep integration into agent architectures and its development towards automated context engineering, RAG is evolving from a technology that improves Q&A quality into the core infrastructure for building the next generation of trustworthy, controllable, and scalable intelligent applications.
\ No newline at end of file
+The essence of RAG is to build a dedicated, efficient, and trustworthy external data interface for large language models; its core is Retrieval, not Generation. Starting from the practical need to solve private data access, its technical depth is reflected in the optimization of retrieval for complex unstructured data. With its deep integration into agent architectures and its development towards automated context engineering, RAG is evolving from a technology that improves Q&A quality into the core infrastructure for building the next generation of trustworthy, controllable, and scalable intelligent applications.
diff --git a/docs/configurations.md b/docs/configurations.md
index b55042e8f5b..2b274c8e9b2 100644
--- a/docs/configurations.md
+++ b/docs/configurations.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /configurations
+sidebar_custom_props: {
+ sidebarIcon: LucideCog
+}
---
-
# Configuration
Configurations for deploying RAGFlow via Docker.
@@ -70,6 +72,8 @@ The [.env](https://github.com/infiniflow/ragflow/blob/main/docker/.env) file con
- `MYSQL_PASSWORD`
The password for MySQL.
- `MYSQL_PORT`
+ The port to connect to MySQL from RAGFlow container. Defaults to `3306`. Change this if you use an external MySQL.
+- `EXPOSE_MYSQL_PORT`
The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`.
### MinIO
@@ -99,7 +103,7 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
- `SVR_HTTP_PORT`
The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`.
- `RAGFLOW-IMAGE`
- The Docker image edition. Defaults to `infiniflow/ragflow:v0.23.1` (the RAGFlow Docker image without embedding models).
+ The Docker image edition. Defaults to `infiniflow/ragflow:v0.24.0` (the RAGFlow Docker image without embedding models).
:::tip NOTE
If you cannot download the RAGFlow Docker image, try the following mirrors.
diff --git a/docs/contribution/_category_.json b/docs/contribution/_category_.json
index 594fe200b4c..a9bd348a8cc 100644
--- a/docs/contribution/_category_.json
+++ b/docs/contribution/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Miscellaneous contribution guides."
+ },
+ "customProps": {
+ "sidebarIcon": "LucideHandshake"
}
}
diff --git a/docs/contribution/contributing.md b/docs/contribution/contributing.md
index 5d1ec19c1cb..39b5e1a5503 100644
--- a/docs/contribution/contributing.md
+++ b/docs/contribution/contributing.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /contributing
+sidebar_custom_props: {
+ categoryIcon: LucideBookA
+}
---
-
# Contribution guidelines
General guidelines for RAGFlow's community contributors.
diff --git a/docs/develop/_category_.json b/docs/develop/_category_.json
index 036bc99a129..c80693175f7 100644
--- a/docs/develop/_category_.json
+++ b/docs/develop/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Guides for hardcore developers"
+ },
+ "customProps": {
+ "sidebarIcon": "LucideWrench"
}
}
diff --git a/docs/develop/acquire_ragflow_api_key.md b/docs/develop/acquire_ragflow_api_key.md
index 4dc4520fe2b..c01b86bf70b 100644
--- a/docs/develop/acquire_ragflow_api_key.md
+++ b/docs/develop/acquire_ragflow_api_key.md
@@ -1,8 +1,10 @@
---
sidebar_position: 4
slug: /acquire_ragflow_api_key
+sidebar_custom_props: {
+ categoryIcon: LucideKey
+}
---
-
# Acquire RAGFlow API key
An API key is required for the RAGFlow server to authenticate your HTTP/Python or MCP requests. This documents provides instructions on obtaining a RAGFlow API key.
diff --git a/docs/develop/build_docker_image.mdx b/docs/develop/build_docker_image.mdx
index 3d20430f3b1..6cb2dede439 100644
--- a/docs/develop/build_docker_image.mdx
+++ b/docs/develop/build_docker_image.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /build_docker_image
+sidebar_custom_props: {
+ categoryIcon: LucidePackage
+}
---
-
# Build RAGFlow Docker image
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
@@ -47,7 +49,7 @@ After building the infiniflow/ragflow:nightly image, you are ready to launch a f
1. Edit Docker Compose Configuration
-Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.23.1` to `infiniflow/ragflow:nightly` to use the pre-built image.
+Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.24.0` to `infiniflow/ragflow:nightly` to use the pre-built image.
2. Launch the Service
diff --git a/docs/develop/launch_ragflow_from_source.md b/docs/develop/launch_ragflow_from_source.md
index 0f154252934..c193e2be373 100644
--- a/docs/develop/launch_ragflow_from_source.md
+++ b/docs/develop/launch_ragflow_from_source.md
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /launch_ragflow_from_source
+sidebar_custom_props: {
+ categoryIcon: LucideMonitorPlay
+}
---
-
# Launch service from source
A guide explaining how to set up a RAGFlow service from its source code. By following this guide, you'll be able to debug using the source code.
@@ -114,10 +116,10 @@ docker compose -f docker/docker-compose-base.yml up -d
npm install
```
-2. Update `proxy.target` in **.umirc.ts** to `http://127.0.0.1:9380`:
+2. Update `server.proxy.target` in **vite.config.ts** to `http://127.0.0.1:9380`:
```bash
- vim .umirc.ts
+ vim vite.config.ts
```
3. Start up the RAGFlow frontend service:
diff --git a/docs/develop/mcp/_category_.json b/docs/develop/mcp/_category_.json
index d2f129c23b8..eb7b1444aa9 100644
--- a/docs/develop/mcp/_category_.json
+++ b/docs/develop/mcp/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Guides and references on accessing RAGFlow's datasets via MCP."
+ },
+ "customProps": {
+ "categoryIcon": "SiModelcontextprotocol"
}
}
diff --git a/docs/develop/mcp/launch_mcp_server.md b/docs/develop/mcp/launch_mcp_server.md
index 2b9f052f06b..72a23aca19e 100644
--- a/docs/develop/mcp/launch_mcp_server.md
+++ b/docs/develop/mcp/launch_mcp_server.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /launch_mcp_server
+sidebar_custom_props: {
+ categoryIcon: LucideTvMinimalPlay
+}
---
-
# Launch RAGFlow MCP server
Launch an MCP server from source or via Docker.
diff --git a/docs/develop/mcp/mcp_client_example.md b/docs/develop/mcp/mcp_client_example.md
index 7eb6a1cf9c0..4393eb53bee 100644
--- a/docs/develop/mcp/mcp_client_example.md
+++ b/docs/develop/mcp/mcp_client_example.md
@@ -1,9 +1,11 @@
---
sidebar_position: 3
slug: /mcp_client
+sidebar_custom_props: {
+ categoryIcon: LucideBookMarked
+}
---
-
# RAGFlow MCP client examples
Python and curl MCP client examples.
diff --git a/docs/develop/mcp/mcp_tools.md b/docs/develop/mcp/mcp_tools.md
index a087de62c75..1a8be9f8074 100644
--- a/docs/develop/mcp/mcp_tools.md
+++ b/docs/develop/mcp/mcp_tools.md
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /mcp_tools
+sidebar_custom_props: {
+ categoryIcon: LucideToolCase
+}
---
-
# RAGFlow MCP tools
The MCP server currently offers a specialized tool to assist users in searching for relevant information powered by RAGFlow DeepDoc technology:
diff --git a/docs/develop/migrate_to_single_bucket_mode.md b/docs/develop/migrate_to_single_bucket_mode.md
index ce258d4e8d9..de7c8fe873b 100644
--- a/docs/develop/migrate_to_single_bucket_mode.md
+++ b/docs/develop/migrate_to_single_bucket_mode.md
@@ -1,4 +1,3 @@
-
---
sidebar_position: 20
slug: /migrate_to_single_bucket_mode
diff --git a/docs/develop/switch_doc_engine.md b/docs/develop/switch_doc_engine.md
index ebac20bd686..10ff68eb994 100644
--- a/docs/develop/switch_doc_engine.md
+++ b/docs/develop/switch_doc_engine.md
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /switch_doc_engine
+sidebar_custom_props: {
+ categoryIcon: LucideShuffle
+}
---
-
# Switch document engine
Switch your doc engine from Elasticsearch to Infinity.
diff --git a/docs/faq.mdx b/docs/faq.mdx
index 6c1334d13cf..cc7ab374b57 100644
--- a/docs/faq.mdx
+++ b/docs/faq.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 10
slug: /faq
+sidebar_custom_props: {
+ sidebarIcon: LucideCircleQuestionMark
+}
---
-
# FAQs
Answers to questions about general features, troubleshooting, usage, and more.
@@ -41,11 +43,11 @@ You can find the RAGFlow version number on the **System** page of the UI:
If you build RAGFlow from source, the version number is also in the system log:
```
- ____ ___ ______ ______ __
+ ____ ___ ______ ______ __
/ __ \ / | / ____// ____// /____ _ __
/ /_/ // /| | / / __ / /_ / // __ \| | /| / /
- / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
- /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/
+ / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
+ /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/
2025-02-18 10:10:43,835 INFO 1445658 RAGFlow version: v0.15.0-50-g6daae7f2
```
@@ -175,7 +177,7 @@ To fix this issue, use https://hf-mirror.com instead:
3. Start up the server:
```bash
- docker compose up -d
+ docker compose up -d
```
---
@@ -208,11 +210,11 @@ You will not log in to RAGFlow unless the server is fully initialized. Run `dock
*The server is successfully initialized, if your system displays the following:*
```
- ____ ___ ______ ______ __
+ ____ ___ ______ ______ __
/ __ \ / | / ____// ____// /____ _ __
/ /_/ // /| | / / __ / /_ / // __ \| | /| / /
- / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
- /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/
+ / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
+ /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/
* Running on all addresses (0.0.0.0)
* Running on http://127.0.0.1:9380
@@ -315,7 +317,7 @@ The status of a Docker container status does not necessarily reflect the status
$ docker ps
```
- *The status of a healthy Elasticsearch component should look as follows:*
+ *The status of a healthy Elasticsearch component should look as follows:*
```
91220e3285dd docker.elastic.co/elasticsearch/elasticsearch:8.11.3 "/bin/tini -- /usr/l…" 11 hours ago Up 11 hours (healthy) 9300/tcp, 0.0.0.0:9200->9200/tcp, :::9200->9200/tcp ragflow-es-01
@@ -368,7 +370,7 @@ Yes, we do. See the Python files under the **rag/app** folder.
$ docker ps
```
- *The status of a healthy Elasticsearch component should look as follows:*
+ *The status of a healthy Elasticsearch component should look as follows:*
```bash
cd29bcb254bc quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z "/usr/bin/docker-ent…" 2 weeks ago Up 11 hours 0.0.0.0:9001->9001/tcp, :::9001->9001/tcp, 0.0.0.0:9000->9000/tcp, :::9000->9000/tcp ragflow-minio
@@ -451,7 +453,7 @@ See [Upgrade RAGFlow](./guides/upgrade_ragflow.mdx) for more information.
To switch your document engine from Elasticsearch to [Infinity](https://github.com/infiniflow/infinity):
-1. Stop all running containers:
+1. Stop all running containers:
```bash
$ docker compose -f docker/docker-compose.yml down -v
@@ -461,7 +463,7 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
:::
2. In **docker/.env**, set `DOC_ENGINE=${DOC_ENGINE:-infinity}`
-3. Restart your Docker image:
+3. Restart your Docker image:
```bash
$ docker compose -f docker-compose.yml up -d
@@ -506,12 +508,12 @@ From v0.22.0 onwards, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF pa
- `"vlm-mlx-engine"`
- `"vlm-vllm-async-engine"`
- `"vlm-lmdeploy-engine"`.
- - `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
+ - `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
- `1`: Delete.
- `0`: Retain.
-3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
+3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
@@ -564,3 +566,82 @@ RAGFlow supports MinerU's `vlm-http-client` backend, enabling you to delegate do
:::tip NOTE
When using the `vlm-http-client` backend, the RAGFlow server requires no GPU, only network connectivity. This enables cost-effective distributed deployment with multiple RAGFlow instances sharing one remote vLLM server.
:::
+
+### How to use PaddleOCR for document parsing?
+
+From v0.24.0 onwards, RAGFlow includes PaddleOCR as an optional PDF parser. Please note that RAGFlow acts only as a *remote client* for PaddleOCR, calling the PaddleOCR API to parse PDFs and reading the returned files.
+
+There are two main ways to configure and use PaddleOCR in RAGFlow:
+
+#### 1. Using PaddleOCR Official API
+
+This method uses PaddleOCR's official API service with an access token.
+
+**Step 1: Configure RAGFlow**
+- **Via Environment Variables:**
+ ```bash
+ # In your docker/.env file:
+ PADDLEOCR_API_URL=https://your-paddleocr-api-endpoint
+ PADDLEOCR_ALGORITHM=PaddleOCR-VL
+ PADDLEOCR_ACCESS_TOKEN=your-access-token-here
+ ```
+
+- **Via UI:**
+ - Navigate to **Model providers** page
+ - Add a new OCR model with factory type "PaddleOCR"
+ - Configure the following fields:
+ - **PaddleOCR API URL**: Your PaddleOCR API endpoint
+ - **PaddleOCR Algorithm**: Select the algorithm corresponding to the API endpoint
+ - **AI Studio Access Token**: Your access token for the PaddleOCR API
+
+**Step 2: Usage in Dataset Configuration**
+- In your dataset's **Configuration** page, find the **Ingestion pipeline** section
+- If using built-in chunking methods that support PDF parsing, select **PaddleOCR** from the **PDF parser** dropdown
+- If using custom ingestion pipeline, select **PaddleOCR** in the **Parser** component
+
+**Notes:**
+- To obtain the API URL, visit the [PaddleOCR official website](https://aistudio.baidu.com/paddleocr), click the **API** button, choose the example code for the specific algorithm you want to use (e.g., PaddleOCR-VL), and copy the `API_URL`.
+- Access tokens can be obtained from the [AI Studio platform](https://aistudio.baidu.com/account/accessToken).
+- This method requires internet connectivity to reach the official PaddleOCR API.
+
+#### 2. Using Self-Hosted PaddleOCR Service
+
+This method allows you to deploy your own PaddleOCR service and use it without an access token.
+
+**Step 1: Deploy PaddleOCR Service**
+Follow the [PaddleOCR serving documentation](https://www.paddleocr.ai/latest/en/version3.x/deployment/serving.html) to deploy your own service. For layout parsing, you can use an endpoint like:
+
+```bash
+http://localhost:8080/layout-parsing
+```
+
+**Step 2: Configure RAGFlow**
+- **Via Environment Variables:**
+ ```bash
+ PADDLEOCR_API_URL=http://localhost:8080/layout-parsing
+ PADDLEOCR_ALGORITHM=PaddleOCR-VL
+ # No access token required for self-hosted service
+ ```
+
+- **Via UI:**
+ - Navigate to **Model providers** page
+ - Add a new OCR model with factory type "PaddleOCR"
+ - Configure the following fields:
+ - **PaddleOCR API URL**: The endpoint of your deployed service
+ - **PaddleOCR Algorithm**: Select the algorithm corresponding to the deployed service
+ - **AI Studio Access Token**: Leave empty
+
+**Step 3: Usage in Dataset Configuration**
+- In your dataset's **Configuration** page, find the **Ingestion pipeline** section
+- If using built-in chunking methods that support PDF parsing, select **PaddleOCR** from the **PDF parser** dropdown
+- If using custom ingestion pipeline, select **PaddleOCR** in the **Parser** component
+
+#### Environment Variables Summary
+
+| Environment Variable | Description | Default | Required |
+|---------------------|-------------|---------|----------|
+| `PADDLEOCR_API_URL` | PaddleOCR API endpoint URL | `""` | Yes, when using environment variables |
+| `PADDLEOCR_ALGORITHM` | Algorithm to use for parsing | `"PaddleOCR-VL"` | No |
+| `PADDLEOCR_ACCESS_TOKEN` | Access token for official API | `None` | Only when using official API |
+
+Environment variables can be used for auto-provisioning, but are not required if configuring via UI. When environment variables are set, these values are used to auto-provision a PaddleOCR model for the tenant on first use.
diff --git a/docs/guides/_category_.json b/docs/guides/_category_.json
index 895506b000c..18f4890a985 100644
--- a/docs/guides/_category_.json
+++ b/docs/guides/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Guides for RAGFlow users and developers."
+ },
+ "customProps": {
+ "sidebarIcon": "LucideBookMarked"
}
}
diff --git a/docs/guides/admin/_category_.json b/docs/guides/admin/_category_.json
new file mode 100644
index 00000000000..fa6d832fc8d
--- /dev/null
+++ b/docs/guides/admin/_category_.json
@@ -0,0 +1,11 @@
+{
+ "label": "Administration",
+ "position": 6,
+ "link": {
+ "type": "generated-index",
+ "description": "RAGFlow administration"
+ },
+ "customProps": {
+ "categoryIcon": "LucideUserCog"
+ }
+}
diff --git a/docs/guides/admin/admin_service.md b/docs/guides/admin/admin_service.md
new file mode 100644
index 00000000000..35ecabae938
--- /dev/null
+++ b/docs/guides/admin/admin_service.md
@@ -0,0 +1,40 @@
+---
+sidebar_position: 0
+slug: /admin_service
+sidebar_custom_props: {
+ categoryIcon: LucideActivity
+}
+---
+# Admin Service
+
+The Admin Service is the core backend management service of the RAGFlow system, providing comprehensive system administration capabilities through centralized API interfaces for managing and controlling the entire platform. Adopting a client-server architecture, it supports access and operations via both a Web UI and an Admin CLI, ensuring flexible and efficient execution of administrative tasks.
+
+The core functions of the Admin Service include real-time monitoring of the operational status of the RAGFlow server and its critical dependent components—such as MySQL, Elasticsearch, Redis, and MinIO—along with full-featured user management. In administrator mode, it enables key operations such as viewing user information, creating users, updating passwords, modifying activation status, and performing complete user data deletion. These functions remain accessible via the Admin CLI even when the web management interface is disabled, ensuring the system stays under control at all times.
+
+With its unified interface design, the Admin Service combines the convenience of visual administration with the efficiency and stability of command-line operations, serving as a crucial foundation for the reliable operation and secure management of the RAGFlow system.
+
+## Starting the Admin Service
+
+### Launching from source code
+
+1. Before start Admin Service, please make sure RAGFlow system is already started.
+
+2. Launch from source code:
+
+ ```bash
+ python admin/server/admin_server.py
+ ```
+
+ The service will start and listen for incoming connections from the CLI on the configured port.
+
+### Using docker image
+
+1. Before startup, please configure the `docker_compose.yml` file to enable admin server:
+
+ ```bash
+ command:
+ - --enable-adminserver
+ ```
+
+2. Start the containers, the service will start and listen for incoming connections from the CLI on the configured port.
+
diff --git a/docs/guides/accessing_admin_ui.md b/docs/guides/admin/admin_ui.md
similarity index 95%
rename from docs/guides/accessing_admin_ui.md
rename to docs/guides/admin/admin_ui.md
index aafd6e99703..9584bb8cfc7 100644
--- a/docs/guides/accessing_admin_ui.md
+++ b/docs/guides/admin/admin_ui.md
@@ -1,8 +1,10 @@
---
-sidebar_position: 7
-slug: /accessing_admin_ui
+sidebar_position: 1
+slug: /admin_ui
+sidebar_custom_props: {
+ categoryIcon: LucidePalette
+}
---
-
# Admin UI
The RAGFlow Admin UI is a web-based interface that provides comprehensive system status monitoring and user management capabilities.
diff --git a/docs/guides/manage_users_and_services.md b/docs/guides/admin/ragflow_cli.md
similarity index 56%
rename from docs/guides/manage_users_and_services.md
rename to docs/guides/admin/ragflow_cli.md
index 0ec0b112d2c..f682d6be64d 100644
--- a/docs/guides/manage_users_and_services.md
+++ b/docs/guides/admin/ragflow_cli.md
@@ -1,52 +1,22 @@
---
-sidebar_position: 6
-slug: /manage_users_and_services
+sidebar_position: 2
+slug: /admin_cli
+sidebar_custom_props: {
+ categoryIcon: LucideSquareTerminal
+}
---
+# RAGFlow CLI
+The RAGFlow CLI is a command-line-based system administration tool that offers administrators an efficient and flexible method for system interaction and control. Operating on a client-server architecture, it communicates in real-time with the Admin Service, receiving administrator commands and dynamically returning execution results.
-# Admin CLI and Admin Service
-
-
-
-The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
-
-
-
-## Starting the Admin Service
-
-### Launching from source code
-
-1. Before start Admin Service, please make sure RAGFlow system is already started.
-
-2. Launch from source code:
-
- ```bash
- python admin/server/admin_server.py
- ```
-
- The service will start and listen for incoming connections from the CLI on the configured port.
-
-### Using docker image
-
-1. Before startup, please configure the `docker_compose.yml` file to enable admin server:
-
- ```bash
- command:
- - --enable-adminserver
- ```
-
-2. Start the containers, the service will start and listen for incoming connections from the CLI on the configured port.
-
-
-
-## Using the Admin CLI
+## Using the RAGFlow CLI
1. Ensure the Admin Service is running.
2. Install ragflow-cli.
```bash
- pip install ragflow-cli==0.23.1
+ pip install ragflow-cli==0.24.0
```
3. Launch the CLI client:
@@ -123,6 +93,21 @@ Commands are case-insensitive and must be terminated with a semicolon(;).
- Changes the user to active or inactive.
- [Example](#example-alter-user-active)
+`GENERATE KEY FOR USER ;`
+
+- Generates a new API key for the specified user.
+- [Example](#example-generate-key)
+
+`LIST KEYS OF ;`
+
+- Lists all API keys associated with the specified user.
+- [Example](#example-list-keys)
+
+`DROP KEY OF ;`
+
+- Deletes a specific API key for the specified user.
+- [Example](#example-drop-key)
+
### Data and Agent Commands
`LIST DATASETS OF ;`
@@ -135,6 +120,40 @@ Commands are case-insensitive and must be terminated with a semicolon(;).
- Lists the agents associated with the specified user.
- [Example](#example-list-agents-of-user)
+### System info
+
+`SHOW VERSION;`
+- Display the current RAGFlow version.
+- [Example](#example-show-version)
+
+`GRANT ADMIN `
+- Grant administrator privileges to the specified user.
+- [Example](#example-grant-admin)
+
+`REVOKE ADMIN `
+- Revoke administrator privileges from the specified user.
+- [Example](#example-revoke-admin)
+
+`LIST VARS`
+- List all system settings.
+- [Example](#example-list-vars)
+
+`SHOW VAR `
+- Display the content of a specific system configuration/setting by its name or name prefix.
+- [Example](#example-show-var)
+
+`SET VAR `
+- Set the value for a specified configuration item.
+- [Example](#example-set-var)
+
+`LIST CONFIGS`
+- List all system configurations.
+- [Example](#example-list-configs)
+
+`LIST ENVS`
+- List all system environments which can accessed by Admin service.
+- [Example](#example-list-environments)
+
### Meta-Commands
- \? or \help
@@ -150,7 +169,7 @@ Commands are case-insensitive and must be terminated with a semicolon(;).
- List all available services.
```
-admin> list services;
+ragflow> list services;
command: list services;
Listing all services
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+---------+
@@ -171,7 +190,7 @@ Listing all services
- Show ragflow_server.
```
-admin> show service 0;
+ragflow> show service 0;
command: show service 0;
Showing service: 0
Service ragflow_0 is alive. Detail:
@@ -181,7 +200,7 @@ Confirm elapsed: 26.0 ms.
- Show mysql.
```
-admin> show service 1;
+ragflow> show service 1;
command: show service 1;
Showing service: 1
Service mysql is alive. Detail:
@@ -197,7 +216,7 @@ Service mysql is alive. Detail:
- Show minio.
```
-admin> show service 2;
+ragflow> show service 2;
command: show service 2;
Showing service: 2
Service minio is alive. Detail:
@@ -207,7 +226,7 @@ Confirm elapsed: 2.1 ms.
- Show elasticsearch.
```
-admin> show service 3;
+ragflow> show service 3;
command: show service 3;
Showing service: 3
Service elasticsearch is alive. Detail:
@@ -221,7 +240,7 @@ Service elasticsearch is alive. Detail:
- Show infinity.
```
-admin> show service 4;
+ragflow> show service 4;
command: show service 4;
Showing service: 4
Fail to show service, code: 500, message: Infinity is not in use.
@@ -230,7 +249,7 @@ Fail to show service, code: 500, message: Infinity is not in use.
- Show redis.
```
-admin> show service 5;
+ragflow> show service 5;
command: show service 5;
Showing service: 5
Service redis is alive. Detail:
@@ -245,7 +264,7 @@ Service redis is alive. Detail:
- Show RAGFlow version
```
-admin> show version;
+ragflow> show version;
+-----------------------+
| version |
+-----------------------+
@@ -258,7 +277,7 @@ admin> show version;
- List all user.
```
-admin> list users;
+ragflow> list users;
command: list users;
Listing all users
+-------------------------------+----------------------+-----------+----------+
@@ -274,7 +293,7 @@ Listing all users
- Show specified user.
```
-admin> show user "admin@ragflow.io";
+ragflow> show user "admin@ragflow.io";
command: show user "admin@ragflow.io";
Showing user: admin@ragflow.io
+-------------------------------+------------------+-----------+--------------+------------------+--------------+----------+-----------------+---------------+--------+-------------------------------+
@@ -289,7 +308,7 @@ Showing user: admin@ragflow.io
- Create new user.
```
-admin> create user "example@ragflow.io" "psw";
+ragflow> create user "example@ragflow.io" "psw";
command: create user "example@ragflow.io" "psw";
Create user: example@ragflow.io, password: psw, role: user
+----------------------------------+--------------------+----------------------------------+--------------+---------------+----------+
@@ -304,7 +323,7 @@ Create user: example@ragflow.io, password: psw, role: user
- Alter user password.
```
-admin> alter user password "example@ragflow.io" "newpsw";
+ragflow> alter user password "example@ragflow.io" "newpsw";
command: alter user password "example@ragflow.io" "newpsw";
Alter user: example@ragflow.io, password: newpsw
Password updated successfully!
@@ -315,7 +334,7 @@ Password updated successfully!
- Alter user active, turn off.
```
-admin> alter user active "example@ragflow.io" off;
+ragflow> alter user active "example@ragflow.io" off;
command: alter user active "example@ragflow.io" off;
Alter user example@ragflow.io activate status, turn off.
Turn off user activate status successfully!
@@ -326,7 +345,7 @@ Turn off user activate status successfully!
- Drop user.
```
-admin> Drop user "example@ragflow.io";
+ragflow> Drop user "example@ragflow.io";
command: Drop user "example@ragflow.io";
Drop user: example@ragflow.io
Successfully deleted user. Details:
@@ -341,12 +360,50 @@ Delete done!
Delete user's data at the same time.
+
+
+- Generate API key for user.
+
+```
+admin> generate key for user "example@ragflow.io";
+Generating API key for user: example@ragflow.io
++----------------------------------+-------------------------------+---------------+----------------------------------+-----------------------------------------------------+-------------+-------------+
+| beta | create_date | create_time | tenant_id | token | update_date | update_time |
++----------------------------------+-------------------------------+---------------+----------------------------------+-----------------------------------------------------+-------------+-------------+
+| Es9OpZ6hrnPGeYA3VU1xKUkj6NCb7cp- | Mon, 12 Jan 2026 15:19:11 GMT | 1768227551361 | 5d5ea8a3efc111f0a79b80fa5b90e659 | ragflow-piwVJHEk09M5UN3LS_Xx9HA7yehs3yNOc9GGsD4jzus | None | None |
++----------------------------------+-------------------------------+---------------+----------------------------------+-----------------------------------------------------+-------------+-------------+
+```
+
+
+
+- List all API keys for user.
+
+```
+admin> list keys of "example@ragflow.io";
+Listing API keys for user: example@ragflow.io
++----------------------------------+-------------------------------+---------------+-----------+--------+----------------------------------+-----------------------------------------------------+-------------------------------+---------------+
+| beta | create_date | create_time | dialog_id | source | tenant_id | token | update_date | update_time |
++----------------------------------+-------------------------------+---------------+-----------+--------+----------------------------------+-----------------------------------------------------+-------------------------------+---------------+
+| Es9OpZ6hrnPGeYA3VU1xKUkj6NCb7cp- | Mon, 12 Jan 2026 15:19:11 GMT | 1768227551361 | None | None | 5d5ea8a3efc111f0a79b80fa5b90e659 | ragflow-piwVJHEk09M5UN3LS_Xx9HA7yehs3yNOc9GGsD4jzus | Mon, 12 Jan 2026 15:19:11 GMT | 1768227551361 |
++----------------------------------+-------------------------------+---------------+-----------+--------+----------------------------------+-----------------------------------------------------+-------------------------------+---------------+
+```
+
+
+
+- Drop API key for user.
+
+```
+admin> drop key "ragflow-piwVJHEk09M5UN3LS_Xx9HA7yehs3yNOc9GGsD4jzus" of "example@ragflow.io";
+Dropping API key for user: example@ragflow.io
+API key deleted successfully
+```
+
- List the specified user's dataset.
```
-admin> list datasets of "lynn_inf@hotmail.com";
+ragflow> list datasets of "lynn_inf@hotmail.com";
command: list datasets of "lynn_inf@hotmail.com";
Listing all datasets of user: lynn_inf@hotmail.com
+-----------+-------------------------------+---------+----------+---------------+------------+--------+-----------+-------------------------------+
@@ -362,7 +419,7 @@ Listing all datasets of user: lynn_inf@hotmail.com
- List the specified user's agents.
```
-admin> list agents of "lynn_inf@hotmail.com";
+ragflow> list agents of "lynn_inf@hotmail.com";
command: list agents of "lynn_inf@hotmail.com";
Listing all agents of user: lynn_inf@hotmail.com
+-----------------+-------------+------------+-----------------+
@@ -372,28 +429,157 @@ Listing all agents of user: lynn_inf@hotmail.com
+-----------------+-------------+------------+-----------------+
```
+
+
+- Display the current RAGFlow version.
+
+```
+ragflow> show version;
+show_version
++-----------------------+
+| version |
++-----------------------+
+| v0.24.0-24-g6f60e9f9e |
++-----------------------+
+```
+
+
+
+- Grant administrator privileges to the specified user.
+
+```
+ragflow> grant admin "anakin.skywalker@ragflow.io";
+Grant successfully!
+```
+
+
+
+- Revoke administrator privileges from the specified user.
+
+```
+ragflow> revoke admin "anakin.skywalker@ragflow.io";
+Revoke successfully!
+```
+
+
+
+- List all system settings.
+
+```
+ragflow> list vars;
++-----------+---------------------+--------------+-----------+
+| data_type | name | source | value |
++-----------+---------------------+--------------+-----------+
+| string | default_role | variable | user |
+| bool | enable_whitelist | variable | true |
+| string | mail.default_sender | variable | |
+| string | mail.password | variable | |
+| integer | mail.port | variable | 15 |
+| string | mail.server | variable | localhost |
+| integer | mail.timeout | variable | 10 |
+| bool | mail.use_ssl | variable | true |
+| bool | mail.use_tls | variable | false |
+| string | mail.username | variable | |
++-----------+---------------------+--------------+-----------+
+```
+
+
+
+- Display the content of a specific system configuration/setting by its name or name prefix.
+
+```
+ragflow> show var mail.server;
++-----------+-------------+--------------+-----------+
+| data_type | name | source | value |
++-----------+-------------+--------------+-----------+
+| string | mail.server | variable | localhost |
++-----------+-------------+--------------+-----------+
+```
+
+
+
+- Set the value for a specified configuration item.
+
+```
+ragflow> set var mail.server 127.0.0.1;
+Set variable successfully
+```
+
+
+
+
+- List all system configurations.
+
+```
+ragflow> list configs;
++-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
+| extra | host | id | name | port | service_type |
++-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
+| {} | 0.0.0.0 | 0 | ragflow_0 | 9380 | ragflow_server |
+| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'} | localhost | 1 | mysql | 5455 | meta_data |
+| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'} | localhost | 2 | minio | 9000 | file_store |
+| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3 | elasticsearch | 1200 | retrieval |
+| {'db_name': 'default_db', 'retrieval_type': 'infinity'} | localhost | 4 | infinity | 23817 | retrieval |
+| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'} | localhost | 5 | redis | 6379 | message_queue |
+| {'message_queue_type': 'redis'} | | 6 | task_executor | 0 | task_executor |
++-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
+```
+
+
+
+- List all system environments which can accessed by Admin service.
+
+```
+ragflow> list envs;
++-------------------------+------------------+
+| env | value |
++-------------------------+------------------+
+| DOC_ENGINE | elasticsearch |
+| DEFAULT_SUPERUSER_EMAIL | admin@ragflow.io |
+| DB_TYPE | mysql |
+| DEVICE | cpu |
+| STORAGE_IMPL | MINIO |
++-------------------------+------------------+
+```
+
+
- Show help information.
```
-admin> \help
+ragflow> \help
command: \help
Commands:
- LIST SERVICES
- SHOW SERVICE
- STARTUP SERVICE
- SHUTDOWN SERVICE
- RESTART SERVICE
- LIST USERS
- SHOW USER
- DROP USER
- CREATE USER
- ALTER USER PASSWORD
- ALTER USER ACTIVE
- LIST DATASETS OF
- LIST AGENTS OF
+LIST SERVICES
+SHOW SERVICE
+STARTUP SERVICE
+SHUTDOWN SERVICE
+RESTART SERVICE
+LIST USERS
+SHOW USER
+DROP USER
+CREATE USER
+ALTER USER PASSWORD
+ALTER USER ACTIVE
+LIST DATASETS OF
+LIST AGENTS OF
+CREATE ROLE
+DROP ROLE
+ALTER ROLE SET DESCRIPTION
+LIST ROLES
+SHOW ROLE
+GRANT ON TO ROLE
+REVOKE ON TO ROLE
+ALTER USER SET ROLE
+SHOW USER PERMISSION
+SHOW VERSION
+GRANT ADMIN
+REVOKE ADMIN
+GENERATE KEY FOR USER
+LIST KEYS OF
+DROP KEY OF
Meta Commands:
\?, \h, \help Show this help
@@ -403,8 +589,7 @@ Meta Commands:
- Exit
```
-admin> \q
+ragflow> \q
command: \q
Goodbye!
```
-
diff --git a/docs/guides/agent/_category_.json b/docs/guides/agent/_category_.json
index 020ba1d3f72..dc81d28a494 100644
--- a/docs/guides/agent/_category_.json
+++ b/docs/guides/agent/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "RAGFlow v0.8.0 introduces an agent mechanism, featuring a no-code workflow editor on the front end and a comprehensive graph-based task orchestration framework on the backend."
+ },
+ "customProps": {
+ "categoryIcon": "RagAiAgent"
}
}
diff --git a/docs/guides/agent/agent_component_reference/_category_.json b/docs/guides/agent/agent_component_reference/_category_.json
index 7548ec8031b..34669a6b76c 100644
--- a/docs/guides/agent/agent_component_reference/_category_.json
+++ b/docs/guides/agent/agent_component_reference/_category_.json
@@ -1,8 +1,11 @@
{
- "label": "Agent Components",
+ "label": "Components",
"position": 20,
"link": {
"type": "generated-index",
"description": "A complete reference for RAGFlow's agent components."
+ },
+ "customProps": {
+ "categoryIcon": "RagAiAgent"
}
}
diff --git a/docs/guides/agent/agent_component_reference/agent.mdx b/docs/guides/agent/agent_component_reference/agent.mdx
index 882c22be12d..e3d6e46a1e4 100644
--- a/docs/guides/agent/agent_component_reference/agent.mdx
+++ b/docs/guides/agent/agent_component_reference/agent.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /agent_component
+sidebar_custom_props: {
+ categoryIcon: RagAiAgent
+}
---
-
# Agent component
The component equipped with reasoning, tool usage, and multi-agent collaboration capabilities.
@@ -133,7 +135,7 @@ Click the dropdown menu of **Model** to show the model configuration window.
- A higher **frequency penalty** value results in the model being more conservative in its use of repeated tokens.
- Defaults to 0.7.
- **Max tokens**:
- This sets the maximum length of the model's output, measured in the number of tokens (words or pieces of words). It is disabled by default, allowing the model to determine the number of tokens in its responses.
+ - The maximum context size of the model.
:::tip NOTE
- It is not necessary to stick with the same model for all components. If a specific model is not performing well for a particular task, consider using a different one.
diff --git a/docs/guides/agent/agent_component_reference/await_response.mdx b/docs/guides/agent/agent_component_reference/await_response.mdx
index 973e1dfa5e6..f47da3cbd3c 100644
--- a/docs/guides/agent/agent_component_reference/await_response.mdx
+++ b/docs/guides/agent/agent_component_reference/await_response.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 5
slug: /await_response
+sidebar_custom_props: {
+ categoryIcon: LucideMessageSquareDot
+}
---
-
# Await response component
A component that halts the workflow and awaits user input.
diff --git a/docs/guides/agent/agent_component_reference/begin.mdx b/docs/guides/agent/agent_component_reference/begin.md
similarity index 67%
rename from docs/guides/agent/agent_component_reference/begin.mdx
rename to docs/guides/agent/agent_component_reference/begin.md
index c265bd2c6a8..1368efebdb1 100644
--- a/docs/guides/agent/agent_component_reference/begin.mdx
+++ b/docs/guides/agent/agent_component_reference/begin.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /begin_component
+sidebar_custom_props: {
+ categoryIcon: LucideHome
+}
---
-
# Begin component
The starting component in a workflow.
@@ -25,6 +27,50 @@ Mode defines how the workflow is triggered.
- Conversational: The agent is triggered from a conversation.
- Task: The agent starts without a conversation.
+- Webhook: Receive external HTTP requests via webhooks, enabling automated triggers and workflow initiation.
+ *When selected, a unique Webhook URL is generated for the current agent.*
+
+
+
+### Methods
+
+The supported HTTP methods. Available only when **Webhook** is selected as **Mode**.
+
+
+### Security
+
+The authentication method to choose, available *only* when **Webhook** is selected as **Mode**. Including:
+
+- **token**: Token-based authentication.
+- **basic**: Basic authentication.
+- **jwt**: JWT authentication.
+
+### Schema
+
+The schema defines the data structure for HTTP requests received by the system in **Webhook** mode. It configurations include:
+
+- Content type:
+ - `application/json`
+ - `multipart/form-data`
+ - `application/x-www-form-urlencoded`
+ - `text-plain`
+ - `application/octet-stream`
+- Query parameters
+- Header parameters
+- Request body parameters
+
+### Response
+
+Available only when **Webhook** is selected as **Mode**.
+
+The response mode of the workflow, i.e., how the workflow respond to external HTTP requests. Supported options:
+
+- **Accepted response**: When an HTTP request is validated, a success response is returned immediately, and the workflow runs asynchronously in the background.
+ - When selected, you configure the corresponding HTTP status code and message in the **Begin** component.
+ - The HTTP status code to return is in the range of `200-399`.
+- **Final response**: The system returns the final processing result only after the entire workflow completes.
+ - When selected, you configure the corresponding HTTP status code and message in the [message](./message.md) component.
+ - The HTTP status code to return is in the range of `200-399`.
### Opening greeting
diff --git a/docs/guides/agent/agent_component_reference/categorize.mdx b/docs/guides/agent/agent_component_reference/categorize.mdx
index a40cc3731de..57cd14ea7bc 100644
--- a/docs/guides/agent/agent_component_reference/categorize.mdx
+++ b/docs/guides/agent/agent_component_reference/categorize.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 8
slug: /categorize_component
+sidebar_custom_props: {
+ categoryIcon: LucideSwatchBook
+}
---
-
# Categorize component
A component that classifies user inputs and applies strategies accordingly.
diff --git a/docs/guides/agent/agent_component_reference/chunker_title.md b/docs/guides/agent/agent_component_reference/chunker_title.md
index 27b8a97ce59..787f6602806 100644
--- a/docs/guides/agent/agent_component_reference/chunker_title.md
+++ b/docs/guides/agent/agent_component_reference/chunker_title.md
@@ -1,8 +1,10 @@
---
sidebar_position: 31
slug: /chunker_title_component
+sidebar_custom_props: {
+ categoryIcon: LucideBlocks
+}
---
-
# Title chunker component
A component that splits texts into chunks by heading level.
diff --git a/docs/guides/agent/agent_component_reference/chunker_token.md b/docs/guides/agent/agent_component_reference/chunker_token.md
index d93f0ea4288..ee0c1e79a0f 100644
--- a/docs/guides/agent/agent_component_reference/chunker_token.md
+++ b/docs/guides/agent/agent_component_reference/chunker_token.md
@@ -1,8 +1,10 @@
---
sidebar_position: 32
slug: /chunker_token_component
+sidebar_custom_props: {
+ categoryIcon: LucideBlocks
+}
---
-
# Token chunker component
A component that splits texts into chunks, respecting a maximum token limit and using delimiters to find optimal breakpoints.
diff --git a/docs/guides/agent/agent_component_reference/code.mdx b/docs/guides/agent/agent_component_reference/code.mdx
index ea483158148..a9472ca5e03 100644
--- a/docs/guides/agent/agent_component_reference/code.mdx
+++ b/docs/guides/agent/agent_component_reference/code.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 13
slug: /code_component
+sidebar_custom_props: {
+ categoryIcon: LucideCodeXml
+}
---
-
# Code component
A component that enables users to integrate Python or JavaScript codes into their Agent for dynamic data processing.
diff --git a/docs/guides/agent/agent_component_reference/execute_sql.md b/docs/guides/agent/agent_component_reference/execute_sql.md
index 47561eccb0f..30c9c9912fa 100644
--- a/docs/guides/agent/agent_component_reference/execute_sql.md
+++ b/docs/guides/agent/agent_component_reference/execute_sql.md
@@ -1,8 +1,10 @@
---
sidebar_position: 25
slug: /execute_sql
+sidebar_custom_props: {
+ categoryIcon: RagSql
+}
---
-
# Execute SQL tool
A tool that execute SQL queries on a specified relational database.
diff --git a/docs/guides/agent/agent_component_reference/http.md b/docs/guides/agent/agent_component_reference/http.md
index 51277f0182d..66ee8067abd 100644
--- a/docs/guides/agent/agent_component_reference/http.md
+++ b/docs/guides/agent/agent_component_reference/http.md
@@ -1,8 +1,10 @@
---
sidebar_position: 30
slug: /http_request_component
+sidebar_custom_props: {
+ categoryIcon: RagHTTP
+}
---
-
# HTTP request component
A component that calls remote services.
diff --git a/docs/guides/agent/agent_component_reference/indexer.md b/docs/guides/agent/agent_component_reference/indexer.md
index 5bc2d925e10..22596773b19 100644
--- a/docs/guides/agent/agent_component_reference/indexer.md
+++ b/docs/guides/agent/agent_component_reference/indexer.md
@@ -1,8 +1,10 @@
---
sidebar_position: 40
slug: /indexer_component
+sidebar_custom_props: {
+ categoryIcon: LucideListPlus
+}
---
-
# Indexer component
A component that defines how chunks are indexed.
diff --git a/docs/guides/agent/agent_component_reference/iteration.mdx b/docs/guides/agent/agent_component_reference/iteration.mdx
index 9d4907d8773..051b923eefb 100644
--- a/docs/guides/agent/agent_component_reference/iteration.mdx
+++ b/docs/guides/agent/agent_component_reference/iteration.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 7
slug: /iteration_component
+sidebar_custom_props: {
+ categoryIcon: LucideRepeat2
+}
---
-
# Iteration component
A component that splits text input into text segments and iterates a predefined workflow for each one.
diff --git a/docs/guides/agent/agent_component_reference/message.mdx b/docs/guides/agent/agent_component_reference/message.md
similarity index 54%
rename from docs/guides/agent/agent_component_reference/message.mdx
rename to docs/guides/agent/agent_component_reference/message.md
index 9e12ba547d4..45e9324dd51 100644
--- a/docs/guides/agent/agent_component_reference/message.mdx
+++ b/docs/guides/agent/agent_component_reference/message.md
@@ -1,8 +1,10 @@
---
sidebar_position: 4
slug: /message_component
+sidebar_custom_props: {
+ categoryIcon: LucideMessageSquareReply
+}
---
-
# Message component
A component that sends out a static or dynamic message.
@@ -13,9 +15,19 @@ As the final component of the workflow, a Message component returns the workflow
## Configurations
+### Status
+
+The HTTP status code (`200` ~ `399`) to return when the entire workflow completes. Available *only* when you select **Final response** as **Execution mode** in the [Begin](./begin.md) component.
+
### Messages
The message to send out. Click `(x)` or type `/` to quickly insert variables.
Click **+ Add message** to add message options. When multiple messages are supplied, the **Message** component randomly selects one to send.
+### Save to memory
+
+Save the conversation to specified memories. Expand the dropdown list to either select all available memories or specified memories:
+
+
+
\ No newline at end of file
diff --git a/docs/guides/agent/agent_component_reference/parser.md b/docs/guides/agent/agent_component_reference/parser.md
index 0eb0f6bff2d..cdc0a9e1750 100644
--- a/docs/guides/agent/agent_component_reference/parser.md
+++ b/docs/guides/agent/agent_component_reference/parser.md
@@ -1,8 +1,10 @@
---
sidebar_position: 30
slug: /parser_component
+sidebar_custom_props: {
+ categoryIcon: LucideFilePlay
+}
---
-
# Parser component
A component that sets the parsing rules for your dataset.
diff --git a/docs/guides/agent/agent_component_reference/retrieval.mdx b/docs/guides/agent/agent_component_reference/retrieval.mdx
index 1f88669cfa2..5295092ed1d 100644
--- a/docs/guides/agent/agent_component_reference/retrieval.mdx
+++ b/docs/guides/agent/agent_component_reference/retrieval.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /retrieval_component
+sidebar_custom_props: {
+ categoryIcon: LucideFolderSearch
+}
---
-
# Retrieval component
A component that retrieves information from specified datasets.
@@ -74,13 +76,15 @@ Select the query source for retrieval. Defaults to `sys.query`, which is the def
The **Retrieval** component relies on query variables to specify its queries. All global variables defined before the **Retrieval** component can also be used as queries. Use the `(x)` button or type `/` to show all the available query variables.
-### Knowledge bases
+### Retrieval from
-Select the dataset(s) to retrieve data from.
+Select the dataset(s) and memory to retrieve data from.
- If no dataset is selected, meaning conversations with the agent will not be based on any dataset, ensure that the **Empty response** field is left blank to avoid an error.
- If you select multiple datasets, you must ensure that the datasets you select use the same embedding model; otherwise, an error message would occur.
+
+
### Similarity threshold
RAGFlow employs a combination of weighted keyword similarity and weighted vector cosine similarity during retrieval. This parameter sets the threshold for similarities between the user query and chunks stored in the datasets. Any chunk with a similarity score below this threshold will be excluded from the results.
@@ -129,6 +133,10 @@ Before enabling this feature, ensure you have properly [constructed a knowledge
Whether to use knowledge graph(s) in the specified dataset(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
+### PageIndex
+
+Whether to use the page index structure generated by the large model to enhance retrieval. This approach mimics human information-searching behavior in books.
+
### Output
The global variable name for the output of the **Retrieval** component, which can be referenced by other components in the workflow.
diff --git a/docs/guides/agent/agent_component_reference/switch.mdx b/docs/guides/agent/agent_component_reference/switch.mdx
index 1840e666a49..d98ca82c007 100644
--- a/docs/guides/agent/agent_component_reference/switch.mdx
+++ b/docs/guides/agent/agent_component_reference/switch.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 6
slug: /switch_component
+sidebar_custom_props: {
+ categoryIcon: LucideSplit
+}
---
-
# Switch component
A component that evaluates whether specified conditions are met and directs the follow of execution accordingly.
diff --git a/docs/guides/agent/agent_component_reference/text_processing.mdx b/docs/guides/agent/agent_component_reference/text_processing.mdx
index 626ae67bf3e..7ecfa19e14d 100644
--- a/docs/guides/agent/agent_component_reference/text_processing.mdx
+++ b/docs/guides/agent/agent_component_reference/text_processing.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 15
slug: /text_processing
+sidebar_custom_props: {
+ categoryIcon: LucideType
+}
---
-
# Text processing component
A component that merges or splits texts.
diff --git a/docs/guides/agent/agent_component_reference/transformer.md b/docs/guides/agent/agent_component_reference/transformer.md
index ad8274ac4ee..6d64c8f19a6 100644
--- a/docs/guides/agent/agent_component_reference/transformer.md
+++ b/docs/guides/agent/agent_component_reference/transformer.md
@@ -1,8 +1,10 @@
---
sidebar_position: 37
slug: /transformer_component
+sidebar_custom_props: {
+ categoryIcon: LucideFileStack
+}
---
-
# Transformer component
A component that uses an LLM to extract insights from the chunks.
@@ -44,10 +46,10 @@ Click the dropdown menu of **Model** to show the model configuration window.
- A higher **frequency penalty** value results in the model being more conservative in its use of repeated tokens.
- Defaults to 0.7.
- **Max tokens**:
- This sets the maximum length of the model's output, measured in the number of tokens (words or pieces of words). It is disabled by default, allowing the model to determine the number of tokens in its responses.
+ - The maximum context size of the model.
:::tip NOTE
-- It is not necessary to stick with the same model for all components. If a specific model is not performing well for a particular task, consider using a different one.
+- It is *not* necessary to stick with the same model for all components. If a specific model is not performing well for a particular task, consider using a different one.
- If you are uncertain about the mechanism behind **Temperature**, **Top P**, **Presence penalty**, and **Frequency penalty**, simply choose one of the three options of **Creativity**.
:::
diff --git a/docs/guides/agent/agent_introduction.md b/docs/guides/agent/agent_introduction.md
index fa21a781062..f310e503ddf 100644
--- a/docs/guides/agent/agent_introduction.md
+++ b/docs/guides/agent/agent_introduction.md
@@ -1,9 +1,11 @@
---
sidebar_position: 1
slug: /agent_introduction
+sidebar_custom_props: {
+ categoryIcon: LucideBookOpenText
+}
---
-
-# Introduction to agents
+# Introduction
Key concepts, basic operations, a quick view of the agent editor.
diff --git a/docs/guides/agent/agent_quickstarts/_category_.json b/docs/guides/agent/agent_quickstarts/_category_.json
new file mode 100644
index 00000000000..fc5ce9c0ac6
--- /dev/null
+++ b/docs/guides/agent/agent_quickstarts/_category_.json
@@ -0,0 +1,11 @@
+{
+ "label": "Quickstarts",
+ "position": 2,
+ "link": {
+ "type": "generated-index",
+ "description": "Agent-specific quickstart"
+ },
+ "customProps": {
+ "categoryIcon": "LucideRocket"
+ }
+}
diff --git a/docs/guides/agent/agent_quickstarts/build_ecommerce_customer_support_agent.md b/docs/guides/agent/agent_quickstarts/build_ecommerce_customer_support_agent.md
new file mode 100644
index 00000000000..93e66ba3de9
--- /dev/null
+++ b/docs/guides/agent/agent_quickstarts/build_ecommerce_customer_support_agent.md
@@ -0,0 +1,105 @@
+---
+sidebar_position: 3
+slug: /ecommerce_customer_support_agent
+sidebar_custom_props: {
+ categoryIcon: LucideStethoscope
+}
+---
+
+# Build Ecommerce customer support agent
+
+This quickstart guides you through building an intelligent e‑commerce customer support agent. The agent uses RAGFlow’s workflow and Agent framework to automatically handle common customer requests such as product comparisons, usage instructions, and installation bookings—providing fast, accurate, and context-aware responses. In the following sections, we will walk you through the process of building an Ecommerce customer support Agent as shown below:
+
+
+
+## Prerequisites
+
+- Sample datasets (available from [Hugging Face](https://huggingface.co/datasets/InfiniFlow/Ecommerce-Customer-Service-Workflow)).
+
+## Procedures
+
+### Prepare datasets
+
+1. Ensure that the above-mentioned sample datasets are downloaded.
+2. Create two datasets:
+ - Product Information
+ - User Guide
+3. Upload the corresponding documents to each dataset.
+4. On the configurations page of both datasets, choose **Manual** as chunking method.
+ *RAGFlow preserves content integrity by splitting documents at the “smallest heading” level, keeping text and related graphics together.*
+
+### Create an Agent app
+
+1. Navigate to the **Agent** page, create an Agent app to enter the Agent canvas.
+ _A **Begin** component will appear on the canvas._
+2. Configure a greeting message in the **Begin** component, for example:
+
+ ```
+ Hi! What can I do for you?
+ ```
+### Add Categorize component
+
+
+
+This **Categorize** component uses an LLM to recognize user intent and route the conversation to the correct workflow.
+
+### Build a product feature comparison workflow
+
+
+
+1. Add a **Retrieval** component named “Feature Comparison Knowledge Base" and connect it to the “Product Information” dataset.
+2. Add an **Agent** component named “Feature Comparison Agent” after the **Retrieval** component.
+3. Configure the Agent’s System Prompt:
+ ```
+ You are a product specification comparison assistant. Help the user compare products by confirming the models and presenting differences clearly in a structured format.
+ ```
+4. Configure the User Prompt:
+ ```
+ User's query is /(Begin Input) sys.query
+ Schema is /(Feature Comparison Knowledge Base) formalized_content
+ ```
+
+### Build a product user guide workflow
+
+
+
+1. Add a **Retrieval** component named “Usage Guide Knowledge Base” and link it to the “User Guide” dataset.
+2. Add an Agent component named “Usage Guide Agent.”
+3. Set its System Prompt:
+ ```
+ You are a product usage guide assistant. Provide step‑by‑step instructions for setup, operation, and troubleshooting.
+ ```
+4. Set the User Prompt:
+ ```
+ User's query is /(Begin Input) sys.query
+ Schema is /(Usage Guide Knowledge Base) formalized_content
+ ```
+
+### Build an installation booking assistant
+
+
+
+1. Add an **Agent** component named “Installation Booking Agent.”
+2. Configure its System Prompt to collect three details:
+ - Contact number
+ - Preferred installation time
+ - Installation address
+
+ *Once all three are collected, the agent should confirm them and notify the user that a technician will call.*
+
+3. Set the User Prompt:
+ ```
+ User's query is /(Begin Input) sys.query
+
+4. Connect a **Message** component after the three Agent branches.
+ *This component displays the final response to the user.*
+
+ 
+
+5. Click **Save** → **Run** to view execution results and verify that each query is correctly routed and answered.
+6. You can test the workflow by asking:
+ - Product comparison questions
+ - Usage guidance questions
+ - Installation booking requests
+
+
diff --git a/docs/guides/agent/agent_quickstarts/ingestion_pipeline_quickstart.md b/docs/guides/agent/agent_quickstarts/ingestion_pipeline_quickstart.md
new file mode 100644
index 00000000000..452463e8cdb
--- /dev/null
+++ b/docs/guides/agent/agent_quickstarts/ingestion_pipeline_quickstart.md
@@ -0,0 +1,133 @@
+---
+sidebar_position: 5
+slug: /ingestion_pipeline_quickstart
+sidebar_custom_props: {
+ categoryIcon: LucideRoute
+}
+---
+
+# Ingestion pipeline quickstart
+
+RAGFlow's ingestion pipeline is a customizable, step-by-step workflow that prepares your documents for high-quality AI retrieval and answering. You can think of it as building blocks: you connect different processing "components" to create a pipeline tailored to your specific documents and needs.
+
+---
+
+RAGFlow is an open-source RAG platform with strong document processing capabilities. Its built-in module, DeepDoc, uses intelligent parsing to split documents for accurate retrieval. To handle diverse real-world needs—like varied file sources, complex layouts, and richer semantics—RAGFlow now introduces the *ingestion pipeline*.
+
+The ingestion pipeline lets you customize every step of document processing:
+
+- Apply different parsing and splitting rules per scenario
+- Add preprocessing like summarization or keyword extraction
+- Connect to cloud drives and online data sources
+- Use advanced layout-aware models for tables and mixed content
+
+This flexible pipeline adapts to your data, improving answer quality in RAG.
+
+## 1. Understand the core pipeline components
+
+- **Parser** component: Reads and understands your files (PDFs, images, emails, etc.), extracting text and structure.
+- **Transformer** component: Enhances text by using AI to add summaries, keywords, or questions to improve search.
+- **Chunker** component: Splits long text into optimal-sized segments ("chunks") for better AI retrieval.
+- **Indexer** component: The final step. Sends the processed data to the document engine (supports hybrid full-text and vector search).
+
+## 2. Create an ingestion pipeline
+
+1. Go to the **Agent** page.
+2. Click **Create agent** and start from a blank canvas or a pre-built template (recommended for beginners).
+3. On the canvas, drag and connect components from the right-side panel to design your flow (e.g., Parser → Chunker → Transformer → Indexer).
+
+*Now let's build a typical ingestion pipeline!*
+
+## 3. Configure Parser component
+
+A **Parser** component converts your files into structured text while preserving layout, tables, headers, and other formatting. Its supported files 8 categories, 23+ formats including PDF, Image, Audio, Video, Email, Spreadsheet (Excel), Word, PPT, HTML, and Markdown. The following are some key configurations:
+
+- For PDF files, choose one of the following:
+ - **DeepDoc** (Default): RAGFlow's built-in model. Best for scanned documents or complex layouts with tables.
+ - **MinerU**: Industry-leading for complex elements like mathematical formulas and intricate layouts.
+ - **Naive**: Simple text extraction. Use for clean, text-based PDFs without complex elements.
+- For image files: Default uses OCR. Can also configure Vision Language Models (VLMs) for advanced visual understanding.
+- For Email Files: Select specific fields to parse (e.g., "subject", "body") for precise extraction.
+- For Spreadsheets: Outputs in HTML format, preserving row/column structure.
+- For Word/PPT: Outputs in JSON format, retaining document hierarchy (titles, paragraphs, slides).
+- For Text & Markup (HTML/MD): Automatically strips formatting tags, outputting clean text.
+
+
+
+
+## 4. Configure Chunker component
+
+The chunker component splits text intelligently. It's goal is to prevent AI context window overflow and improve semantic accuracy in hybrid search. There are two core methods (Can be used sequentially):
+
+- By Tokens (Default):
+ - Chunk Size: Default is 512 tokens. Balance between retrieval quality and model compatibility.
+ - Overlap: Set **Overlapped percent** to duplicate end of one chunk into start of next. Improves semantic continuity.
+ - Separators: Default uses `\n` (newlines) to split at natural paragraph boundaries first, avoiding mid-sentence cuts.
+- By Title (Hierarchical):
+ - Best for structured documents like manuals, papers, legal contracts.
+ - System splits document by chapter/section structure. Each chunk represents a complete structural unit.
+
+:::caution IMPORTANT
+In the current design, if using both Token and Title methods, connect the **Token chunker** component first, then **Title chunker** component. Connecting **Title chunker** directly to **Parser** may cause format errors for Email, Image, Spreadsheet, and Text files.
+:::
+
+## 5. Configure Transformer component
+
+A **Transformer** component is designed to bridge the "Semantic Gap". Generally speaking, it uses AI models to add semantic metadata, making your content more discoverable during retrieval. It has four generation types:
+
+- Summary: Create concise overviews.
+- Keywords: Extract key terms.
+- Questions: Generate questions each text chunk can answer.
+- Metadata: Custom metadata extraction.
+
+If you have multiple **Transformers**, ensure that you separate **Transformer** components for each function (e.g., one for Summary, another for Keywords).
+
+The following are some key configurations:
+
+- Model modes: (choose one)
+ - Improvise: More creative, good for question generation.
+ - Precise: Strictly faithful to text, good for Summary/Keyword extraction.
+ - Balance: Middle ground for most scenarios.
+- Prompt engineering: System prompts for each generation type are open and customizable.
+- Connection: **Transformer** can connect after **Parser** (processes whole document) OR after **Chunker** (processes each chunk).
+- Variable referencing: The node doesn't auto-acquire content. In the User prompt, manually reference upstream variables by typing `/` and selecting the specific output (e.g., `/{Parser.output}` or `/{Chunker.output}`).
+- Series connection: When chaining **Transformers**, the second **Transformer** component will process the output of the first (e.g., generate Keywords from a Summary) if variables are correctly referenced.
+
+
+
+
+
+## 6. Configure Indexer component
+
+The **Indexer** component indexes for optimal retrieval. It is the final step writes processed data to the search engine (such as Infinity, Elasticsearch, OpenSearch). The following are some key configurations:
+
+- Search methods:
+ - Full-text: Keyword search for exact matches (codes, names).
+ - Embedding: Semantic search using vector similarity.
+ - Hybrid (Recommended): Both methods combined for best recall.
+- Retrieval Strategy:
+ - Processed text (Default): Indexes the chunked text.
+ - Questions: Indexes generated questions. Often yields higher similarity matching than text-to-text.
+ - Augmented context: Indexes summaries instead of raw text. Good for broad topic matching.
+- Filename weight: Slider to include document filename as semantic information in retrieval.
+- Embedding model: Automatically uses the model set when creating the dataset.
+
+
+
+:::caution IMPORTANT
+To search across multiple datasets simultaneously, all selected datasets must use the same embedding model.
+:::
+
+## 7. Test run
+
+Click **Run** on your pipeline canvas to upload a sample file and see the step-by-step results.
+
+## 8. Connect pipeline to a dataset
+
+1. When creating or editing a dataset, find the **Ingestion pipeline** section.
+2. Click **Choose pipeline** and select your saved pipeline.
+
+
+
+*Now, any files uploaded to this dataset will be processed by your custom pipeline.*
+
diff --git a/docs/guides/agent/sandbox_quickstart.md b/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md
similarity index 98%
rename from docs/guides/agent/sandbox_quickstart.md
rename to docs/guides/agent/agent_quickstarts/sandbox_quickstart.md
index 5baa935a844..115ffe88823 100644
--- a/docs/guides/agent/sandbox_quickstart.md
+++ b/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md
@@ -1,8 +1,10 @@
---
sidebar_position: 20
slug: /sandbox_quickstart
+sidebar_custom_props: {
+ categoryIcon: LucideCodesandbox
+}
---
-
# Sandbox quickstart
A secure, pluggable code execution backend designed for RAGFlow and other applications requiring isolated code execution environments.
diff --git a/docs/guides/agent/best_practices/_category_.json b/docs/guides/agent/best_practices/_category_.json
index c788383c044..63edea2af69 100644
--- a/docs/guides/agent/best_practices/_category_.json
+++ b/docs/guides/agent/best_practices/_category_.json
@@ -1,8 +1,11 @@
{
"label": "Best practices",
- "position": 30,
+ "position": 40,
"link": {
"type": "generated-index",
"description": "Best practices on Agent configuration."
+ },
+ "customProps": {
+ "categoryIcon": "LucideStar"
}
}
diff --git a/docs/guides/agent/embed_agent_into_webpage.md b/docs/guides/agent/embed_agent_into_webpage.md
index 1b532c4d724..97dae8b66c0 100644
--- a/docs/guides/agent/embed_agent_into_webpage.md
+++ b/docs/guides/agent/embed_agent_into_webpage.md
@@ -1,9 +1,11 @@
---
-sidebar_position: 3
+sidebar_position: 30
slug: /embed_agent_into_webpage
+sidebar_custom_props: {
+ categoryIcon: LucideMonitorDot
+}
---
-
-# Embed agent into webpage
+# Embed Agent into webpage
You can use iframe to embed an agent into a third-party webpage.
diff --git a/docs/guides/ai_search.md b/docs/guides/ai_search.md
index 6bd5336006d..1f257d29110 100644
--- a/docs/guides/ai_search.md
+++ b/docs/guides/ai_search.md
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /ai_search
+sidebar_custom_props: {
+ categoryIcon: LucideSearch
+}
---
-
# Search
Conduct an AI search.
diff --git a/docs/guides/chat/_category_.json b/docs/guides/chat/_category_.json
index 4b33e0c7b3d..d55b914ec73 100644
--- a/docs/guides/chat/_category_.json
+++ b/docs/guides/chat/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Chat-specific guides."
+ },
+ "customProps": {
+ "categoryIcon": "LucideMessagesSquare"
}
}
diff --git a/docs/guides/chat/best_practices/_category_.json b/docs/guides/chat/best_practices/_category_.json
index e92bb793db6..a0e97731fba 100644
--- a/docs/guides/chat/best_practices/_category_.json
+++ b/docs/guides/chat/best_practices/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Best practices on chat assistant configuration."
+ },
+ "customProps": {
+ "categoryIcon": "LucideStar"
}
}
diff --git a/docs/guides/chat/implement_deep_research.md b/docs/guides/chat/implement_deep_research.md
index b5edd2d92f0..2b07a4116e6 100644
--- a/docs/guides/chat/implement_deep_research.md
+++ b/docs/guides/chat/implement_deep_research.md
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /implement_deep_research
+sidebar_custom_props: {
+ categoryIcon: LucideScanSearch
+}
---
-
# Implement deep research
Implements deep research for agentic reasoning.
diff --git a/docs/guides/chat/set_chat_variables.md b/docs/guides/chat/set_chat_variables.md
index 00f1a58c71c..a9bd9dcdcb8 100644
--- a/docs/guides/chat/set_chat_variables.md
+++ b/docs/guides/chat/set_chat_variables.md
@@ -1,8 +1,10 @@
---
sidebar_position: 4
slug: /set_chat_variables
+sidebar_custom_props: {
+ categoryIcon: LucideVariable
+}
---
-
# Set variables
Set variables to be used together with the system prompt for your LLM.
@@ -17,7 +19,7 @@ In RAGFlow, variables are closely linked with the system prompt. When you add a
## Where to set variables
-
+
## 1. Manage variables
diff --git a/docs/guides/chat/start_chat.md b/docs/guides/chat/start_chat.md
index 1e0dd0f10f0..e5066a8b297 100644
--- a/docs/guides/chat/start_chat.md
+++ b/docs/guides/chat/start_chat.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /start_chat
+sidebar_custom_props: {
+ categoryIcon: LucideBot
+}
---
-
# Start AI chat
Initiate an AI-powered chat with a configured chat assistant.
diff --git a/docs/guides/dataset/_category_.json b/docs/guides/dataset/_category_.json
index 4c454f51f47..9501311fd68 100644
--- a/docs/guides/dataset/_category_.json
+++ b/docs/guides/dataset/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Guides on configuring a dataset."
+ },
+ "customProps": {
+ "categoryIcon": "LucideDatabaseZap"
}
}
diff --git a/docs/guides/dataset/add_data_source/_category_.json b/docs/guides/dataset/add_data_source/_category_.json
index 42f2b164a13..71b3d794d30 100644
--- a/docs/guides/dataset/add_data_source/_category_.json
+++ b/docs/guides/dataset/add_data_source/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Add various data sources"
+ },
+ "customProps": {
+ "categoryIcon": "LucideServer"
}
}
diff --git a/docs/guides/dataset/add_data_source/add_google_drive.md b/docs/guides/dataset/add_data_source/add_google_drive.md
index a1f2d895fe6..57263094845 100644
--- a/docs/guides/dataset/add_data_source/add_google_drive.md
+++ b/docs/guides/dataset/add_data_source/add_google_drive.md
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /add_google_drive
+sidebar_custom_props: {
+ categoryIcon: SiGoogledrive
+}
---
-
# Add Google Drive
## 1. Create a Google Cloud Project
diff --git a/docs/guides/dataset/auto_metadata.md b/docs/guides/dataset/auto_metadata.md
index 35967b935b6..7a7b086361b 100644
--- a/docs/guides/dataset/auto_metadata.md
+++ b/docs/guides/dataset/auto_metadata.md
@@ -1,8 +1,10 @@
---
sidebar_position: -6
slug: /auto_metadata
+sidebar_custom_props: {
+ categoryIcon: LucideFileCodeCorner
+}
---
-
# Auto-extract metadata
Automatically extract metadata from uploaded files.
diff --git a/docs/guides/dataset/autokeyword_autoquestion.mdx b/docs/guides/dataset/autokeyword_autoquestion.mdx
index e917645856f..3165a6a6b14 100644
--- a/docs/guides/dataset/autokeyword_autoquestion.mdx
+++ b/docs/guides/dataset/autokeyword_autoquestion.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /autokeyword_autoquestion
+sidebar_custom_props: {
+ categoryIcon: LucideSlidersHorizontal
+}
---
-
# Auto-keyword Auto-question
import APITable from '@site/src/components/APITable';
diff --git a/docs/guides/dataset/best_practices/_category_.json b/docs/guides/dataset/best_practices/_category_.json
index 79a1103d5fa..f1fe9fa4100 100644
--- a/docs/guides/dataset/best_practices/_category_.json
+++ b/docs/guides/dataset/best_practices/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Best practices on configuring a dataset."
+ },
+ "customProps": {
+ "categoryIcon": "LucideStar"
}
}
diff --git a/docs/guides/dataset/configure_child_chunking_strategy.md b/docs/guides/dataset/configure_child_chunking_strategy.md
index 0be4d233034..32a61408ee8 100644
--- a/docs/guides/dataset/configure_child_chunking_strategy.md
+++ b/docs/guides/dataset/configure_child_chunking_strategy.md
@@ -1,8 +1,10 @@
---
sidebar_position: -4
slug: /configure_child_chunking_strategy
+sidebar_custom_props: {
+ categoryIcon: LucideGroup
+}
---
-
# Configure child chunking strategy
Set parent-child chunking strategy to improve retrieval.
diff --git a/docs/guides/dataset/configure_knowledge_base.md b/docs/guides/dataset/configure_knowledge_base.md
index e7aaa50ff8a..92fc1fec9ae 100644
--- a/docs/guides/dataset/configure_knowledge_base.md
+++ b/docs/guides/dataset/configure_knowledge_base.md
@@ -1,8 +1,10 @@
---
sidebar_position: -10
slug: /configure_knowledge_base
+sidebar_custom_props: {
+ categoryIcon: LucideCog
+}
---
-
# Configure dataset
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
@@ -133,7 +135,7 @@ See [Run retrieval test](./run_retrieval_test.md) for details.
## Search for dataset
-As of RAGFlow v0.23.1, the search feature is still in a rudimentary form, supporting only dataset search by name.
+As of RAGFlow v0.24.0, the search feature is still in a rudimentary form, supporting only dataset search by name.

diff --git a/docs/guides/dataset/construct_knowledge_graph.md b/docs/guides/dataset/construct_knowledge_graph.md
index 47108081151..b4eba1fd6b0 100644
--- a/docs/guides/dataset/construct_knowledge_graph.md
+++ b/docs/guides/dataset/construct_knowledge_graph.md
@@ -1,8 +1,10 @@
---
sidebar_position: 8
slug: /construct_knowledge_graph
+sidebar_custom_props: {
+ categoryIcon: LucideWandSparkles
+}
---
-
# Construct knowledge graph
Generate a knowledge graph for your dataset.
diff --git a/docs/guides/dataset/enable_excel2html.md b/docs/guides/dataset/enable_excel2html.md
index 5a7a8fa41f3..9f4f20bec02 100644
--- a/docs/guides/dataset/enable_excel2html.md
+++ b/docs/guides/dataset/enable_excel2html.md
@@ -1,8 +1,10 @@
---
sidebar_position: 4
slug: /enable_excel2html
+sidebar_custom_props: {
+ categoryIcon: LucideToggleRight
+}
---
-
# Enable Excel2HTML
Convert complex Excel spreadsheets into HTML tables.
diff --git a/docs/guides/dataset/enable_raptor.md b/docs/guides/dataset/enable_raptor.md
index 2d8fa245358..54e36d2bf22 100644
--- a/docs/guides/dataset/enable_raptor.md
+++ b/docs/guides/dataset/enable_raptor.md
@@ -1,8 +1,10 @@
---
sidebar_position: 7
slug: /enable_raptor
+sidebar_custom_props: {
+ categoryIcon: LucideNetwork
+}
---
-
# Enable RAPTOR
A recursive abstractive method used in long-context knowledge retrieval and summarization, balancing broad semantic understanding with fine details.
diff --git a/docs/guides/dataset/extract_table_of_contents.md b/docs/guides/dataset/extract_table_of_contents.md
index 58e920613ec..fc86f78f466 100644
--- a/docs/guides/dataset/extract_table_of_contents.md
+++ b/docs/guides/dataset/extract_table_of_contents.md
@@ -1,18 +1,20 @@
---
sidebar_position: 4
slug: /enable_table_of_contents
+sidebar_custom_props: {
+ categoryIcon: LucideTableOfContents
+}
---
-
# Extract table of contents
-Extract table of contents (TOC) from documents to provide long context RAG and improve retrieval.
+Extract PageIndex, namely table of contents, from documents to provide long context RAG and improve retrieval.
---
-During indexing, this technique uses LLM to extract and generate chapter information, which is added to each chunk to provide sufficient global context. At the retrieval stage, it first uses the chunks matched by search, then supplements missing chunks based on the table of contents structure. This addresses issues caused by chunk fragmentation and insufficient context, improving answer quality.
+During indexing, this technique uses LLM to extract and generate chapter information, which is added to each chunk to provide sufficient global context. At the retrieval stage, it first uses the chunks matched by search, then supplements missing chunks based on the PageIndex (table of contents) structure. This addresses issues caused by chunk fragmentation and insufficient context, improving answer quality.
:::danger WARNING
-Enabling TOC extraction requires significant memory, computational resources, and tokens.
+Enabling PageIndex extraction requires significant memory, computational resources, and tokens.
:::
## Prerequisites
@@ -25,15 +27,15 @@ The system's default chat model is used to summarize clustered content. Before p
1. Navigate to the **Configuration** page.
-2. Enable **TOC Enhance**.
+2. Enable **PageIndex**.
3. To use this technique during retrieval, do either of the following:
- - In the **Chat setting** panel of your chat app, switch on the **TOC Enhance** toggle.
- - If you are using an agent, click the **Retrieval** agent component to specify the dataset(s) and switch on the **TOC Enhance** toggle.
+ - In the **Chat setting** panel of your chat app, switch on the **PageIndex** toggle.
+ - If you are using an Agent, click the **Retrieval** agent component to specify the dataset(s) and switch on the **PageIndex** toggle.
## Frequently asked questions
-### Will previously parsed files be searched using the TOC enhancement feature once I enable `TOC Enhance`?
+### Will previously parsed files be searched using the directory enhancement feature once I enable `PageIndex`?
-No. Only files parsed after you enable **TOC Enhance** will be searched using the TOC enhancement feature. To apply this feature to files parsed before enabling **TOC Enhance**, you must reparse them.
\ No newline at end of file
+No. Only files parsed after you enable **PageIndex** will be searched using the directory enhancement feature. To apply this feature to files parsed before enabling **PageIndex**, you must reparse them.
\ No newline at end of file
diff --git a/docs/guides/dataset/manage_metadata.md b/docs/guides/dataset/manage_metadata.md
index a848007fbf7..79b42a47621 100644
--- a/docs/guides/dataset/manage_metadata.md
+++ b/docs/guides/dataset/manage_metadata.md
@@ -1,8 +1,10 @@
---
sidebar_position: -5
slug: /manage_metadata
+sidebar_custom_props: {
+ categoryIcon: LucideCode
+}
---
-
# Manage metadata
Manage metadata for your dataset and for your individual documents.
diff --git a/docs/guides/dataset/run_retrieval_test.md b/docs/guides/dataset/run_retrieval_test.md
index 87bd29835c5..973a2f2ed56 100644
--- a/docs/guides/dataset/run_retrieval_test.md
+++ b/docs/guides/dataset/run_retrieval_test.md
@@ -1,8 +1,10 @@
---
sidebar_position: 10
slug: /run_retrieval_test
+sidebar_custom_props: {
+ categoryIcon: LucideTextSearch
+}
---
-
# Run retrieval test
Conduct a retrieval test on your dataset to check whether the intended chunks can be retrieved.
diff --git a/docs/guides/dataset/select_pdf_parser.md b/docs/guides/dataset/select_pdf_parser.md
index 14831490803..fa2d068cb42 100644
--- a/docs/guides/dataset/select_pdf_parser.md
+++ b/docs/guides/dataset/select_pdf_parser.md
@@ -1,8 +1,10 @@
---
sidebar_position: -3
slug: /select_pdf_parser
+sidebar_custom_props: {
+ categoryIcon: LucideFileText
+}
---
-
# Select PDF parser
Select a visual model for parsing your PDFs.
diff --git a/docs/guides/dataset/set_context_window.md b/docs/guides/dataset/set_context_window.md
index 7f9abdd804c..20d9cb597e7 100644
--- a/docs/guides/dataset/set_context_window.md
+++ b/docs/guides/dataset/set_context_window.md
@@ -1,8 +1,10 @@
---
sidebar_position: -8
slug: /set_context_window
+sidebar_custom_props: {
+ categoryIcon: LucideListChevronsUpDown
+}
---
-
# Set context window size
Set context window size for images and tables to improve long-context RAG performances.
diff --git a/docs/guides/dataset/set_metadata.md b/docs/guides/dataset/set_metadata.md
index 34db390cd29..082fc70b540 100644
--- a/docs/guides/dataset/set_metadata.md
+++ b/docs/guides/dataset/set_metadata.md
@@ -1,8 +1,10 @@
---
sidebar_position: -7
slug: /set_metadata
+sidebar_custom_props: {
+ categoryIcon: LucideCode
+}
---
-
# Set metadata
Manually add metadata to an uploaded file
diff --git a/docs/guides/dataset/set_page_rank.md b/docs/guides/dataset/set_page_rank.md
index 5df848a0e22..de22072ca67 100644
--- a/docs/guides/dataset/set_page_rank.md
+++ b/docs/guides/dataset/set_page_rank.md
@@ -1,8 +1,10 @@
---
sidebar_position: -2
slug: /set_page_rank
+sidebar_custom_props: {
+ categoryIcon: LucideStickyNote
+}
---
-
# Set page rank
Create a step-retrieval strategy using page rank.
diff --git a/docs/guides/dataset/use_tag_sets.md b/docs/guides/dataset/use_tag_sets.md
index 389a97b0a93..af9134b2015 100644
--- a/docs/guides/dataset/use_tag_sets.md
+++ b/docs/guides/dataset/use_tag_sets.md
@@ -1,8 +1,10 @@
---
sidebar_position: 6
slug: /use_tag_sets
+sidebar_custom_props: {
+ categoryIcon: LucideTags
+}
---
-
# Use tag set
Use a tag set to auto-tag chunks in your datasets.
diff --git a/docs/guides/manage_files.md b/docs/guides/manage_files.md
index 27c6f1d3657..bbb5b5ec143 100644
--- a/docs/guides/manage_files.md
+++ b/docs/guides/manage_files.md
@@ -1,8 +1,10 @@
---
sidebar_position: 6
slug: /manage_files
+sidebar_custom_props: {
+ categoryIcon: LucideFolderDot
+}
---
-
# Files
RAGFlow's file management allows you to upload files individually or in bulk. You can then link an uploaded file to multiple target datasets. This guide showcases some basic usages of the file management feature.
@@ -87,4 +89,4 @@ RAGFlow's file management allows you to download an uploaded file:

-> As of RAGFlow v0.23.1, bulk download is not supported, nor can you download an entire folder.
+> As of RAGFlow v0.24.0, bulk download is not supported, nor can you download an entire folder.
diff --git a/docs/guides/memory/_category_.json b/docs/guides/memory/_category_.json
new file mode 100644
index 00000000000..d3b1e49e565
--- /dev/null
+++ b/docs/guides/memory/_category_.json
@@ -0,0 +1,11 @@
+{
+ "label": "Memory",
+ "position": 3.5,
+ "link": {
+ "type": "generated-index",
+ "description": "Guides on using Memory."
+ },
+ "customProps": {
+ "categoryIcon": "LucideBox"
+ }
+}
diff --git a/docs/guides/memory/use_memory.md b/docs/guides/memory/use_memory.md
new file mode 100644
index 00000000000..3979ea55896
--- /dev/null
+++ b/docs/guides/memory/use_memory.md
@@ -0,0 +1,108 @@
+---
+sidebar_position: 1
+slug: /use_memory
+sidebar_custom_props: {
+ categoryIcon: LucideMonitorCog
+}
+---
+
+# Use memory
+
+RAGFlow's Memory module is built to save everything, including conversation that happens while an Agent is working. It keeps the raw logs of conversations, like what a user says and what the AI says back. It also saves extra information created during the chat, like summaries or notes the AI makes about the interaction. Its main jobs are to make conversations flow smoothly from one to the next, to allow the AI to remember personal details about a user, and to let the AI learn from all its past talks.
+
+This module does more than just store the raw data. It is smart enough to sort information into different useful types. It can pull out key facts and meanings (semantic memory), remember specific events and stories from past chats (episodic memory), and hold details needed for the current task (working memory). This turns a simple log into an organized library of past experiences.
+
+Because of this, users can easily bring back any saved information into a new conversation. This past context helps the AI stay on topic and avoid repeating itself, making chats feel more connected and natural. More importantly, it gives the AI a reliable history to think from, which makes its answers more accurate and useful.
+
+## Create memory
+
+The Memory module offers streamlined, centralized management of all memories.
+
+When creating a Memory, users can precisely define which types of information to extract, helping ensure that only relevant data is captured and organized. From the navigation path Overview >> Memory, users can then perform key management actions, including renaming memories, organizing them, and sharing them with team members to support collaborative workflows.
+
+
+
+
+## Configure memory
+
+On the **Memory** page, click the intended memory **>** **Configuration** to view and update its settings.
+
+### Name
+
+The unique name of the memory created.
+
+### Embedding model
+
+The embedding model for converting the memory into embeddings.
+
+### LLM
+
+The chat model for extracting knowledge from the memory.
+
+### Memory type
+
+What is stored in the memory:
+
+`Raw`: The raw dialogue between the user and the Agent (Required by default).
+`Semantic Memory`: General knowledge and facts about the user and world.
+`Episodic Memory`: Time-stamped records of specific events and experiences.
+`Procedural Memory`: Learned skills, habits, and automated procedures.
+
+### Memory size
+
+The default capacity allocated to the memory and the corresponding embeddings in bytes. Defaults to `5242880` (5MB).
+
+:::tip NOTE
+A 1KB message with a 1024-dimension embedding occupies approximately 9KB of memory (1KB + 1024 x 8Bytes = 9KB). With a default limit of 5 MB, the system can store roughly 500 such messages.
+:::
+
+### Permission
+
+- **Only me**: Exclusive to the user.
+- **Team**: Share this memory with the team members.
+
+
+## Manage memory
+
+Within an individual Memory page, you can fine-tune how saved entries are used during Agent calls. Each entry can be selectively enabled or disabled, allowing you to control which pieces of information remain active without permanently removing anything.
+
+When certain details are no longer relevant, you can also choose to forget specific memory entries entirely. This keeps the Memory clean, focused, and easier to maintain over time, ensuring that Agents rely only on up‑to‑date and useful information.
+
+
+
+Manually forgotten memory entries are completely excluded from the results returned by Agent calls, ensuring they no longer influence downstream behavior. This helps keep responses focused on the most relevant and intentionally retained information.
+
+When the Memory reaches its storage limit and the automatic forgetting policy is applied, entries that were previously forgotten manually are also prioritized for removal. This allows the system to reclaim capacity more intelligently while respecting earlier user curation decisions.
+
+## Enhance Agent context
+
+Under [Retrieval](../agent/agent_component_reference/retrieval.mdx) and [Message](../agent/agent_component_reference/message.mdx) component settings, a new Memory invocation capability is available. In the Message component, users can configure the Agent to write selected data into a designated Memory, while the Retrieval component can be set to read from that same Memory to answer future queries. This enables a simple Q&A bot Agent to accumulate context over time and respond with richer, memory-aware answers.
+
+### Retrieve from memory
+
+For any Agent configuration that uses Memory, a **Retrieval** component is required to bring stored information back into the conversation. By including Retrieval alongside Memory-aware components, the Agent can consistently recall and apply relevant past data whenever it is needed.
+
+
+
+### Save to memory
+
+At the same time you have finished **Retrieval** component settings, select the corresponding Memory in the **Message** component under **Save to Memory**:
+
+
+
+
+
+## Frequently asked questions
+
+### Can I share my memory?
+
+Yes, you can. Your memory can be shared between Agents. See these topics:
+
+- [Create memory](#create-memory)
+- [Enhance Agent context](#enhance-agent-context)
+
+If you wish to share your memory with your team members, please ensure you have configured its team permissions. See [Share memory](../team/share_memory.md) for details.
+
+
+
+
diff --git a/docs/guides/migration/_category_.json b/docs/guides/migration/_category_.json
index dcb81271612..1099886f2ee 100644
--- a/docs/guides/migration/_category_.json
+++ b/docs/guides/migration/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "RAGFlow migration guide"
+ },
+ "customProps": {
+ "categoryIcon": "LucideArrowRightLeft"
}
}
diff --git a/docs/guides/models/_category_.json b/docs/guides/models/_category_.json
index 8536f8e4760..b4a996b4fa5 100644
--- a/docs/guides/models/_category_.json
+++ b/docs/guides/models/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Guides on model settings."
+ },
+ "customProps": {
+ "categoryIcon": "LucideBox"
}
}
diff --git a/docs/guides/models/deploy_local_llm.mdx b/docs/guides/models/deploy_local_llm.mdx
index 7d8e58eee9b..e7e3fbeaee3 100644
--- a/docs/guides/models/deploy_local_llm.mdx
+++ b/docs/guides/models/deploy_local_llm.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /deploy_local_llm
+sidebar_custom_props: {
+ categoryIcon: LucideMonitorCog
+}
---
-
# Deploy local models
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
diff --git a/docs/guides/models/llm_api_key_setup.md b/docs/guides/models/llm_api_key_setup.md
index f61d71c5830..d2cf67597cc 100644
--- a/docs/guides/models/llm_api_key_setup.md
+++ b/docs/guides/models/llm_api_key_setup.md
@@ -1,8 +1,10 @@
---
sidebar_position: 1
slug: /llm_api_key_setup
+sidebar_custom_props: {
+ categoryIcon: LucideKey
+}
---
-
# Configure model API key
An API key is required for RAGFlow to interact with an online AI model. This guide provides information about setting your model API key in RAGFlow.
diff --git a/docs/guides/team/_category_.json b/docs/guides/team/_category_.json
index 37bbf13073e..f245a5f35b6 100644
--- a/docs/guides/team/_category_.json
+++ b/docs/guides/team/_category_.json
@@ -4,5 +4,8 @@
"link": {
"type": "generated-index",
"description": "Team-specific guides."
+ },
+ "customProps": {
+ "categoryIcon": "LucideUsers"
}
}
diff --git a/docs/guides/team/join_or_leave_team.md b/docs/guides/team/join_or_leave_team.md
index 978523d8018..dfc80ed5a1e 100644
--- a/docs/guides/team/join_or_leave_team.md
+++ b/docs/guides/team/join_or_leave_team.md
@@ -1,8 +1,10 @@
---
sidebar_position: 3
slug: /join_or_leave_team
+sidebar_custom_props: {
+ categoryIcon: LucideLogOut
+}
---
-
# Join or leave a team
Accept an invitation to join a team, decline an invitation, or leave a team.
diff --git a/docs/guides/team/manage_team_members.md b/docs/guides/team/manage_team_members.md
index edd8289cda4..6df75899108 100644
--- a/docs/guides/team/manage_team_members.md
+++ b/docs/guides/team/manage_team_members.md
@@ -1,8 +1,10 @@
---
sidebar_position: 2
slug: /manage_team_members
+sidebar_custom_props: {
+ categoryIcon: LucideUserCog
+}
---
-
# Manage team members
Invite or remove team members.
diff --git a/docs/guides/team/share_agents.md b/docs/guides/team/share_agents.md
index f6be1a7288a..f901f08ebfc 100644
--- a/docs/guides/team/share_agents.md
+++ b/docs/guides/team/share_agents.md
@@ -1,8 +1,10 @@
---
sidebar_position: 6
slug: /share_agent
+sidebar_custom_props: {
+ categoryIcon: LucideShare2
+}
---
-
# Share Agent
Share an Agent with your team members.
diff --git a/docs/guides/team/share_chat_assistant.md b/docs/guides/team/share_chat_assistant.md
index f8f172ee5db..719fbda51ac 100644
--- a/docs/guides/team/share_chat_assistant.md
+++ b/docs/guides/team/share_chat_assistant.md
@@ -1,8 +1,10 @@
---
sidebar_position: 5
slug: /share_chat_assistant
+sidebar_custom_props: {
+ categoryIcon: LucideShare2
+}
---
-
# Share chat assistant
Sharing chat assistant is currently exclusive to RAGFlow Enterprise, but will be made available in due course.
\ No newline at end of file
diff --git a/docs/guides/team/share_knowledge_bases.md b/docs/guides/team/share_knowledge_bases.md
index 4eeccd2643f..3f00c9bd8ea 100644
--- a/docs/guides/team/share_knowledge_bases.md
+++ b/docs/guides/team/share_knowledge_bases.md
@@ -1,8 +1,10 @@
---
sidebar_position: 4
slug: /share_datasets
+sidebar_custom_props: {
+ categoryIcon: LucideShare2
+}
---
-
# Share dataset
Share a dataset with team members.
diff --git a/docs/guides/team/share_memory.md b/docs/guides/team/share_memory.md
new file mode 100644
index 00000000000..fa7a1c51b0a
--- /dev/null
+++ b/docs/guides/team/share_memory.md
@@ -0,0 +1,20 @@
+---
+sidebar_position: 9
+slug: /share_memory
+sidebar_custom_props: {
+ categoryIcon: LucideShare2
+}
+---
+# Share memory
+
+Share a memory with your team members.
+
+---
+
+When ready, you may share your memory with your team members so that they can use it. Please note that your memories are not shared automatically; you must manually enable sharing by selecting the corresponding **Permissions** radio button:
+
+1. Navigate to the **Memory** page, find the intended memory, and click to open its editing canvas.
+2. Click **Configurations**.
+3. Change **Permissions** from **Only me** to **Team**.
+4. Click **Save** to apply your changes.
+ *When completed, your team members will see your shared memories.*
\ No newline at end of file
diff --git a/docs/guides/team/share_model.md b/docs/guides/team/share_model.md
index 459641fcaa8..5a97e671651 100644
--- a/docs/guides/team/share_model.md
+++ b/docs/guides/team/share_model.md
@@ -1,8 +1,10 @@
---
sidebar_position: 7
slug: /share_model
+sidebar_custom_props: {
+ categoryIcon: LucideShare2
+}
---
-
# Share models
Sharing models is currently exclusive to RAGFlow Enterprise.
\ No newline at end of file
diff --git a/docs/guides/tracing.mdx b/docs/guides/tracing.mdx
index c9f37ba7537..13cf99874b8 100644
--- a/docs/guides/tracing.mdx
+++ b/docs/guides/tracing.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 9
slug: /tracing
+sidebar_custom_props: {
+ categoryIcon: LucideLocateFixed
+}
---
-
# Tracing
Observability & Tracing with Langfuse.
diff --git a/docs/guides/upgrade_ragflow.mdx b/docs/guides/upgrade_ragflow.mdx
index 419fe76e4f4..ef43384ddce 100644
--- a/docs/guides/upgrade_ragflow.mdx
+++ b/docs/guides/upgrade_ragflow.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 11
slug: /upgrade_ragflow
+sidebar_custom_props: {
+ categoryIcon: LucideArrowBigUpDash
+}
---
-
# Upgrading
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
@@ -60,16 +62,16 @@ To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker imag
git pull
```
-3. Switch to the latest, officially published release, e.g., `v0.23.1`:
+3. Switch to the latest, officially published release, e.g., `v0.24.0`:
```bash
- git checkout -f v0.23.1
+ git checkout -f v0.24.0
```
4. Update **ragflow/docker/.env**:
```bash
- RAGFLOW_IMAGE=infiniflow/ragflow:v0.23.1
+ RAGFLOW_IMAGE=infiniflow/ragflow:v0.24.0
```
5. Update the RAGFlow image and restart RAGFlow:
@@ -90,10 +92,10 @@ No, you do not need to. Upgrading RAGFlow in itself will *not* remove your uploa
1. From an environment with Internet access, pull the required Docker image.
2. Save the Docker image to a **.tar** file.
```bash
- docker save -o ragflow.v0.23.1.tar infiniflow/ragflow:v0.23.1
+ docker save -o ragflow.v0.24.0.tar infiniflow/ragflow:v0.24.0
```
3. Copy the **.tar** file to the target server.
4. Load the **.tar** file into Docker:
```bash
- docker load -i ragflow.v0.23.1.tar
+ docker load -i ragflow.v0.24.0.tar
```
diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx
index 387de9d7906..e1de5fe184a 100644
--- a/docs/quickstart.mdx
+++ b/docs/quickstart.mdx
@@ -1,8 +1,10 @@
---
sidebar_position: 0
slug: /
+sidebar_custom_props: {
+ sidebarIcon: LucideRocket
+}
---
-
# Get started
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
@@ -46,7 +48,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
`vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation.
- RAGFlow v0.23.1 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
+ RAGFlow v0.24.0 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
- Message: - Value: <>"
}
```
@@ -675,9 +694,10 @@ Failure:
```json
{
- "code": 102,
- "message": "You don't own the dataset."
+ "code":108,
+ "message":"User '' lacks permission for datasets: ''"
}
+
```
---
@@ -894,7 +914,7 @@ Success:
"vector_similarity_weight": 0.3
}
],
- "total": 1
+ "total_datasets": 1
}
```
@@ -2219,8 +2239,14 @@ Success:
"code": 0,
"data": {
"summary": {
- "tags": [["bar", 2], ["foo", 1], ["baz", 1]],
- "author": [["alice", 2], ["bob", 1]]
+ "tags": {
+ "type": "string",
+ "values": [["bar", 2], ["foo", 1], ["baz", 1]]
+ },
+ "author": {
+ "type": "string",
+ "values": [["alice", 2], ["bob", 1]]
+ }
}
}
}
@@ -3939,6 +3965,8 @@ data: {
data:[DONE]
```
+When `extra_body.reference_metadata.include` is `true`, each reference chunk may include a `document_metadata` object.
+
Non-stream:
```json
@@ -4913,6 +4941,1036 @@ Failure:
}
```
+---
+
+
+
+## MEMORY MANAGEMENT
+
+### Create Memory
+
+**POST** `/api/v1/memories`
+
+Create a new memory.
+
+#### Request
+
+- Method: POST
+- URL: `/api/v1/memories`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+- Body:
+ - `"name"`: `string`
+ - `"memory_type"`: `list[string]`
+ - `"embd_id"`: `string`.
+ - `"llm_id"`: `string`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/memories' \
+--header 'Content-Type: application/json' \
+--header 'Authorization: Bearer ' \
+--data-raw '{
+ "name": "new_memory_1",
+ "memory_type": ["raw", "semantic"],
+ "embd_id": "BAAI/bge-large-zh-v1.5@BAAI",
+ "llm_id": "glm-4-flash@ZHIPU-AI"
+}'
+```
+
+##### Request parameters
+
+- `name` : (*Body parameter*), `string`, *Required*
+
+ The unique name of the memory to create. It must adhere to the following requirements:
+
+ - Basic Multilingual Plane (BMP) only
+ - Maximum 128 characters
+
+- `memory_type`: (*Body parameter*), `list[enum]`, *Required*
+
+ Specifies the types of memory to extract. Available options:
+
+ - `raw`: The raw dialogue content between the user and the agent . *Required by default*.
+ - `semantic`: General knowledge and facts about the user and world.
+ - `episodic`: Time-stamped records of specific events and experiences.
+ - `procedural`: Learned skills, habits, and automated procedures.
+
+- `embd_id`: (*Body parameter*), `string`, *Required*
+
+ The name of the embedding model to use. For example: `"BAAI/bge-large-zh-v1.5@BAAI"`
+
+ - Maximum 255 characters
+ - Must follow `model_name@model_factory` format
+
+- `llm_id`: (*Body parameter*), `string`, *Required*
+
+ The name of the chat model to use. For example: `"glm-4-flash@ZHIPU-AI"`
+
+ - Maximum 255 characters
+ - Must follow `model_name@model_factory` format
+
+#### Response
+
+Success:
+
+```json
+{
+ "code": 0,
+ "data": {
+ ...your new memory here
+ },
+ "message": true
+}
+```
+
+Failure:
+
+```json
+{
+ "code": 101,
+ "message": "Memory name cannot be empty or whitespace."
+}
+```
+
+
+
+### Update Memory
+
+**PUT** `/api/v1/memories/{memory_id}`
+
+Updates configurations for a specified memory.
+
+#### Request
+
+- Method: PUT
+- URL: `/api/v1/memories/{memory_id}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+- Body:
+ - `"name"`: `string`
+ - `"avatar"`: `string`
+ - `"permission"`: `string`
+ - `"llm_id"`: `string`
+ - `"description"`: `string`
+ - `"memory_size"`: `int`
+ - `"forgetting_policy"`: `string`
+ - `"temperature"`: `float`
+ - `"system_promot"`: `string`
+ - `"user_prompt"`: `string`
+
+##### Request example
+
+```bash
+curl --location --request PUT 'http://{address}/api/v1/memories/d6775d4eeada11f08ca284ba59bc53c7' \
+--header 'Content-Type: application/json' \
+--header 'Authorization: Bearer ' \
+--data '{
+ "name": "name_update",
+}'
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*)
+
+ The ID of the memory to update.
+
+- `name`: (*Body parameter*), `string`, *Optional*
+
+ The revised name of the memory.
+
+ - Basic Multilingual Plane (BMP) only
+ - Maximum 128 characters, *Optional*
+
+- `avatar`: (*Body parameter*), `string`, *Optional*
+
+ The updated base64 encoding of the avatar.
+
+ - Maximum 65535 characters
+
+- `permission`: (*Body parameter*), `enum`, *Optional*
+
+ The updated memory permission. Available options:
+
+ - `"me"`: (Default) Only you can manage the memory.
+ - `"team"`: All team members can manage the memory.
+
+- `llm_id`: (*Body parameter*), `string`, *Optional*
+
+ The name of the chat model to use. For example: `"glm-4-flash@ZHIPU-AI"`
+
+ - Maximum 255 characters
+ - Must follow `model_name@model_factory` format
+
+- `description`: (*Body parameter*), `string`, *Optional*
+
+ The description of the memory. Defaults to `None`.
+
+- `memory_size`: (*Body parameter*), `int`, *Optional*
+
+ Defaults to `5*1024*1024` Bytes. Accounts for each message's content + its embedding vector (≈ Content + Dimensions × 8 Bytes). Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default limit holds ~500 such messages.
+
+ - Maximum 10 * 1024 * 1024 Bytes
+
+- `forgetting_policy`: (*Body parameter*), `enum`, *Optional*
+
+ Evicts existing data based on the chosen policy when the size limit is reached, freeing up space for new messages. Available options:
+
+ - `"FIFO"`: (Default) Prioritize messages with the earliest `forget_at` time for removal. When the pool of messages that have `forget_at` set is insufficient, it falls back to selecting messages in ascending order of their `valid_at` (oldest first).
+
+- `temperature`: (*Body parameter*), `float`, *Optional*
+
+ Adjusts output randomness. Lower = more deterministic; higher = more creative.
+
+ - Range [0, 1]
+
+- `system_prompt`: (*Body parameter*), `string`, *Optional*
+
+ Defines the system-level instructions and role for the AI assistant. It is automatically assembled based on the selected `memory_type` by `PromptAssembler` in `memory/utils/prompt_util.py`. This prompt sets the foundational behavior and context for the entire conversation.
+
+ - Keep the `OUTPUT REQUIREMENTS` and `OUTPUT FORMAT` parts unchanged.
+
+- `user_prompt`: (*Body parameter*), `string`, *Optional*
+
+ Represents the user's custom setting, which is the specific question or instruction the AI needs to respond to directly. Defaults to `None`.
+
+#### Response
+
+Success:
+
+```json
+{
+ "code": 0,
+ "data": {
+ ...your updated memory here
+ },
+ "message": true
+}
+```
+
+Failure:
+
+```json
+{
+ "code": 101,
+ "message": "Memory name cannot be empty or whitespace."
+}
+```
+
+
+
+### List Memory
+
+**GET** `/api/v1/memories?tenant_id={tenant_ids}&memory_type={memory_types}&storage_type={storage_type}&keywords={keywords}&page={page}&page_size={page_size}`
+
+List memories.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/memories?tenant_id={tenant_ids}&memory_type={memory_types}&storage_type={storage_type}&keywords={keywords}&page={page}&page_size={page_size}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/memories?keywords=&page_size=50&page=1&memory_type=semantic%2Cepisodic' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `tenant_id`: (*Filter parameter*), `string` or `list[string]`, *Optional*
+
+ The owner's ID, supports search multiple IDs.
+
+- `memory_type`: (*Filter parameter*), `enum` or `list[enum]`, *Optional*
+
+ The type of memory (as set during creation). A memory matches if its type is **included in** the provided value(s). Available options:
+
+ - `raw`
+ - `semantic`
+ - `episodic`
+ - `procedural`
+
+- `storage_type`: (*Filter parameter*), `enum`, *Optional*
+
+ The storage format of messages. Available options:
+
+ - `table`: (Default)
+
+- `keywords`: (*Filter parameter*), `string`, *Optional*
+
+ The name of memory to retrieve, supports fuzzy search.
+
+- `page`: (*Filter parameter*), `int`, *Optional*
+
+ Specifies the page on which the memories will be displayed. Defaults to `1`.
+
+- `page_size`: (*Filter parameter*), `int`, *Optional*
+
+ The number of memories on each page. Defaults to `50`.
+
+#### Response
+
+Success:
+
+```json
+{
+ "code": 0,
+ "data": {
+ "memory_list": [
+ {
+ "avatar": null,
+ "create_date": "Tue, 06 Jan 2026 16:36:47 GMT",
+ "create_time": 1767688607040,
+ "description": null,
+ "id": "d6775d4eeada11f08ca284ba59bc53c7",
+ "memory_type": [
+ "raw",
+ "semantic"
+ ],
+ "name": "new_memory_1",
+ "owner_name": "Lynn",
+ "permissions": "me",
+ "storage_type": "table",
+ "tenant_id": "55777efac9df11f09cd07f49bd527ade"
+ },
+ ...other 3 memories here
+ ],
+ "total_count": 4
+ },
+ "message": true
+}
+```
+
+Failure:
+
+```json
+{
+ "code": 500,
+ "message": "Internal Server Error."
+}
+```
+
+
+
+### Get Memory Config
+
+**GET** `/api/v1/memories/{memory_id}/config`
+
+Get the configuration of a specified memory.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/memories/{memory_id}/config`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/memories/6c8983badede11f083f184ba59bc53c7/config' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": {
+ "avatar": null,
+ "create_date": "Mon, 22 Dec 2025 10:32:13 GMT",
+ "create_time": 1766370733354,
+ "description": null,
+ "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
+ "forgetting_policy": "FIFO",
+ "id": "6c8983badede11f083f184ba59bc53c7",
+ "llm_id": "glm-4.5-flash@ZHIPU-AI",
+ "memory_size": 5242880,
+ "memory_type": [
+ "raw",
+ "semantic",
+ "episodic",
+ "procedural"
+ ],
+ "name": "mem1222",
+ "owner_name": null,
+ "permissions": "me",
+ "storage_type": "table",
+ "system_prompt": ...your prompt here,
+ "temperature": 0.5,
+ "tenant_id": "55777efac9df11f09cd07f49bd527ade",
+ "update_date": null,
+ "update_time": null,
+ "user_prompt": null
+ },
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 404,
+ "data": null,
+ "message": "Memory '{memory_id}' not found."
+}
+```
+
+
+
+### Delete Memory
+
+**DELETE** `/api/v1/memories/{memory_id}`
+
+Delete a specified memory.
+
+#### Request
+
+- Method: DELETE
+- URL: `/api/v1/memories/{memory_id}`
+- Headers:
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location --request DELETE 'http://{address}/api/v1/memories/d6775d4eeada11f08ca284ba59bc53c7' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory to delete.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": null,
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 404,
+ "data": null,
+ "message": true
+}
+```
+
+
+
+### List messages of a memory
+
+**GET** `/api/v1/memories/{memory_id}?agent_id={agent_id}&keywords={session_id}&page={page}&page_size={page_size}`
+
+List the messages of a specified memory.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/memories/{memory_id}?agent_id={agent_id}&keywords={session_id}&page={page}&page_size={page_size}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/memories/6c8983badede11f083f184ba59bc53c?page=1' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory to show messages.
+
+- `agent_id`: (*Filter parameter*), `string` or `list[string]`, *Optional*
+
+ Filters messages by the ID of their source agent. Supports multiple values.
+
+- `session_id`: (*Filter parameter*), `string`, *Optional*
+
+ Filters messages by their session ID. This field supports fuzzy search.
+
+- `page`: (*Filter parameter*), `int`, *Optional*
+
+ Specifies the page on which the messages will be displayed. Defaults to `1`.
+
+- `page_size`: (*Filter parameter*), `int`, *Optional*
+
+ The number of messages on each page. Defaults to `50`.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": {
+ "messages": {
+ "message_list": [
+ {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "agent_name": "memory_agent_1223",
+ "extract": [
+ {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "agent_name": "memory_agent_1223",
+ "forget_at": "None",
+ "invalid_at": "None",
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 236,
+ "message_type": "semantic",
+ "session_id": "65b89ab8e96411f08d4e84ba59bc53c7",
+ "source_id": 233,
+ "status": true,
+ "user_id": "",
+ "valid_at": "2026-01-04 19:56:46"
+ },
+ ...other extracted messages
+ ],
+ "forget_at": "None",
+ "invalid_at": "None",
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 233,
+ "message_type": "raw",
+ "session_id": "65b89ab8e96411f08d4e84ba59bc53c7",
+ "source_id": "None",
+ "status": true,
+ "task": {
+ "progress": 1.0,
+ "progress_msg": "\n2026-01-04 19:56:46 Prepared prompts and LLM.\n2026-01-04 19:57:48 Get extracted result from LLM.\n2026-01-04 19:57:48 Extracted 6 messages from raw dialogue.\n2026-01-04 19:57:48 Prepared embedding model.\n2026-01-04 19:57:48 Embedded extracted content.\n2026-01-04 19:57:48 Saved messages to storage.\n2026-01-04 19:57:48 Message saved successfully."
+ },
+ "user_id": "",
+ "valid_at": "2026-01-04 19:56:42"
+ },
+ {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "agent_name": "memory_agent_1223",
+ "extract": [],
+ "forget_at": "None",
+ "invalid_at": "None",
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 226,
+ "message_type": "raw",
+ "session_id": "d982a8cbe96111f08a1384ba59bc53c7",
+ "source_id": "None",
+ "status": true,
+ "task": {
+ "progress": -1.0,
+ "progress_msg": "Failed to insert message into memory. Details: 6c8983badede11f083f184ba59bc53c7_228:{'type': 'document_parsing_exception', 'reason': \"[1:230] failed to parse field [valid_at] of type [date] in document with id '6c8983badede11f083f184ba59bc53c7_228'. Preview of field's value: ''\", 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'cannot parse empty date'}}; 6c8983badede11f083f184ba59bc53c7_229:{'type': 'document_parsing_exception', 'reason': \"[1:230] failed to parse field [valid_at] of type [date] in document with id '6c8983badede11f083f184ba59bc53c7_229'. Preview of field's value: ''\", 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'cannot parse empty date'}}; 6c8983badede11f083f184ba59bc53c7_230:{'type': 'document_parsing_exception', 'reason': \"[1:230] failed to parse field [valid_at] of type [date] in document with id '6c8983badede11f083f184ba59bc53c7_230'. Preview of field's value: ''\", 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'cannot parse empty date'}}; 6c8983badede11f083f184ba59bc53c7_231:{'type': 'document_parsing_exception', 'reason': \"[1:230] failed to parse field [valid_at] of type [date] in document with id '6c8983badede11f083f184ba59bc53c7_231'. Preview of field's value: ''\", 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'cannot parse empty date'}}; 6c8983badede11f083f184ba59bc53c7_232:{'type': 'document_parsing_exception', 'reason': \"[1:230] failed to parse field [valid_at] of type [date] in document with id '6c8983badede11f083f184ba59bc53c7_232'. Preview of field's value: ''\", 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'cannot parse empty date'}}"
+ },
+ "user_id": "",
+ "valid_at": "2026-01-04 19:38:26"
+ },
+ ...other 11 messages
+ ],
+ "total_count": 13
+ },
+ "storage_type": "table"
+ },
+ "message": true
+}
+```
+
+Failure
+
+```
+{
+ "code": 404,
+ "data": null,
+ "message": "Memory '{memory_id}' not found."
+}
+```
+
+
+
+### Add Message
+
+**POST** `/api/v1/messages`
+
+Add a message to specified memories.
+
+#### Request
+
+- Method: POST
+- URL: `/api/v1/messages`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+- Body:
+ - `"memory_id"`: `list[string]`
+ - `"agent_id"`: `string`
+ - `"session_id"`: `string`
+ - `"user_id"`: `string`
+ - `"user_input"`: `string`
+ - `"agent_response"`: `string`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/messages' \
+--header 'Content-Type: application/json' \
+--header 'Authorization: Bearer ' \
+--data '{
+ "memory_id": ["6c8983badede11f083f184ba59bc53c7", "87ebb892df1711f08d6b84ba59bc53c7"],
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "session_id": "bf0a50abeb8111f0917884ba59bc53c7",
+ "user_id": "55777efac9df11f09cd07f49bd527ade",
+ "user_input": "your user input here",
+ "agent_response": "your agent response here"
+
+}'
+```
+
+##### Request parameter
+
+- `memory_id`: (*Body parameter*), `list[string]`, *Required*
+
+ The IDs of the memories to save messages.
+
+- `agent_id`: (*Body parameter*), `string`, *Required*
+
+ The ID of the message's source agent.
+
+- `session_id`: (*Body parameter*), `string`, *Required*
+
+ The ID of the message's session.
+
+- `user_id`: (*Body parameter*), `string`, *Optional*
+
+ The user participating in the conversation with the agent. Defaults to `None`.
+
+- `user_input`: (*Body parameter*), `string`, *Required*
+
+ The text input provided by the user.
+
+- `agent_response`: (*Body parameter*), `string`, *Required*
+
+ The text response generated by the AI agent.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": null,
+ "message": "All add to task."
+}
+```
+
+Failure
+
+```json
+{
+ "code": 500,
+ "data": null,
+ "message": "Some messages failed to add. Detail: {fail information}"
+}
+```
+
+
+
+### Forget Message
+
+**DELETE** `/api/v1/messages/{memory_id}:{message_id}`
+
+Forget a specified message. After forgetting, this message will not be retrieved by agents, and it will also be prioritized for cleanup by the forgetting policy.
+
+#### Request
+
+- Method: DELETE
+- URL: `/api/v1/messages/{memory_id}:{message_id}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location --request DELETE 'http://{address}/api/v1/messages/6c8983badede11f083f184ba59bc53c7:272' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory to which the specified message belongs.
+
+- `message_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the message to forget.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": null,
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 404,
+ "data": null,
+ "message": "Memory '{memory_id}' not found."
+}
+```
+
+
+
+### Update message status
+
+**PUT** `/api/v1/messages/{memory_id}:{message_id}`
+
+Update message status, enable or disable a message. Once a message is disabled, it will not be retrieved by agents.
+
+#### Request
+
+- Method: PUT
+- URL: `/api/v1/messages/{memory_id}:{message_id}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+- Body:
+ - `"status"`: `bool`
+
+##### Request example
+
+```bash
+curl --location --request PUT 'http://{address}/api/v1/messages/6c8983badede11f083f184ba59bc53c7:270' \
+--header 'Content-Type: application/json' \
+--header 'Authorization: Bearer ' \
+--data '{
+ "status": false
+}'
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory to which the specified message belongs.
+
+- `message_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the message to enable or disable.
+
+- `status`: (*Body parameter*), `bool`, *Required*
+
+ The status of message. `True` = `enabled`, `False` = `disabled`.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": null,
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 404,
+ "data": null,
+ "message": "Memory '{memory_id}' not found."
+}
+```
+
+### Search Message
+
+**GET** `/api/v1/messages/search?query={question}&memory_id={memory_id}&similarity_threshold={similarity_threshold}&keywords_similarity_weight={keywords_similarity_weight}&top_n={top_n}`
+
+Searches and retrieves messages from memory based on the provided `query` and other configuration parameters.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/messages/search?query={question}&memory_id={memory_id}&similarity_threshold={similarity_threshold}&keywords_similarity_weight={keywords_similarity_weight}&top_n={top_n}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/messages/search?query=%22who%20are%20you%3F%22&memory_id=6c8983badede11f083f184ba59bc53c7&similarity_threshold=0.2&keywords_similarity_weight=0.7&top_n=10' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `question`: (*Filter parameter*), `string`, *Required*
+
+ The search term or natural language question used to find relevant messages.
+
+- `memory_id`: (*Filter parameter*), `string` or `list[string]`, *Required*
+
+ The IDs of the memories to search. Supports multiple values.
+
+- `agent_id`: (*Filter parameter*), `string`, *Optional*
+
+ The ID of the message's source agent. Defaults to `None`.
+
+- `session_id`: (*Filter parameter*), `string`, *Optional*
+
+ The ID of the message's session. Defaults to `None`.
+
+- `similarity_threshold`: (*Filter parameter*), `float`, *Optional*
+
+ The minimum cosine similarity score required for a message to be considered a match. A higher value yields more precise but fewer results. Defaults to `0.2`.
+
+ - Range [0.0, 1.0]
+
+- `keywords_similarity_weight` : (*Filter parameter*), `float`, *Optional*
+
+ Controls the influence of keyword matching versus semantic (embedding-based) matching in the final relevance score. A value of 0.5 gives them equal weight. Defaults to `0.7`.
+
+ - Range [0.0, 1.0]
+
+- `top_n`: (*Filter parameter*), `int`, *Optional*
+
+ The maximum number of most relevant messages to return. This limits the result set size for efficiency. Defaults to `10`.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": [
+ {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "content": "User Input: who am I?\nAgent Response: To address the question \"who am I?\", let's follow the logical steps outlined in the instructions:\n\n1. **Understand the User’s Request**: The user is asking for a clarification or identification of their own self. This is a fundamental question about personal identity.\n\n2. **Decompose the Request**: The request is quite simple and doesn't require complex decomposition. The core task is to provide an answer that identifies the user in some capacity.\n\n3. **Execute the Subtask**:\n - **Identify the nature of the question**: The user is seeking to understand their own existence or their sense of self.\n - **Assess the context**: The context is not explicitly given, so the response will be general.\n - **Provide a response**: The answer should acknowledge the user's inquiry into their identity.\n\n4. **Validate Accuracy and Consistency**: The response should be consistent with the general understanding of the question. Since the user has not provided specific details about their identity, the response should be broad and open-ended.\n\n5. **Summarize the Final Result**: The user is asking \"who am I?\" which is an inquiry into their own identity. The answer is that the user is the individual who is asking the question. Without more specific information, a detailed description of their identity cannot be provided.\n\nSo, the final summary would be:\n\nThe user is asking the question \"who am I?\" to seek an understanding of their own identity. The response to this question is that the user is the individual who is posing the question. Without additional context or details, a more comprehensive description of the user's identity cannot be given.",
+ "forget_at": "None",
+ "invalid_at": "None",
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 61,
+ "message_type": "raw",
+ "session_id": "ebf8025de52211f0b56684ba59bc53c7",
+ "source_id": "None",
+ "status": true,
+ "user_id": "",
+ "valid_at": "2025-12-30 09:57:49"
+ },
+ ...other 2 matched messages here
+ ],
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 500,
+ "message": "Internal Server Error."
+}
+```
+
+
+
+### Get Recent Messages
+
+**GET** `/api/v1/messages?memory_id={memory_id}&agent_id={agent_id}&session_id={session_id}&limit={limit}`
+
+Retrieves the most recent messages from specified memories. Typically accepts a `limit` parameter to control the number of messages returned.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/messages?memory_id={memory_id}&agent_id={agent_id}&session_id={session_id}&limit={limit}`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/messages?memory_id=6c8983badede11f083f184ba59bc53c7&limit=10' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Filter parameter*), `string` or `list[string]`, *Required*
+
+ The IDs of the memories to search. Supports multiple values.
+
+- `agent_id`: (*Filter parameter*), `string`, *Optional*
+
+ The ID of the message's source agent. Defaults to `None`.
+
+- `session_id`: (*Filter parameter*), `string`, *Optional*
+
+ The ID of the message's session. Defaults to `None`.
+
+- `limit`: (*Filter parameter*), `int`, *Optional*
+
+ Control the number of messages returned. Defaults to `10`.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": [
+ {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "content": "User Input: what is pineapple?\nAgent Response: A pineapple is a tropical fruit known for its sweet, tangy flavor and distinctive, spiky appearance. Here are the key facts:\nScientific Name: Ananas comosus\nPhysical Description: It has a tough, spiky, diamond-patterned outer skin (rind) that is usually green, yellow, or brownish. Inside, the juicy yellow flesh surrounds a fibrous core.\nGrowth: Unlike most fruits, pineapples do not grow on trees. They grow from a central stem as a composite fruit, meaning they are formed from many individual berries that fuse together around the core. They grow on a short, leafy plant close to the ground.\nUses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a popular ingredient in desserts, fruit salads, savory dishes (like pizzas or ham glazes), smoothies, and cocktails.\nNutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat.\nSymbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures.\nAre you asking about the fruit itself, or its use in a specific context?",
+ "forget_at": "None",
+ "invalid_at": "None",
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 269,
+ "message_type": "raw",
+ "session_id": "bf0a50abeb8111f0917884ba59bc53c7",
+ "source_id": "None",
+ "status": true,
+ "user_id": "",
+ "valid_at": "2026-01-07 16:49:12"
+ },
+ ...other 9 messages here
+ ],
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 500,
+ "message": "Internal Server Error."
+}
+```
+
+
+
+### Get Message Content
+
+**GET** `/api/v1/messages/{memory_id}:{message_id}/content`
+
+Retrieves the full content and embed vector of a specific message using its unique message ID.
+
+#### Request
+
+- Method: GET
+- URL: `/api/v1/messages/{memory_id}:{message_id}/content`
+- Headers:
+ - `'Content-Type: application/json'`
+ - `'Authorization: Bearer '`
+
+##### Request example
+
+```bash
+curl --location 'http://{address}/api/v1/messages/6c8983badede11f083f184ba59bc53c7:270/content' \
+--header 'Authorization: Bearer '
+```
+
+##### Request parameters
+
+- `memory_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the memory to which the specified message belongs.
+
+- `message_id`: (*Path parameter*), `string`, *Required*
+
+ The ID of the message.
+
+#### Response
+
+Success
+
+```json
+{
+ "code": 0,
+ "data": {
+ "agent_id": "8db9c8eddfcc11f0b5da84ba59bc53c7",
+ "content": "Pineapples are tropical fruits known for their sweet, tangy flavor and distinctive, spiky appearance",
+ "content_embed": [
+ 0.03641991,
+ ...embed vector here
+ ],
+ "forget_at": null,
+ "id": "6c8983badede11f083f184ba59bc53c7_270",
+ "invalid_at": null,
+ "memory_id": "6c8983badede11f083f184ba59bc53c7",
+ "message_id": 270,
+ "message_type": "semantic",
+ "session_id": "bf0a50abeb8111f0917884ba59bc53c7",
+ "source_id": 269,
+ "status": false,
+ "user_id": "",
+ "valid_at": "2026-01-07 16:48:37",
+ "zone_id": 0
+ },
+ "message": true
+}
+```
+
+Failure
+
+```json
+{
+ "code": 404,
+ "data": null,
+ "message": "Memory '{memory_id}' not found."
+}
+```
+
+
+
---
### System
diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md
index 3689da3f3eb..c0eeee3b3cc 100644
--- a/docs/references/python_api_reference.md
+++ b/docs/references/python_api_reference.md
@@ -1,8 +1,10 @@
---
sidebar_position: 5
slug: /python_api_reference
+sidebar_custom_props: {
+ categoryIcon: SiPython
+}
---
-
# Python API
A complete reference for RAGFlow's Python APIs. Before proceeding, please ensure you [have your RAGFlow API key ready for authentication](https://ragflow.io/docs/dev/acquire_ragflow_api_key).
@@ -63,8 +65,17 @@ Whether to receive the response as a stream. Set this to `false` explicitly if y
#### Examples
+> **Note**
+> Streaming via `client.chat.completions.create(stream=True, ...)` does not
+> return `reference` currently because `reference` is only exposed in the
+> non-stream response payload. The only way to return `reference` is non-stream
+> mode with `with_raw_response`.
+:::caution NOTE
+Streaming via `client.chat.completions.create(stream=True, ...)` does not return `reference` because it is *only* included in the raw response payload in non-stream mode. To return `reference`, set `stream=False`.
+:::
```python
from openai import OpenAI
+import json
model = "model"
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/")
@@ -72,7 +83,7 @@ client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api
stream = True
reference = True
-completion = client.chat.completions.create(
+request_kwargs = dict(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@@ -80,22 +91,36 @@ completion = client.chat.completions.create(
{"role": "assistant", "content": "I am an AI assistant named..."},
{"role": "user", "content": "Can you tell me how to install neovim"},
],
- stream=stream,
- extra_body={"reference": reference}
+ extra_body={
+ "extra_body": {
+ "reference": reference,
+ "reference_metadata": {
+ "include": True,
+ "fields": ["author", "year", "source"],
+ },
+ }
+ },
)
if stream:
+ completion = client.chat.completions.create(stream=True, **request_kwargs)
for chunk in completion:
print(chunk)
- if reference and chunk.choices[0].finish_reason == "stop":
- print(f"Reference:\n{chunk.choices[0].delta.reference}")
- print(f"Final content:\n{chunk.choices[0].delta.final_content}")
else:
- print(completion.choices[0].message.content)
- if reference:
- print(completion.choices[0].message.reference)
+ resp = client.chat.completions.with_raw_response.create(
+ stream=False, **request_kwargs
+ )
+ print("status:", resp.http_response.status_code)
+ raw_text = resp.http_response.text
+ print("raw:", raw_text)
+
+ data = json.loads(raw_text)
+ print("assistant:", data["choices"][0]["message"].get("content"))
+ print("reference:", data["choices"][0]["message"].get("reference"))
```
+When `extra_body.reference_metadata.include` is `true`, each reference chunk may include a `document_metadata` object in both streaming and non-streaming responses.
+
## DATASET MANAGEMENT
---
@@ -1516,6 +1541,8 @@ A list of `Chunk` objects representing references to the message, each containin
The ID of the referenced document.
- `document_name` `str`
The name of the referenced document.
+- `document_metadata` `dict`
+ Optional document metadata, returned only when `extra_body.reference_metadata.include` is `true`.
- `position` `list[str]`
The location information of the chunk within the referenced document.
- `dataset_id` `str`
@@ -1641,6 +1668,8 @@ A list of `Chunk` objects representing references to the message, each containin
The ID of the referenced document.
- `document_name` `str`
The name of the referenced document.
+- `document_metadata` `dict`
+ Optional document metadata, returned only when `extra_body.reference_metadata.include` is `true`.
- `position` `list[str]`
The location information of the chunk within the referenced document.
- `dataset_id` `str`
@@ -1960,3 +1989,637 @@ rag_object.delete_agent("58af890a2a8911f0a71a11b922ed82d6")
```
---
+
+
+
+## Memory Management
+
+### Create Memory
+
+```python
+Ragflow.create_memory(
+ name: str,
+ memory_type: list[str],
+ embd_id: str,
+ llm_id: str
+) -> Memory
+```
+
+Create a new memory.
+
+#### Parameters
+
+##### name: `str`, *Required*
+
+The unique name of the memory to create. It must adhere to the following requirements:
+
+- Basic Multilingual Plane (BMP) only
+- Maximum 128 characters
+
+##### memory_type: `list[str]`, *Required*
+
+Specifies the types of memory to extract. Available options:
+
+- `raw`: The raw dialogue content between the user and the agent . *Required by default*.
+- `semantic`: General knowledge and facts about the user and world.
+- `episodic`: Time-stamped records of specific events and experiences.
+- `procedural`: Learned skills, habits, and automated procedures.
+
+##### embd_id: `str`, *Required*
+
+The name of the embedding model to use. For example: `"BAAI/bge-large-zh-v1.5@BAAI"`
+
+- Maximum 255 characters
+- Must follow `model_name@model_factory` format
+
+##### llm_id: `str`, *Required*
+
+The name of the chat model to use. For example: `"glm-4-flash@ZHIPU-AI"`
+
+- Maximum 255 characters
+- Must follow `model_name@model_factory` format
+
+#### Returns
+
+- Success: A `memory` object.
+
+- Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import RAGFlow
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory = rag_obj.create_memory("name", ["raw"], "BAAI/bge-large-zh-v1.5@SILICONFLOW", "glm-4-flash@ZHIPU-AI")
+```
+
+---
+
+
+
+### Update Memory
+
+```python
+Memory.update(
+ update_dict: dict
+) -> Memory
+```
+
+Updates configurations for a specified memory.
+
+#### Parameters
+
+##### update_dict: `dict`, *Required*
+
+Configurations to update. Available configurations:
+
+- `name`: `string`, *Optional*
+
+ The revised name of the memory.
+
+ - Basic Multilingual Plane (BMP) only
+ - Maximum 128 characters, *Optional*
+
+- `avatar`: `string`, *Optional*
+
+ The updated base64 encoding of the avatar.
+
+ - Maximum 65535 characters
+
+- `permission`: `enum`, *Optional*
+
+ The updated memory permission. Available options:
+
+ - `"me"`: (Default) Only you can manage the memory.
+ - `"team"`: All team members can manage the memory.
+
+- `llm_id`: `string`, *Optional*
+
+ The name of the chat model to use. For example: `"glm-4-flash@ZHIPU-AI"`
+
+ - Maximum 255 characters
+ - Must follow `model_name@model_factory` format
+
+- `description`: `string`, *Optional*
+
+ The description of the memory. Defaults to `None`.
+
+- `memory_size`: `int`, *Optional*
+
+ Defaults to `5*1024*1024` Bytes. Accounts for each message's content + its embedding vector (≈ Content + Dimensions × 8 Bytes). Example: A 1 KB message with 1024-dim embedding uses ~9 KB. The 5 MB default limit holds ~500 such messages.
+
+ - Maximum 10 * 1024 * 1024 Bytes
+
+- `forgetting_policy`: `enum`, *Optional*
+
+ Evicts existing data based on the chosen policy when the size limit is reached, freeing up space for new messages. Available options:
+
+ - `"FIFO"`: (Default) Prioritize messages with the earliest `forget_at` time for removal. When the pool of messages that have `forget_at` set is insufficient, it falls back to selecting messages in ascending order of their `valid_at` (oldest first).
+
+- `temperature`: (*Body parameter*), `float`, *Optional*
+
+ Adjusts output randomness. Lower = more deterministic; higher = more creative.
+
+ - Range [0, 1]
+
+- `system_prompt`: (*Body parameter*), `string`, *Optional*
+
+ Defines the system-level instructions and role for the AI assistant. It is automatically assembled based on the selected `memory_type` by `PromptAssembler` in `memory/utils/prompt_util.py`. This prompt sets the foundational behavior and context for the entire conversation.
+
+ - Keep the `OUTPUT REQUIREMENTS` and `OUTPUT FORMAT` parts unchanged.
+
+- `user_prompt`: (*Body parameter*), `string`, *Optional*
+
+ Represents the user's custom setting, which is the specific question or instruction the AI needs to respond to directly. Defaults to `None`.
+
+#### Returns
+
+- Success: A `memory` object.
+
+- Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory_obejct = Memory(rag_object, {"id": "your memory_id"})
+memory_object.update({"name": "New_name"})
+```
+
+---
+
+
+
+### List Memory
+
+```python
+Ragflow.list_memory(
+ page: int = 1,
+ page_size: int = 50,
+ tenant_id: str | list[str] = None,
+ memory_type: str | list[str] = None,
+ storage_type: str = None,
+ keywords: str = None) -> dict
+```
+
+List memories.
+
+#### Parameters
+
+##### page: `int`, *Optional*
+
+Specifies the page on which the datasets will be displayed. Defaults to `1`
+
+##### page_size: `int`, *Optional*
+
+The number of memories on each page. Defaults to `50`.
+
+##### tenant_id: `str` or `list[str]`, *Optional*
+
+The owner's ID, supports search multiple IDs.
+
+##### memory_type: `str` or `list[str]`, *Optional*
+
+The type of memory (as set during creation). A memory matches if its type is **included in** the provided value(s). Available options:
+
+- `raw`
+- `semantic`
+- `episodic`
+- `procedural`
+
+##### storage_type: `str`, *Optional*
+
+The storage format of messages. Available options:
+
+- `table`: (Default)
+
+##### keywords: `str`, *Optional*
+
+The name of memory to retrieve, supports fuzzy search.
+
+#### Returns
+
+Success: A dict of `Memory` object list and total count.
+
+```json
+{"memory_list": list[Memory], "total_count": int}
+```
+
+Failure: `Exception`
+
+#### Examples
+
+```
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+rag_obejct.list_memory()
+```
+
+---
+
+
+
+### Get Memory Config
+
+```python
+Memory.get_config()
+```
+
+Get the configuration of a specified memory.
+
+#### Parameters
+
+None
+
+#### Returns
+
+Success: A `Memory` object.
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory_obejct = Memory(rag_object, {"id": "your memory_id"})
+memory_obejct.get_config()
+```
+
+---
+
+
+
+### Delete Memory
+
+```python
+Ragflow.delete_memory(
+ memory_id: str
+) -> None
+```
+
+Delete a specified memory.
+
+#### Parameters
+
+##### memory_id: `str`, *Required*
+
+The ID of the memory.
+
+#### Returns
+
+Success: Nothing
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+rag_object.delete_memory("your memory_id")
+```
+
+---
+
+
+
+### List messages of a memory
+
+```python
+Memory.list_memory_messages(
+ agent_id: str | list[str]=None,
+ keywords: str=None,
+ page: int=1,
+ page_size: int=50
+) -> dict
+```
+
+List the messages of a specified memory.
+
+#### Parameters
+
+##### agent_id: `str` or `list[str]`, *Optional*
+
+Filters messages by the ID of their source agent. Supports multiple values.
+
+##### keywords: `str`, *Optional*
+
+Filters messages by their session ID. This field supports fuzzy search.
+
+##### page: `int`, *Optional*
+
+Specifies the page on which the messages will be displayed. Defaults to `1`.
+
+##### page_size: `int`, *Optional*
+
+The number of messages on each page. Defaults to `50`.
+
+#### Returns
+
+Success: a dict of messages and meta info.
+
+```json
+{"messages": {"message_list": [{message dict}], "total_count": int}, "storage_type": "table"}
+```
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory_obejct = Memory(rag_object, {"id": "your memory_id"})
+memory_obejct.list_memory_messages()
+```
+
+---
+
+
+
+### Add Message
+
+```python
+Ragflow.add_message(
+ memory_id: list[str],
+ agent_id: str,
+ session_id: str,
+ user_input: str,
+ agent_response: str,
+ user_id: str = ""
+) -> str
+```
+
+Add a message to specified memories.
+
+#### Parameters
+
+##### memory_id: `list[str]`, *Required*
+
+The IDs of the memories to save messages.
+
+##### agent_id: `str`, *Required*
+
+The ID of the message's source agent.
+
+##### session_id: `str`, *Required*
+
+The ID of the message's session.
+
+##### user_input: `str`, *Required*
+
+The text input provided by the user.
+
+##### agent_response: `str`, *Required*
+
+The text response generated by the AI agent.
+
+##### user_id: `str`, *Optional*
+
+The user participating in the conversation with the agent. Defaults to `""`.
+
+#### Returns
+
+Success: A text `"All add to task."`
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+message_payload = {
+ "memory_id": memory_ids,
+ "agent_id": agent_id,
+ "session_id": session_id,
+ "user_id": "",
+ "user_input": "Your question here",
+ "agent_response": """
+Your agent response here
+"""
+}
+client.add_message(**message_payload)
+```
+
+---
+
+
+
+### Forget Message
+
+```python
+Memory.forget_message(message_id: int) -> bool
+```
+
+Forget a specified message. After forgetting, this message will not be retrieved by agents, and it will also be prioritized for cleanup by the forgetting policy.
+
+#### Parameters
+
+##### message_id: `int`, *Required*
+
+The ID of the message to forget.
+
+#### Returns
+
+Success: True
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory_object = Memory(rag_object, {"id": "your memory_id"})
+memory_object.forget_message(message_id)
+```
+
+---
+
+
+
+### Update message status
+
+```python
+Memory.update_message_status(message_id: int, status: bool) -> bool
+```
+
+Update message status, enable or disable a message. Once a message is disabled, it will not be retrieved by agents.
+
+#### Parameters
+
+##### message_id: `int`, *Required*
+
+The ID of the message to enable or disable.
+
+##### status: `bool`, *Required*
+
+The status of message. `True` = `enabled`, `False` = `disabled`.
+
+#### Returns
+
+Success: `True`
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow, Memory
+rag_object = RAGFlow(api_key="", base_url="http://:9380")
+memory_object = Memory(rag_object, {"id": "your memory_id"})
+memory_object.update_message_status(message_id, True)
+```
+
+---
+
+
+
+### Search message
+
+```python
+Ragflow.search_message(
+ query: str,
+ memory_id: list[str],
+ agent_id: str=None,
+ session_id: str=None,
+ similarity_threshold: float=0.2,
+ keywords_similarity_weight: float=0.7,
+ top_n: int=10
+) -> list[dict]
+```
+
+Searches and retrieves messages from memory based on the provided `query` and other configuration parameters.
+
+#### Parameters
+
+##### query: `str`, *Required*
+
+The search term or natural language question used to find relevant messages.
+
+##### memory_id: `list[str]`, *Required*
+
+The IDs of the memories to search. Supports multiple values.
+
+##### agent_id: `str`, *Optional*
+
+The ID of the message's source agent. Defaults to `None`.
+
+##### session_id: `str`, *Optional*
+
+The ID of the message's session. Defaults to `None`.
+
+##### similarity_threshold: `float`, *Optional*
+
+The minimum cosine similarity score required for a message to be considered a match. A higher value yields more precise but fewer results. Defaults to `0.2`.
+
+- Range [0.0, 1.0]
+
+##### keywords_similarity_weight: `float`, *Optional*
+
+Controls the influence of keyword matching versus semantic (embedding-based) matching in the final relevance score. A value of 0.5 gives them equal weight. Defaults to `0.7`.
+
+- Range [0.0, 1.0]
+
+##### top_n: `int`, *Optional*
+
+The maximum number of most relevant messages to return. This limits the result set size for efficiency. Defaults to `10`.
+
+#### Returns
+
+Success: A list of `message` dict.
+
+Failure: `Exception`
+
+#### Examples
+
+```python
+from ragflow_sdk import Ragflow
+rag_object = RAGFlow(api_key="", base_url="http://