Kremon96 commited on
Commit
f163b2f
·
verified ·
1 Parent(s): 861c474

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -7
inference.py CHANGED
@@ -57,14 +57,16 @@ def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str =
57
  fish_ae = build_ae()
58
 
59
  w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
 
 
 
 
60
  if dtype is not None and dtype != torch.float32:
61
- state = st.load_file(w_path, device='cpu')
62
  state = {k: v.to(dtype=dtype) for k, v in state.items()}
63
- state = {k: v.to(device=device) for k, v in state.items()}
64
- fish_ae.load_state_dict(state, strict=False, assign=True)
65
- else:
66
- state = st.load_file(w_path, device=device)
67
- fish_ae.load_state_dict(state, strict=False, assign=True)
68
 
69
  fish_ae = fish_ae.eval().to(device) # Явно перемещаем модель на устройство
70
 
@@ -83,7 +85,13 @@ class PCAState:
83
 
84
  def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
85
  p_path = hf_hub_download(repo_id, filename, token=token)
86
- t = st.load_file(p_path, device=device) # Загружаем напрямую на нужное устройство
 
 
 
 
 
 
87
  return PCAState(
88
  pca_components=t["pca_components"],
89
  pca_mean=t["pca_mean"],
 
57
  fish_ae = build_ae()
58
 
59
  w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
60
+
61
+ # Загружаем на CPU, затем перемещаем на нужное устройство
62
+ state = st.load_file(w_path, device='cpu')
63
+
64
  if dtype is not None and dtype != torch.float32:
 
65
  state = {k: v.to(dtype=dtype) for k, v in state.items()}
66
+
67
+ # Перемещаем на указанное устройство
68
+ state = {k: v.to(device=device) for k, v in state.items()}
69
+ fish_ae.load_state_dict(state, strict=False, assign=True)
 
70
 
71
  fish_ae = fish_ae.eval().to(device) # Явно перемещаем модель на устройство
72
 
 
85
 
86
  def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
87
  p_path = hf_hub_download(repo_id, filename, token=token)
88
+
89
+ # Загружаем на CPU, затем перемещаем на нужное устройство
90
+ t = st.load_file(p_path, device='cpu')
91
+
92
+ # Перемещаем тензоры на указанное устройство
93
+ t = {k: v.to(device=device) for k, v in t.items()}
94
+
95
  return PCAState(
96
  pca_components=t["pca_components"],
97
  pca_mean=t["pca_mean"],