-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
123 lines (99 loc) · 3.85 KB
/
cli.py
File metadata and controls
123 lines (99 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# cli.py
import argparse
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tts_engine.config import TTSConfig
from tts_engine.registry import TTS_REGISTRY
from utils import (
TextChunker,
FileManager,
get_user_confirmation,
calculate_cost,
calculate_total_characters,
)
def create_arg_parser():
parser = argparse.ArgumentParser(description="TTS CLI")
# Add required input file argument
parser.add_argument("input_file", type=Path, help="Input text file path")
# Add engine selection argument
parser.add_argument(
"--engine",
choices=list(TTS_REGISTRY.keys()),
default="kokoro",
help="TTS engine to use",
)
# Add base TTSConfig arguments
for field_name, field in TTSConfig.model_fields.items():
if field_name != "engine_config": # Skip the engine config field
parser.add_argument(
f"--{field_name.replace('_', '-')}",
type=field.annotation,
default=None,
help=field.description,
)
# Create a dict of all unique engine-specific arguments
engine_args = {}
for engine_name, engine_info in TTS_REGISTRY.items():
config_class = engine_info["config"]
for field_name, field in config_class.model_fields.items():
# Skip frozen fields and parent class fields
if not field.frozen and field_name not in TTSConfig.model_fields:
if field_name not in engine_args:
engine_args[field_name] = {
"type": field.annotation,
"description": field.description,
"engines": [engine_name],
}
else:
engine_args[field_name]["engines"].append(engine_name)
# Add engine-specific arguments
engine_group = parser.add_argument_group("Engine-specific options")
for field_name, info in engine_args.items():
engines_str = ", ".join(info["engines"])
engine_group.add_argument(
f"--{field_name.replace('_', '-')}",
type=info["type"],
default=None,
help=f"({engines_str}) {info['description']}",
)
return parser
def main():
parser = create_arg_parser()
args = parser.parse_args()
# Convert args to dict, only including non-None values
cli_args = {k: v for k, v in vars(args).items() if v is not None}
# Create config using factory method
config = TTSConfig.create(args.engine, cli_args)
# Read input text and calculate total characters
with open(args.input_file) as f:
input_text = f.read()
total_chars = calculate_total_characters(input_text, config.chunk_size)
# Calculate and confirm costs if necessary
total_cost = calculate_cost(total_chars, config.engine_config.cost_per_char)
if not get_user_confirmation(total_cost):
print("Operation cancelled by user.")
return
# Create engine instance from registry
engine = TTS_REGISTRY[args.engine]["engine"](config.engine_config)
# Process text
chunker = TextChunker(config.chunk_size)
chunks = chunker.process(input_text)
# Create output directory
FileManager.create_output_dir(config.output_dir)
# Process chunks in parallel
with ThreadPoolExecutor(max_workers=config.max_workers) as executor:
futures = [
executor.submit(
engine.synthesize,
chunk,
config.output_dir / f"output_chunk_{i+1:04d}.wav",
chunk_index=i + 1,
)
for i, chunk in enumerate(chunks)
]
for future in concurrent.futures.as_completed(futures):
result = future.result()
print(f"Generated: {result.output_file}")
if __name__ == "__main__":
main()