diff --git a/funasr/datasets/fun_asr_datasets/datasets.py b/funasr/datasets/fun_asr_datasets/datasets.py index 09bbfa300..af8702a89 100644 --- a/funasr/datasets/fun_asr_datasets/datasets.py +++ b/funasr/datasets/fun_asr_datasets/datasets.py @@ -258,8 +258,9 @@ def __getitem__(self, index): f"text_length: {len(target_ids)} > {self.max_target_length}, drop it: {item}" ) # simulate prev-token fixed output + target_labels = target_ids.copy() if np.random.rand() < self.use_dynamic_output_ratio: - max_len = len(target_ids) + max_len = len(target_labels) min_output_mask_token_len = min(self.min_output_mask_token_len, max_len) min_output_non_mask_token_len = min(self.min_output_non_mask_token_len, max_len) if max_len - min_output_non_mask_token_len > min_output_mask_token_len: @@ -268,10 +269,10 @@ def __getitem__(self, index): else: end_index = max_len - min_output_non_mask_token_len if end_index > 0: - target_ids[:end_index] = [-100] * end_index + target_labels[:end_index] = [-100] * end_index input_ids += source_ids + target_ids - labels += source_mask + target_ids + labels += source_mask + target_labels if len(speech) > 0: fbank.append(speech[0, :, :]) fbank_lens.append(speech_lengths)