diff --git a/LICENSE b/LICENSE index c4d1011..f79bf7b 100644 --- a/LICENSE +++ b/LICENSE @@ -25,13 +25,14 @@ The nlp_text_splitter utlity uses the following sentence detection libraries: ***************************************************************************** -The WtP, "Where the Point", sentence segmentation library falls under the MIT License: +The WtP, "Where the Point", and SaT, "Segment any Text" sentence segmentation +library falls under the MIT License: -https://github.com/bminixhofer/wtpsplit/blob/main/LICENSE +https://github.com/segment-any-text/wtpsplit/blob/main/LICENSE MIT License -Copyright (c) 2024 Benjamin Minixhofer +Copyright (c) 2024 Benjamin Minixhofer, Markus Frohmann, Igor Sterner Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/detection/nlp_text_splitter/README.md b/detection/nlp_text_splitter/README.md index e8c7d14..3dae3be 100644 --- a/detection/nlp_text_splitter/README.md +++ b/detection/nlp_text_splitter/README.md @@ -1,8 +1,8 @@ # Overview This directory contains the source code, test examples, and installation script -for the OpenMPF NlpTextSplitter tool, which uses WtP and spaCy libraries -to detect sentences in a given chunk of text. +for the OpenMPF NlpTextSplitter tool, which uses **SaT (Segment any Text)**, +**WtP (Where's the Point)**, and **spaCy** to detect sentences in a given chunk of text. # Background @@ -10,16 +10,48 @@ Our primary motivation for creating this tool was to find a lightweight, accurat sentence detection capability to support a large variety of text processing tasks including translation and tagging. -Through preliminary investigation, we identified the [WtP library ("Where's the -Point")](https://github.com/bminixhofer/wtpsplit) and [spaCy's multilingual sentence +Through preliminary investigation, we identified the [WtP/SaT library ("Where's the +Point"/"Segment any Text")](https://github.com/bminixhofer/wtpsplit) and [spaCy's multilingual sentence detection model](https://spacy.io/models) for identifying sentence breaks in a large section of text. WtP models are trained to split up multilingual text by sentence without the need of an input language tag. The disadvantage is that the most accurate WtP models will need ~3.5 -GB of GPU memory. On the other hand, spaCy has a single multilingual sentence detection -that appears to work better for splitting up English text in certain cases. Unfortunately -this model lacks support handling for Chinese punctuation. +GB of GPU memory. SaT is the newer successor to WtP from the same authors and +generally offers better accuracy/efficiency. + +On the other hand, spaCy has a single multilingual sentence detection +that appears to work better for splitting up English text in certain cases. + +This component has been updated to use the Azure Translation Component's NewLineBehavior class +for swapping newlines with either whitespace or removing it altogether based on script detected. + +The reason why we need to consider the script/character encodings is because certain languages +will treat whitespace between words as possessing different meanings. For instance in Chinese + +`电脑` would mean `computer` but `电 脑` would mean `electricity brain`. + +When calling the NLP text splitter, users can adjust the following parameters to control for sentence +splitting behaviors: + +- `split_mode`: set to `DEFAULT` for splitting by chunk size and `SENTENCE` when splitting by sentences + +- `newline_behavior` : controls how newlines are handled in a submitted input text. Options include: + - `GUESS` to choose ' ' for space-separated langs; '' for Chinese/Japanese/Korean. + - `SPACE` to always replace with a single space. + - `REMOVE` to always remove (no space). + - `NONE` to no change. + +For instance: +``` + result = list(TextSplitter.split(input_text, + ... + self.sat_model, + split_mode='DEFAULT') + newline_behavior='NONE') +``` +Will attempt to split using an SaT model, using the default chunking parameters and no newline adjustments. + # Installation @@ -40,13 +72,33 @@ Please note that several customizations are supported: setup a PyTorch installation with CUDA (GPU) libraries. - `--wtp-models-dir |-m `: Add this parameter to - change the default WtP model installation directory + change the default WtP/SaT model installation directory (default: `/opt/wtp/models`). - `--install-wtp-model|-w `: Add this parameter to specify - additional WTP models for installation. This parameter can be provided - multiple times to install more than one model. + additional WtP/SaT models for installation. Accepts both WtP names + (e.g., `wtp-bert-mini`) and SaT names (e.g., `sat-6l-sm`). + This parameter can be provided multiple times to install more than one model. - `--install-spacy-model|-s `: Add this parameter to specify additional spaCy models for installation. This parameter can be provided multiple times to install more than one model. + + +# Optimal WtP / SaT model: + +Based on testing, `sat-6l-sm` emerges as the recommended model, particularly when accuracy is prioritized and GPU resources (~1 GB) are available + +| Metric | sat-3l-sm | sat-6l-sm | Difference | +|--------------------------|-----------------|-----------------|--------------------------| +| **Accuracy (%)** | 97.6% | **98.8%** | +1.2% | +| **GPU Memory Used (MB)** | **1083.8** | 1125.8 | +42 MB (~4% increase) | +| **GPU Processing Time (s)** | **2.81** | 3.00 | +0.19 s | +| **CPU Processing Time (s)** | **6.57** | 11.72 | **+5.15 s (significant)**| + +Key Considerations: +- `sat-3l-sm` is slightly less accurate but runs at x2 speedup in CPU compared to `sat-6l-sm` +- While running on GPU: Both models use roughly the same amount of GPU resource/runtime. +- This means that it's generally advantageous to use `sat-6l-sm` when GPU is available, and to fall back to `sat-3l-sm` if only CPU resources are available. + +For detailed results and additional models, see [WtP SaT Text Splitter Analysis](WtP%20SaT%20Text%20Splitter%20Analysis.xlsx). diff --git a/detection/nlp_text_splitter/install.sh b/detection/nlp_text_splitter/install.sh index 38d1f5c..749b682 100755 --- a/detection/nlp_text_splitter/install.sh +++ b/detection/nlp_text_splitter/install.sh @@ -7,11 +7,11 @@ # under contract, and is subject to the Rights in Data-General Clause # # 52.227-14, Alt. IV (DEC 2007). # # # -# Copyright 2024 The MITRE Corporation. All Rights Reserved. # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # ############################################################################# ############################################################################# -# Copyright 2024 The MITRE Corporation # +# Copyright 2025 The MITRE Corporation # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # @@ -37,7 +37,7 @@ main() { fi eval set -- "$options" local wtp_models_dir=/opt/wtp/models - local wtp_models=("wtp-bert-mini") + local wtp_models=("wtp-bert-mini" "sat-3l-sm") local spacy_models=("xx_sent_ud_sm") while true; do case "$1" in @@ -107,10 +107,20 @@ download_wtp_models() { for model_name in "${model_names[@]}"; do echo "Downloading the $model_name model to $wtp_models_dir." - local wtp_model_dir="$wtp_models_dir/$model_name" + local model_dir="$wtp_models_dir/$model_name" + + # Decide which HF org to use based on model prefix. + # - WtP: benjamin/ + # - SaT: segment-any-text/ + local hf_owner="benjamin" + case "$model_name" in + sat-*) hf_owner="segment-any-text" ;; + esac + python3 -c \ "from huggingface_hub import snapshot_download; \ - snapshot_download('benjamin/$model_name', local_dir='$wtp_model_dir')" + snapshot_download(repo_id='${hf_owner}/${model_name}', local_dir='${model_dir}')" + done } @@ -149,12 +159,12 @@ Options --text-splitter-dir, -t : Path to text splitter source code. (defaults to to the same directory as this script) --gpu, -g: Install the GPU version of PyTorch - --wtp-models-dir , -m : Path where WTP models will be stored. + --wtp-models-dir , -m : Path where WtP/SaT models will be stored. (defaults to /opt/wtp/models) - --install-wtp-model, -w : Name of a WTP model to install in addtion to wtp-bert-mini. + --install-wtp-model, -w : Name of a WTP or SaT model to install in addition to 'wtp-bert-mini' and 'sat-3l-sm. This option can be provided more than once to specify multiple models. - --install-spacy-model | -s : Names of a spaCy model to install in addtion to + --install-spacy-model | -s : Names of a spaCy model to install in addition to xx_sent_ud_sm. The option can be provided more than once to specify multiple models. " diff --git a/detection/nlp_text_splitter/nlp_text_splitter/__init__.py b/detection/nlp_text_splitter/nlp_text_splitter/__init__.py index 3913b9a..d13daff 100644 --- a/detection/nlp_text_splitter/nlp_text_splitter/__init__.py +++ b/detection/nlp_text_splitter/nlp_text_splitter/__init__.py @@ -5,11 +5,11 @@ # under contract, and is subject to the Rights in Data-General Clause # # 52.227-14, Alt. IV (DEC 2007). # # # -# Copyright 2024 The MITRE Corporation. All Rights Reserved. # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # ############################################################################# ############################################################################# -# Copyright 2024 The MITRE Corporation # +# Copyright 2025 The MITRE Corporation # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # @@ -30,35 +30,40 @@ from importlib.resources.abc import Traversable import spacy -from wtpsplit import WtP -from typing import Callable, List, Optional, Tuple - -from .wtp_lang_settings import WtpLanguageSettings - import torch +import re +import bisect + +from wtpsplit import WtP, SaT +from typing import Callable, List, Optional, Tuple, Union +from .wtp_lang_settings import WtpLanguageSettings +from .newline_behavior import NewLineBehavior DEFAULT_WTP_MODELS = "/opt/wtp/models" # If we want to package model installation with this utility in the future: -WTP_MODELS_PATH: Traversable = importlib.resources.files(__name__) / 'models' +MODELS_PATH: Traversable = importlib.resources.files(__name__) / 'models' log = logging.getLogger(__name__) +_LAST_WS_RE = re.compile(r"\s(?=\S*$)") + + # These models must have an specified language during sentence splitting. -WTP_MANDATORY_ADAPTOR = ['wtp-canine-s-1l', - 'wtp-canine-s-3l', - 'wtp-canine-s-6l', - 'wtp-canine-s-9l', - 'wtp-canine-s-12l'] +WTP_MANDATORY_ADAPTOR = { + 'wtp-canine-s-1l', + 'wtp-canine-s-3l', + 'wtp-canine-s-6l', + 'wtp-canine-s-9l', + 'wtp-canine-s-12l', +} -GPU_AVAILABLE = False -if torch.cuda.is_available(): - GPU_AVAILABLE = True +GPU_AVAILABLE = torch.cuda.is_available() class TextSplitterModel: - # To hold spaCy, WtP, and other potential sentence detection models in cache + # To hold spaCy, WtP, SaT, and other potential sentence detection models in cache def __init__(self, model_name: str, model_setting: str, default_lang: str = "en") -> None: self._model_name = "" @@ -68,68 +73,95 @@ def __init__(self, model_name: str, model_setting: str, default_lang: str = "en" self.split = lambda t, **param: [t] self.update_model(model_name, model_setting, default_lang) - def update_model(self, model_name: str, model_setting: str = "cpu", default_lang: str="en"): - if model_name: - if "wtp" in model_name: - self._update_wtp_model(model_name, model_setting, default_lang) - self.split = self._split_wtp - log.info(f"Setup WtP model: {model_name}") - else: - self._update_spacy_model(model_name) - self.split = self._split_spacy - log.info(f"Setup spaCy model: {model_name}") - - def _update_wtp_model(self, wtp_model_name: str, - model_setting: str, - default_lang: str) -> None: + def update_model(self, model_name: str, model_setting: str = "cpu", default_lang: str = "en"): + if not model_name: + return + + lower_name = model_name.lower() + if lower_name.startswith("wtp"): + self._update_wtp_model(model_name, model_setting, default_lang) + self.split = self._split_wtp + log.info(f"Setup WtP model: {model_name}") + elif lower_name.startswith("sat"): + self._update_sat_model(model_name, model_setting, default_lang) + self.split = self._split_sat + log.info(f"Setup SaT model: {model_name}") + else: + self._update_spacy_model(model_name) + self.split = self._split_spacy + log.info(f"Setup spaCy model: {model_name}") - if model_setting == "gpu" or model_setting == "cuda": + def _resolve_cpu_gpu_device(self, model_setting: str) -> str: + if model_setting in ("gpu", "cuda"): if GPU_AVAILABLE: - model_setting = "cuda" + return "cuda" else: log.warning("PyTorch determined that CUDA is not available. " "You may need to update the NVIDIA driver for the host system, " "or reinstall PyTorch with GPU support by setting " "ARGS BUILD_TYPE=gpu in the Dockerfile when building this component.") - model_setting = "cpu" - elif model_setting != "cpu": - log.warning("Invalid WtP model setting. Only `cpu` and `cuda` " - "(or `gpu`) WtP model options available at this time. " + return "cpu" + if model_setting != "cpu": + log.warning( + f"Invalid model setting {model_setting}. Only `cpu` and `cuda` " + "(or `gpu`) WtP/SaT model options available at this time. " "Defaulting to `cpu` mode.") - model_setting = "cpu" + return "cpu" + + def _find_local_model_path(self, model_name: str) -> Optional[str]: + candidate = MODELS_PATH / model_name + if candidate.is_file() or candidate.is_dir(): + with importlib.resources.as_file(candidate) as path: + return str(path) - if wtp_model_name in WTP_MANDATORY_ADAPTOR: - self._mandatory_wtp_language = True - self._default_lang = default_lang + fallback = os.path.join(DEFAULT_WTP_MODELS, model_name) + if os.path.exists(fallback): + return fallback + return None - if self._model_name == wtp_model_name and self._model_setting == model_setting: - log.info(f"Using cached model, running on {self._model_setting}: " - f"{self._model_name}") + def _update_wtp_model(self, wtp_model_name: str, + model_setting: str, + default_lang: str) -> None: + device = self._resolve_cpu_gpu_device(model_setting) + + self._model_name = wtp_model_name + self._model_setting = device + self._default_lang = default_lang + self._mandatory_wtp_language = (wtp_model_name in WTP_MANDATORY_ADAPTOR) + + local_path = self._find_local_model_path(wtp_model_name) + + if local_path: + log.info(f"Using downloaded WtP model at {local_path}") + self.wtp_model = WtP(local_path) else: - self._model_setting = model_setting - self._model_name = wtp_model_name - # Check if model has been downloaded - if (WTP_MODELS_PATH / wtp_model_name).is_file(): - log.info(f"Using downloaded {wtp_model_name} model.") - with importlib.resources.as_file(WTP_MODELS_PATH / wtp_model_name) as path: - self.wtp_model = WtP(str(path)) - elif os.path.exists(os.path.join(DEFAULT_WTP_MODELS, - wtp_model_name)): - - log.info(f"Using downloaded {wtp_model_name} model.") - wtp_model_name = os.path.join(DEFAULT_WTP_MODELS, - wtp_model_name) - self.wtp_model = WtP(wtp_model_name) - else: - log.warning(f"Model {wtp_model_name} not found, " - "downloading from hugging face.") - self.wtp_model = WtP(wtp_model_name) + log.warning(f"WtP model {wtp_model_name} not found locally; downloading from Hugging Face.") + self.wtp_model = WtP(wtp_model_name) + self.wtp_model.to(device) + + def _update_sat_model(self, sat_model_name: str, model_setting: str, default_lang: str) -> None: + device = self._resolve_cpu_gpu_device(model_setting) + + self._model_name = sat_model_name + self._model_setting = device + self._default_lang = default_lang + self._mandatory_wtp_language = (sat_model_name in WTP_MANDATORY_ADAPTOR) + + local_path = self._find_local_model_path(sat_model_name) + + if local_path: + log.info(f"Using downloaded SaT model at {local_path}") + self.sat_model = SaT(local_path) + else: + log.warning(f"SaT model {sat_model_name} not found locally; downloading from Hugging Face.") + self.sat_model = SaT(sat_model_name) + + # Move model to device; SaT benefits from half precision on GPU. + if device == "cuda": + self.sat_model.half().to("cuda") + else: + self.sat_model.to("cpu") - if model_setting != "cpu" and model_setting != "cuda": - log.warning(f"Invalid setting for WtP runtime {model_setting}. " - "Defaulting to CPU mode.") - model_setting = "cpu" - self.wtp_model.to(model_setting) def _split_wtp(self, text: str, lang: Optional[str] = None) -> List[str]: if lang: @@ -152,6 +184,10 @@ def _update_spacy_model(self, spacy_model_name: str): self.spacy_model = spacy.load(spacy_model_name, exclude=["parser"]) self.spacy_model.enable_pipe("senter") + def _split_sat(self, text: str, lang: Optional[str] = None) -> List[str]: + # TODO: For now, we'll only use the SaT models that are language agnostic. + return self.sat_model.split(text) + def _split_spacy(self, text: str, lang: Optional[str] = None) -> List[str]: # TODO: We may add an auto model selection for spaCy in the future. # However, the drawback is we will also need to @@ -160,29 +196,53 @@ def _split_spacy(self, text: str, lang: Optional[str] = None) -> List[str]: return [sent.text_with_ws for sent in processed_text.sents] class TextSplitter: + NewLineBehaviorType = Union[ + NewLineBehavior.Behavior, # 'GUESS' | 'SPACE' | 'REMOVE' | 'NONE' | callable | None + ] def __init__( self, text: str, limit: int, num_boundary_chars: int, get_text_size: Callable[[str], int], sentence_model: TextSplitterModel, - in_lang: Optional[str] = None) -> None: + in_lang: Optional[str] = None, + split_mode: str = 'DEFAULT', + newline_behavior: NewLineBehaviorType = 'GUESS', + preferred_limit: int = -1 + ) -> None: + self._sentence_model = sentence_model self._limit = limit self._num_boundary_chars = num_boundary_chars self._get_text_size = get_text_size + self._in_lang = in_lang + self._split_mode = split_mode + + self._newline_fn: Callable[[str, Optional[str]], str] = NewLineBehavior.get(newline_behavior) self._text = "" self._text_full_size = 0 self._overhead_size = 0 self._soft_limit = self._limit - self._in_lang = in_lang + + if preferred_limit > 0: + self._preferred_limit = min(preferred_limit, limit) + else: + self._preferred_limit = limit + if text: self.set_text(text) def set_text(self, text: str): - self._text = text - self._text_full_size = self._get_text_size(text) - chars_per_size = len(text) / self._text_full_size + + if text: + self._text = self._newline_fn(text, self._in_lang) + else: + self._text = text + + self._text_full_size = self._get_text_size(self._text) + + text_size = self._text_full_size if self._text_full_size > 0 else 1 + chars_per_size = len(self._text) / text_size self._overhead_size = self._get_text_size('') self._soft_limit = int(self._limit * chars_per_size) - self._overhead_size @@ -194,7 +254,6 @@ def set_text(self, text: str): # before applying chars_per_size weighting. self._soft_limit = max(1, int((self._limit - self._overhead_size) * chars_per_size)) - def _isolate_largest_section(self, text:str) -> str: # Using cached word splitting model, isolate largest section of text string_length = len(text) @@ -206,10 +265,15 @@ def _isolate_largest_section(self, text:str) -> str: start_indx = max(0, string_length - num_chars_to_process) substring = text[start_indx: string_length] - substring_list = self._sentence_model.split(substring, lang = self._in_lang) - div_index = string_length - len(substring_list[-1]) + substring_list = self._sentence_model.split(substring, lang=self._in_lang) + if not substring_list: + return text + last = substring_list[-1] + if not last: + return text + div_index = string_length - len(last) - if div_index==start_indx: + if div_index == start_indx: return text return text[0:div_index] @@ -218,17 +282,62 @@ def _isolate_largest_section(self, text:str) -> str: def split(cls, text: str, limit: int, num_boundary_chars: int, get_text_size: Callable[[str], int], sentence_model: TextSplitterModel, - in_lang: Optional[str] = None - ): - return cls(text, limit, num_boundary_chars, get_text_size, sentence_model, in_lang)._split() - + in_lang: Optional[str] = None, + split_mode: str = 'DEFAULT', + newline_behavior: NewLineBehavior.Behavior = 'GUESS', + preferred_limit: int = -1 + ): + return cls( + text, limit, num_boundary_chars, get_text_size, + sentence_model, in_lang, split_mode, newline_behavior, + preferred_limit + )._split() def _split(self): - if self._text_full_size <= self._limit: + if self._split_mode == 'SENTENCE': + yield from self._split_sentences_individually() + else: + yield from self._split_default() + + def _split_default(self): + effective_limit = min(self._preferred_limit, self._limit) + + if self._text_full_size <= effective_limit: yield self._text else: yield from self._split_internal(self._text) + def _split_sentences_individually(self): + """ + Yield one sentence at a time. If any individual sentence exceeds the limit, + reuse the internal chunking logic to subdivide that sentence. + """ + sentences = self._sentence_model.split(self._text, lang=self._in_lang) + for sentence in sentences: + if self._get_text_size(sentence) <= self._limit: + yield sentence + else: + # Split oversized sentence using the default internal logic. + yield from self._split_sentence_text(sentence) + + def _split_sentence_text(self, text: str): + saved = ( + self._text, + self._text_full_size, + self._overhead_size, + self._soft_limit + ) + try: + self.set_text(text) + yield from self._split_internal(text) + finally: + ( + self._text, + self._text_full_size, + self._overhead_size, + self._soft_limit + ) = saved + def _split_internal(self, text): right = text while True: @@ -237,27 +346,119 @@ def _split_internal(self, text): if not right: return + + def _compute_breakpoints_from_sentences(self, text: str, pieces: List[str]) -> List[int]: + """ + Align sentence pieces back onto `text` to produce true breakpoint indices. + This avoids drift when the sentence model trims/normalizes whitespace. + Returns indices `bp` such that `text[:bp]` ends at a real sentence boundary. + """ + break_pts: List[int] = [] + pos = 0 + + for s in pieces: + if not s: + continue + + # Try exact match first + idx = text.find(s, pos) + + if idx == -1: + # Common issue: text models trim surrounding whitespace; try stripped piece + s2 = s.strip() + if not s2: + continue + idx = text.find(s2, pos) + if idx == -1: + # Could not align; stop and use whatever we have so far. + # (Better to have partial breakpoints than wrong ones.) + log.debug("Sentence alignment failed; using partial breakpoints.") + return break_pts + s = s2 + + end = idx + len(s) + if 0 < end <= len(text): + break_pts.append(end) + pos = end + + # Ensure sorted unique breakpoints + return sorted(set(break_pts)) + def _divide(self, text) -> Tuple[str, str]: - limit = self._soft_limit - while True: - left = text[:limit] - left_size = self._get_text_size(left) - - if left_size <= self._limit: - if left != text: - # If dividing into two parts - # Determine soft boundary for left segment - left = self._isolate_largest_section(left) - return left, text[len(left):] + max_limit = self._limit + soft_limit = self._soft_limit + preferred_enabled = (self._preferred_limit < self._limit) - char_per_size = len(left) / left_size + # Always start with the existing max/guess + limit = soft_limit + while True: + left_window = text[:limit] + left_size = self._get_text_size(left_window) + + if left_size <= max_limit: + # If preferred is enabled and this remainder is still larger than preferred, + # split even if left_window == text. + prefer_split = preferred_enabled and (left_size > self._preferred_limit) + + # If not using preferred logic, preserve original behavior: + if not prefer_split: + if left_window != text: + left = self._isolate_largest_section(left_window) + else: + left = left_window + else: + sents = self._sentence_model.split(left_window, lang=self._in_lang) or [] + + break_pts = self._compute_breakpoints_from_sentences(left_window, sents) + + # If left_window == text and we need to split, don't allow choosing full length. + if left_window == text: + desired_max = len(left_window) - 1 if len(left_window) > 1 else 1 + else: + desired_max = len(left_window) + + local_chars_per_token = len(left_window) / max(left_size, 1) + local_target = int(self._preferred_limit * local_chars_per_token) - self._overhead_size + target = max(1, min(desired_max, local_target)) + + # Always end on a breakpoint if any exist. + chosen: Optional[int] = None + if break_pts: + # Prefer the first breakpoint at/after target (slightly over is fine). + i = bisect.bisect_left(break_pts, target) + if i < len(break_pts): + chosen = break_pts[i] + else: + chosen = break_pts[-1] + else: + chosen = target + + # Fallback: + if chosen is None or chosen <= 0: + chosen = target + elif left_window == text and chosen >= len(left_window): + chosen = target + + left = left_window[:chosen] + + cut = len(left) + if 0 < cut < len(text) and text[cut - 1].isalnum() and text[cut].isalnum(): + m = _LAST_WS_RE.search(left) + if m: + left = left[:m.end()] + + # Worst-case, but extremely unlikely to happen. + if left == "" and text != "": + left = text[:1] - limit = int(self._limit * char_per_size) - self._overhead_size + return left, text[len(left):] + char_per_size = len(left_window) / max(left_size, 1) + limit = int(max_limit * char_per_size) - self._overhead_size if limit < 1: - # Caused by an unusually large overhead relative to text. - # This is unlikely to occur except during testing of small text limits. - # Recalculate soft limit by subtracting overhead from limit before - # applying chars_per_size weighting. - limit = max(1, int((self._limit - self._overhead_size) * char_per_size)) + # Caused by an unusually large overhead relative to text. + # This is unlikely to occur except during testing of small text limits. + # Recalculate soft limit by subtracting overhead from limit before + # applying chars_per_size weighting. + limit = max(1, int((max_limit - self._overhead_size) * char_per_size)) \ No newline at end of file diff --git a/detection/nlp_text_splitter/nlp_text_splitter/newline_behavior.py b/detection/nlp_text_splitter/nlp_text_splitter/newline_behavior.py new file mode 100644 index 0000000..8424438 --- /dev/null +++ b/detection/nlp_text_splitter/nlp_text_splitter/newline_behavior.py @@ -0,0 +1,155 @@ +############################################################################# +# NOTICE # +# # +# This software (or technical data) was produced for the U.S. Government # +# under contract, and is subject to the Rights in Data-General Clause # +# 52.227-14, Alt. IV (DEC 2007). # +# # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # +############################################################################# + +############################################################################# +# Copyright 2025 The MITRE Corporation # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################# + +from __future__ import annotations + +import bisect +import re +from typing import Callable, Literal, Optional, Union + +import mpf_component_api as mpf + +# Languages that typically do NOT use spaces between words +NO_SPACE_LANGS = ('JA', 'YUE', 'ZH-HANS', 'ZH-HANT') + +class ChineseAndJapaneseCodePoints: + # From http://www.unicode.org/charts/ + RANGES = sorted(( + range(0x2e80, 0x2fe0), + range(0x2ff0, 0x3130), + range(0x3190, 0x3300), + range(0x3400, 0x4dc0), + range(0x4e00, 0xa4d0), + range(0xf900, 0xfb00), + range(0xfe10, 0xfe20), + range(0xfe30, 0xfe70), + range(0xff00, 0xffa0), + range(0x16f00, 0x16fa0), + range(0x16fe0, 0x18d09), + range(0x1b000, 0x1b300), + range(0x1f200, 0x1f300), + range(0x20000, 0x2a6de), + range(0x2a700, 0x2ebe1), + range(0x2f800, 0x2fa20), + range(0x30000, 0x3134b) + ), key=lambda r: r.start) + + RANGE_BEGINS = [r.start for r in RANGES] + + @classmethod + def check_char(cls, char: str) -> bool: + """ + Determine whether or not the given character is in the Unicode code point ranges assigned + to Chinese and Japanese. + """ + code_point = ord(char[0]) + if code_point < cls.RANGE_BEGINS[0]: + return False + else: + idx = bisect.bisect_right(cls.RANGE_BEGINS, code_point) + return code_point in cls.RANGES[idx - 1] + + +class NewLineBehavior: + """ + Provides a callable to normalize *single* newline events while preserving intended breaks. + Modes: + - 'GUESS' : choose ' ' for space-separated langs; '' for CJK. + - 'SPACE' : always replace with a single space. + - 'REMOVE' : always remove (no space). + - 'NONE' : no change. + + Users can also provide a custom callable to augment NewLineBehavior. + """ + + Behavior = Union[ + Literal['GUESS', 'SPACE', 'REMOVE', 'NONE'], + Callable[[str, Optional[str]], str], + None + + ] + + @classmethod + def get(cls, behavior: Behavior) -> Callable[[str, Optional[str]], str]: + if callable(behavior): + return behavior + + # Default to GUESS if None or invalid string + if behavior is None: + behavior = 'GUESS' + + behavior = behavior.upper() + + if behavior == 'GUESS': + return lambda s, l: cls._replace_new_lines(s, cls._guess_lang_separator(s, l)) + elif behavior == 'REMOVE': + return lambda s, _: cls._replace_new_lines(s, '') + elif behavior == 'SPACE': + return lambda s, _: cls._replace_new_lines(s, ' ') + elif behavior == 'NONE': + return lambda s, _: s + else: + raise mpf.DetectionError.INVALID_PROPERTY.exception( + f'"{behavior}" is not a valid value for the "STRIP_NEW_LINE_BEHAVIOR" property. ' + 'Valid value are GUESS, REMOVE, SPACE, NONE.') + + @staticmethod + def _guess_lang_separator(text: str, language: Optional[str]) -> Literal['', ' ']: + if language: + if language.upper() in NO_SPACE_LANGS: + return '' + else: + return ' ' + else: + first_alpha_letter = next((ch for ch in text if ch.isalpha()), 'a') + if ChineseAndJapaneseCodePoints.check_char(first_alpha_letter): + return '' + else: + return ' ' + + + REPLACE_NEW_LINE_REGEX = re.compile(r''' + \s? # Include preceding whitespace character if present + (? str: + + def do_replacement(match: Match[str]) -> str: + match_text = match.group(0) + if match_text == '\n': + # Surrounding characters are not whitespace. + return replacement + else: + # There is already whitespace next to newline character, so it can just be removed. + return match_text.replace('\n', '', 1) + + return cls.REPLACE_NEW_LINE_REGEX.sub(do_replacement, text) \ No newline at end of file diff --git a/detection/nlp_text_splitter/nlp_text_splitter/wtp_lang_settings.py b/detection/nlp_text_splitter/nlp_text_splitter/wtp_lang_settings.py index c682fd3..3111aa4 100644 --- a/detection/nlp_text_splitter/nlp_text_splitter/wtp_lang_settings.py +++ b/detection/nlp_text_splitter/nlp_text_splitter/wtp_lang_settings.py @@ -5,11 +5,11 @@ # under contract, and is subject to the Rights in Data-General Clause # # 52.227-14, Alt. IV (DEC 2007). # # # -# Copyright 2024 The MITRE Corporation. All Rights Reserved. # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # ############################################################################# ############################################################################# -# Copyright 2024 The MITRE Corporation # +# Copyright 2025 The MITRE Corporation # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # @@ -234,15 +234,156 @@ class WtpLanguageSettings: 'cmn':'zh' # In some cases we use 'cmn' = 'Mandarin' } + # iso mappings for Flores-200 not recognized by + # WtpLanguageSettings.convert_to_iso() + _flores_to_wtpsplit_iso_639_1 = { + 'ace_arab': 'ar', # Acehnese Arabic + 'ace_latn': 'id', # Acehnese Latin + 'acm_arab': 'ar', # Mesopotamian Arabic + 'acq_arab': 'ar', # Ta’izzi-Adeni Arabic + 'aeb_arab': 'ar', # Tunisian Arabic + 'ajp_arab': 'ar', # South Levantine Arabic + 'aka_latn': 'ak', # Akan + 'als_latn': 'sq', # Albanian (Gheg) + 'apc_arab': 'ar', # North Levantine Arabic + 'arb_arab': 'ar', # Standard Arabic + 'arb_latn': 'ar', # Standard Arabic (Latin script) + 'ars_arab': 'ar', # Najdi Arabic + 'ary_arab': 'ar', # Moroccan Arabic + 'arz_arab': 'ar', # Egyptian Arabic + 'asm_beng': 'bn', # Assamese + 'ast_latn': 'es', # Asturian + 'awa_deva': 'hi', # Awadhi + 'ayr_latn': 'es', # Aymara + 'azb_arab': 'az', # South Azerbaijani + 'azj_latn': 'az', # North Azerbaijani + 'bak_cyrl': 'ru', # Bashkir + 'bam_latn': 'fr', # Bambara + 'ban_latn': 'id', # Balinese + 'bem_latn': 'sw', # Bemba + 'bho_deva': 'hi', # Bhojpuri + 'bjn_latn': 'id', # Banjar + 'bod_tibt': 'bo', # Tibetan + 'bos_latn': 'bs', # Bosnian + 'bug_latn': 'id', # Buginese + 'cjk_latn': 'id', # Chokwe (approx) + 'ckb_arab': 'ku', # Central Kurdish (Sorani) + 'crh_latn': 'tr', # Crimean Tatar + 'dik_latn': 'ar', # Dinka + 'dyu_latn': 'fr', # Dyula + 'dzo_tibt': 'dz', # Dzongkha + 'ewe_latn': 'ee', # Ewe + 'fao_latn': 'fo', # Faroese + 'fij_latn': 'fj', # Fijian + 'fon_latn': 'fr', # Fon + 'fur_latn': 'it', # Friulian + 'fuv_latn': 'ha', # Nigerian Fulfulde + 'gaz_latn': 'om', # Oromo + 'grn_latn': 'es', # Guarani + 'hat_latn': 'fr', # Haitian Creole + 'hne_deva': 'hi', # Chhattisgarhi + 'hrv_latn': 'hr', # Croatian + 'ilo_latn': 'tl', # Ilocano + 'kab_latn': 'fr', # Kabyle + 'kac_latn': 'my', # Jingpho/Kachin + 'kam_latn': 'sw', # Kamba + 'kas_deva': 'hi', # Kashmiri + 'kbp_latn': 'fr', # Kabiyè + 'kea_latn': 'pt', # Cape Verdean Creole + 'khk_cyrl': 'mn', # Halh Mongolian + 'kik_latn': 'sw', # Kikuyu + 'kin_latn': 'rw', # Kinyarwanda + 'kmb_latn': 'pt', # Kimbundu + 'kmr_latn': 'ku', # Kurmanji Kurdish + 'knc_latn': 'ha', # Kanuri + 'kon_latn': 'fr', # Kongo + 'lao_laoo': 'lo', # Lao + 'lij_latn': 'it', # Ligurian + 'lim_latn': 'nl', # Limburgish + 'lin_latn': 'fr', # Lingala + 'lmo_latn': 'it', # Lombard + 'ltg_latn': 'lv', # Latgalian + 'ltz_latn': 'lb', # Luxembourgish + 'lua_latn': 'fr', # Luba-Kasai + 'lug_latn': 'lg', # Ganda + 'luo_latn': 'luo', # Luo + 'lus_latn': 'hi', # Mizo + 'lvs_latn': 'lv', # Latvian + 'mag_deva': 'hi', # Magahi + 'mai_deva': 'hi', # Maithili + 'min_latn': 'id', # Minangkabau + 'mni_beng': 'bn', # Manipuri (Meitei) + 'mos_latn': 'fr', # Mossi + 'mri_latn': 'mi', # Maori + 'nno_latn': 'no', # Norwegian Nynorsk + 'nob_latn': 'no', # Norwegian Bokmål + 'npi_deva': 'ne', # Nepali + 'nso_latn': 'st', # Northern Sotho + 'nus_latn': 'ar', # Nuer + 'nya_latn': 'ny', # Chichewa + 'oci_latn': 'oc', # Occitan + 'ory_orya': 'or', # Odia + 'pag_latn': 'tl', # Pangasinan + 'pap_latn': 'es', # Papiamento + 'pbt_arab': 'ps', # Southern Pashto + 'pes_arab': 'fa', # Iranian Persian (Farsi) + 'plt_latn': 'mg', # Plateau Malagasy + 'prs_arab': 'fa', # Dari Persian + 'quy_latn': 'qu', # Quechua + 'run_latn': 'rn', # Rundi + 'sag_latn': 'fr', # Sango + 'san_deva': 'sa', # Sanskrit + 'sat_olck': 'hi', # Santali + 'scn_latn': 'it', # Sicilian + 'shn_mymr': 'my', # Shan + 'smo_latn': 'sm', # Samoan + 'sna_latn': 'sn', # Shona + 'snd_arab': 'sd', # Sindhi + 'som_latn': 'so', # Somali + 'sot_latn': 'st', # Southern Sotho + 'srd_latn': 'sc', # Sardinian + 'ssw_latn': 'ss', # Swati + 'sun_latn': 'su', # Sundanese + 'swh_latn': 'sw', # Swahili + 'szl_latn': 'pl', # Silesian + 'taq_latn': 'ber', # Tamasheq + 'tat_cyrl': 'tt', # Tatar + 'tgl_latn': 'tl', # Tagalog + 'tir_ethi': 'ti', # Tigrinya + 'tpi_latn': 'tpi', # Tok Pisin + 'tsn_latn': 'tn', # Tswana + 'tso_latn': 'ts', # Tsonga + 'tuk_latn': 'tk', # Turkmen + 'tum_latn': 'ny', # Tumbuka + 'twi_latn': 'ak', # Twi + 'tzm_tfng': 'ber', # Central Atlas Tamazight + 'uig_arab': 'ug', # Uyghur + 'umb_latn': 'pt', # Umbundu + 'uzn_latn': 'uz', # Uzbek + 'vec_latn': 'it', # Venetian + 'war_latn': 'tl', # Waray + 'wol_latn': 'wo', # Wolof + 'ydd_hebr': 'yi', # Yiddish + 'yue_hant': 'zh', # Yue Chinese (Cantonese) + 'zsm_latn': 'ms', # Malay + } + + _wtp_iso_set = set(_wtp_lang_map.values()) @classmethod - def convert_to_iso(cls, lang: str) -> Optional[str]: + def convert_to_iso(cls, lang: Optional[str]) -> Optional[str]: # ISO 639-2 (language) is sometimes paired with ISO 15924 (script). # Extract the language portion and check if supported in WtP. if not lang: return None + # 1) Handle Flores/NLLB codes first (e.g. "arb_Arab") + norm = lang.strip().lower().replace('-', '_') + mapped = cls._flores_to_wtpsplit_iso_639_1.get(norm) + if mapped: + lang = mapped + if '-' in lang: lang = lang.split('-')[0] if '_' in lang: @@ -252,8 +393,6 @@ def convert_to_iso(cls, lang: str) -> Optional[str]: if lang in cls._wtp_iso_set: return lang - if lang in cls._wtp_lang_map: return cls._wtp_lang_map[lang] - return None diff --git a/detection/nlp_text_splitter/pyproject.toml b/detection/nlp_text_splitter/pyproject.toml index 4ca28ea..b969657 100644 --- a/detection/nlp_text_splitter/pyproject.toml +++ b/detection/nlp_text_splitter/pyproject.toml @@ -5,11 +5,11 @@ # under contract, and is subject to the Rights in Data-General Clause # # 52.227-14, Alt. IV (DEC 2007). # # # -# Copyright 2024 The MITRE Corporation. All Rights Reserved. # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # ############################################################################# ############################################################################# -# Copyright 2024 The MITRE Corporation # +# Copyright 2025 The MITRE Corporation # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # diff --git a/detection/nlp_text_splitter/tests/test_text_splitter.py b/detection/nlp_text_splitter/tests/test_text_splitter.py index 9782870..f84c26b 100644 --- a/detection/nlp_text_splitter/tests/test_text_splitter.py +++ b/detection/nlp_text_splitter/tests/test_text_splitter.py @@ -5,11 +5,11 @@ # under contract, and is subject to the Rights in Data-General Clause # # 52.227-14, Alt. IV (DEC 2007). # # # -# Copyright 2024 The MITRE Corporation. All Rights Reserved. # +# Copyright 2025 The MITRE Corporation. All Rights Reserved. # ############################################################################# ############################################################################# -# Copyright 2024 The MITRE Corporation # +# Copyright 2025 The MITRE Corporation # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # @@ -38,6 +38,21 @@ def setUpClass(cls): cls.wtp_model = TextSplitterModel("wtp-bert-mini", "cpu", "en") cls.wtp_adv_model = TextSplitterModel("wtp-canine-s-1l", "cpu", "zh") cls.spacy_model = TextSplitterModel("xx_sent_ud_sm", "cpu", "en") + cls.sat_model = TextSplitterModel("sat-3l-sm", "cpu", "en") + + def test_sat_basic_sentence_split(self): + input_text = 'Hello, what is your name? My name is John.' + actual = list(TextSplitter.split(input_text, + 100, + 100, + len, + self.sat_model, + split_mode='SENTENCE')) + self.assertEqual(2, len(actual)) + self.assertEqual('Hello, what is your name? ', actual[0]) + self.assertEqual('My name is John.', actual[1]) + + def test_split_engine_difference(self): # Note: Only WtP's multilingual models @@ -58,8 +73,14 @@ def test_split_engine_difference(self): actual = self.wtp_model._split_wtp(text) self.assertEqual(10, len(actual)) + # SaT seems to try to split using additional features, in addition to newlines. + actual = self.sat_model._split_sat(text) + self.assertEqual(16, len(actual)) + def test_guess_split_simple_sentence(self): - input_text = 'Hello, what is your name? My name is John.' + input_text = 'Hello, what is your name? My name is John. C. Finn.' + + # WtP Produces a clean split. actual = list(TextSplitter.split(input_text, 28, 28, @@ -71,9 +92,24 @@ def test_guess_split_simple_sentence(self): # "Hello, what is your name?" self.assertEqual('Hello, what is your name? ', actual[0]) # " My name is John." - self.assertEqual('My name is John.', actual[1]) + self.assertEqual('My name is John. C. Finn.', actual[1]) + + # Seems SaT is a bit more aggressive at splitting text. + actual = list(TextSplitter.split(input_text, + 500, + 500, + len, + self.sat_model, + split_mode='SENTENCE')) + self.assertEqual(input_text, ''.join(actual)) + self.assertEqual(3, len(actual)) + + # "Hello, what is your name?" + self.assertEqual('Hello, what is your name? ', actual[0]) + # " My name is John." + self.assertEqual('My name is John. ', actual[1]) + self.assertEqual('C. Finn.', actual[2]) - input_text = 'Hello, what is your name? My name is John.' actual = list(TextSplitter.split(input_text, 28, 28, @@ -85,7 +121,7 @@ def test_guess_split_simple_sentence(self): # "Hello, what is your name?" self.assertEqual('Hello, what is your name? ', actual[0]) # " My name is John." - self.assertEqual('My name is John.', actual[1]) + self.assertEqual('My name is John. C. Finn.', actual[1]) def test_split_sentence_end_punctuation(self): input_text = 'Hello. How are you? asdfasdf' @@ -124,7 +160,8 @@ def test_guess_split_edge_cases(self): 30, 30, len, - self.wtp_model)) + self.wtp_model, + newline_behavior = "NONE")) self.assertEqual(input_text, ''.join(actual)) self.assertEqual(4, len(actual)) @@ -135,11 +172,30 @@ def test_guess_split_edge_cases(self): self.assertEqual("Maybe...maybe not? \n ", actual[2]) self.assertEqual("All done, I think!", actual[3]) + # Split using WtP model. + actual = list(TextSplitter.split(input_text, + 30, + 30, + len, + self.wtp_model, + newline_behavior = "GUESS")) + + self.assertEqual(input_text.replace('\n',''), ''.join(actual)) + self.assertEqual(4, len(actual)) + + # WtP should detect and split out each sentence + self.assertEqual("This is a sentence (Dr.Test). ", actual[0]) + self.assertEqual("Is this, a sentence as well? ", actual[1]) + self.assertEqual("Maybe...maybe not? ", actual[2]) + self.assertEqual("All done, I think!", actual[3]) + + actual = list(TextSplitter.split(input_text, 35, 35, len, - self.spacy_model)) + self.spacy_model, + newline_behavior = "NONE")) self.assertEqual(input_text, ''.join(actual)) self.assertEqual(4, len(actual))