Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +15 -7
inference.py
CHANGED
|
@@ -57,14 +57,16 @@ def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str =
|
|
| 57 |
fish_ae = build_ae()
|
| 58 |
|
| 59 |
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
if dtype is not None and dtype != torch.float32:
|
| 61 |
-
state = st.load_file(w_path, device='cpu')
|
| 62 |
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
fish_ae.load_state_dict(state, strict=False, assign=True)
|
| 68 |
|
| 69 |
fish_ae = fish_ae.eval().to(device) # Явно перемещаем модель на устройство
|
| 70 |
|
|
@@ -83,7 +85,13 @@ class PCAState:
|
|
| 83 |
|
| 84 |
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
|
| 85 |
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return PCAState(
|
| 88 |
pca_components=t["pca_components"],
|
| 89 |
pca_mean=t["pca_mean"],
|
|
|
|
| 57 |
fish_ae = build_ae()
|
| 58 |
|
| 59 |
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
|
| 60 |
+
|
| 61 |
+
# Загружаем на CPU, затем перемещаем на нужное устройство
|
| 62 |
+
state = st.load_file(w_path, device='cpu')
|
| 63 |
+
|
| 64 |
if dtype is not None and dtype != torch.float32:
|
|
|
|
| 65 |
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 66 |
+
|
| 67 |
+
# Перемещаем на указанное устройство
|
| 68 |
+
state = {k: v.to(device=device) for k, v in state.items()}
|
| 69 |
+
fish_ae.load_state_dict(state, strict=False, assign=True)
|
|
|
|
| 70 |
|
| 71 |
fish_ae = fish_ae.eval().to(device) # Явно перемещаем модель на устройство
|
| 72 |
|
|
|
|
| 85 |
|
| 86 |
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
|
| 87 |
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 88 |
+
|
| 89 |
+
# Загружаем на CPU, затем перемещаем на нужное устройство
|
| 90 |
+
t = st.load_file(p_path, device='cpu')
|
| 91 |
+
|
| 92 |
+
# Перемещаем тензоры на указанное устройство
|
| 93 |
+
t = {k: v.to(device=device) for k, v in t.items()}
|
| 94 |
+
|
| 95 |
return PCAState(
|
| 96 |
pca_components=t["pca_components"],
|
| 97 |
pca_mean=t["pca_mean"],
|