Spaces:
Running
Running
Update inference.py
Browse files- inference.py +3 -3
inference.py
CHANGED
|
@@ -18,7 +18,7 @@ SampleFn = Callable[
|
|
| 18 |
]
|
| 19 |
### Loading
|
| 20 |
|
| 21 |
-
def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = '
|
| 22 |
with torch.device('meta'):
|
| 23 |
model = EchoDiT(
|
| 24 |
latent_size=80, model_size=2048, num_layers=24, num_heads=16,
|
|
@@ -50,7 +50,7 @@ def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cu
|
|
| 50 |
|
| 51 |
return model
|
| 52 |
|
| 53 |
-
def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = '
|
| 54 |
# have not tested lower precisions with fish AE yet
|
| 55 |
|
| 56 |
with torch.device('meta'):
|
|
@@ -81,7 +81,7 @@ class PCAState:
|
|
| 81 |
pca_mean: torch.Tensor
|
| 82 |
latent_scale: float
|
| 83 |
|
| 84 |
-
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = '
|
| 85 |
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 86 |
t = st.load_file(p_path, device=device)
|
| 87 |
return PCAState(
|
|
|
|
| 18 |
]
|
| 19 |
### Loading
|
| 20 |
|
| 21 |
+
def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
|
| 22 |
with torch.device('meta'):
|
| 23 |
model = EchoDiT(
|
| 24 |
latent_size=80, model_size=2048, num_layers=24, num_heads=16,
|
|
|
|
| 50 |
|
| 51 |
return model
|
| 52 |
|
| 53 |
+
def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cpu', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
|
| 54 |
# have not tested lower precisions with fish AE yet
|
| 55 |
|
| 56 |
with torch.device('meta'):
|
|
|
|
| 81 |
pca_mean: torch.Tensor
|
| 82 |
latent_scale: float
|
| 83 |
|
| 84 |
+
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
|
| 85 |
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 86 |
t = st.load_file(p_path, device=device)
|
| 87 |
return PCAState(
|