Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 76 additions & 46 deletions deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def export_attachments(self, path: Path):

@torch.no_grad()
def _torch_export_model(self):
is_torch_113 = torch.__version__.startswith('1.13.')

# Prepare inputs for FastSpeech2 and dur predictor tracing
tokens = torch.LongTensor([[1] * 5]).to(self.device)
ph_dur = torch.LongTensor([[3, 5, 2, 1, 4]]).to(self.device)
Expand Down Expand Up @@ -388,33 +390,47 @@ def _torch_export_model(self):
noise = torch.randn(shape, device=self.device)
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)

print(f'Tracing {self.pitch_backbone_class_name} backbone...')
pitch_predictor = self.model.view_as_pitch_predictor()
pitch_predictor.pitch_predictor.set_backbone(
torch.jit.trace(
pitch_predictor.pitch_predictor.backbone,
(
noise,
dummy_time,
condition

if is_torch_113:
print(f'Tracing {self.pitch_backbone_class_name} backbone...')
pitch_predictor.pitch_predictor.set_backbone(
torch.jit.trace(
pitch_predictor.pitch_predictor.backbone,
(
noise,
dummy_time,
condition
)
)
)
)

print(f'Scripting {self.pitch_predictor_class_name}...')
pitch_predictor = torch.jit.script(
pitch_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
print(f'Scripting {self.pitch_predictor_class_name}...')
pitch_predictor = torch.jit.script(
pitch_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
else:
print(f'Wrapping {self.pitch_predictor_class_name} for trace-based export...')

class _PitchPredWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.pitch_predictor = model.pitch_predictor

def forward(self, pitch_cond, steps):
return self.pitch_predictor(pitch_cond, steps=steps)

pitch_predictor = _PitchPredWrapper(pitch_predictor)

print(f'Exporting {self.pitch_predictor_class_name}...')
torch.onnx.export(
Expand Down Expand Up @@ -535,33 +551,47 @@ def _torch_export_model(self):
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()

print(f'Tracing {self.variance_backbone_class_name} backbone...')
multi_var_predictor = self.model.view_as_variance_predictor()
multi_var_predictor.variance_predictor.set_backbone(
torch.jit.trace(
multi_var_predictor.variance_predictor.backbone,
(
noise,
step,
condition

if is_torch_113:
print(f'Tracing {self.variance_backbone_class_name} backbone...')
multi_var_predictor.variance_predictor.set_backbone(
torch.jit.trace(
multi_var_predictor.variance_predictor.backbone,
(
noise,
step,
condition
)
)
)
)

print(f'Scripting {self.multi_var_predictor_class_name}...')
multi_var_predictor = torch.jit.script(
multi_var_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
print(f'Scripting {self.multi_var_predictor_class_name}...')
multi_var_predictor = torch.jit.script(
multi_var_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
else:
print(f'Wrapping {self.multi_var_predictor_class_name} for trace-based export...')

class _VarPredWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.variance_predictor = model.variance_predictor

def forward(self, variance_cond, steps):
return self.variance_predictor(variance_cond, steps=steps)

multi_var_predictor = _VarPredWrapper(multi_var_predictor)

print(f'Exporting {self.multi_var_predictor_class_name}...')
torch.onnx.export(
Expand Down
7 changes: 5 additions & 2 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@


def check_pytorch_version():
# Require PyTorch version to be exactly 1.13.x
import warnings
if torch.__version__.startswith('1.13.'):
return
raise RuntimeError('This script requires PyTorch 1.13.x. Please install the correct version.')
warnings.warn(
f'ONNX export is tested on PyTorch 1.13.x, but you have {torch.__version__}. '
f'Export may not behave as expected with this PyTorch version.'
)


def find_exp(exp):
Expand Down