Kremon96 commited on
Commit
f35045e
·
verified ·
1 Parent(s): 4149510

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -18,7 +18,7 @@ SampleFn = Callable[
18
  ]
19
  ### Loading
20
 
21
- def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
22
  with torch.device('meta'):
23
  model = EchoDiT(
24
  latent_size=80, model_size=2048, num_layers=24, num_heads=16,
@@ -50,7 +50,7 @@ def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cu
50
 
51
  return model
52
 
53
- def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
54
  # have not tested lower precisions with fish AE yet
55
 
56
  with torch.device('meta'):
@@ -81,7 +81,7 @@ class PCAState:
81
  pca_mean: torch.Tensor
82
  latent_scale: float
83
 
84
- def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
85
  p_path = hf_hub_download(repo_id, filename, token=token)
86
  t = st.load_file(p_path, device=device)
87
  return PCAState(
 
18
  ]
19
  ### Loading
20
 
21
+ def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
22
  with torch.device('meta'):
23
  model = EchoDiT(
24
  latent_size=80, model_size=2048, num_layers=24, num_heads=16,
 
50
 
51
  return model
52
 
53
+ def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cpu', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
54
  # have not tested lower precisions with fish AE yet
55
 
56
  with torch.device('meta'):
 
81
  pca_mean: torch.Tensor
82
  latent_scale: float
83
 
84
+ def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cpu', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
85
  p_path = hf_hub_download(repo_id, filename, token=token)
86
  t = st.load_file(p_path, device=device)
87
  return PCAState(