| | |
| |
|
| | from wtpsplit import SaT |
| | from typing import List |
| | import torch |
| |
|
| |
|
| | |
| | _sat_model = None |
| |
|
| |
|
| | def get_sat_model(model_name: str = "sat-3l-sm", device: str = "cuda") -> SaT: |
| | """ |
| | Get or create global SaT model instance |
| | |
| | Args: |
| | model_name: Model name from segment-any-text |
| | device: Device to run model on |
| | |
| | Returns: |
| | SaT model instance |
| | """ |
| | global _sat_model |
| | |
| | if _sat_model is None: |
| | print(f"Loading SaT model: {model_name}") |
| | _sat_model = SaT(model_name) |
| | |
| | |
| | if device == "cuda" and torch.cuda.is_available(): |
| | _sat_model.half().to("cuda") |
| | print(f"SaT model loaded on GPU") |
| | else: |
| | print(f"SaT model loaded on CPU") |
| | |
| | return _sat_model |
| |
|
| |
|
| | |
| | |
| | |
| | def segment_SaT(text: str) -> List[int]: |
| | """ |
| | Segment text using wtpsplit SaT model |
| | |
| | Args: |
| | text: Input text to segment |
| | |
| | Returns: |
| | List of labels: 0 = word is not the last word of c-unit, |
| | 1 = word is the last word of c-unit |
| | """ |
| | if not text.strip(): |
| | return [] |
| | |
| | |
| | cleaned_text = text.lower().replace(".", "").replace(",", "") |
| | words = cleaned_text.strip().split() |
| | if not words: |
| | return [] |
| | |
| | |
| | sat_model = get_sat_model() |
| | |
| | |
| | try: |
| | sentences = sat_model.split(cleaned_text) |
| | |
| | |
| | word_labels = [0] * len(words) |
| | |
| | |
| | word_idx = 0 |
| | |
| | for sentence in sentences: |
| | sentence_words = sentence.strip().split() |
| | |
| | |
| | if sentence_words: |
| | |
| | sentence_end_idx = word_idx + len(sentence_words) - 1 |
| | |
| | |
| | if sentence_end_idx < len(words): |
| | word_labels[sentence_end_idx] = 1 |
| | |
| | word_idx += len(sentence_words) |
| | |
| | return word_labels |
| | |
| | except Exception as e: |
| | print(f"Error in SaT segmentation: {e}") |
| | return [0] * len(words) |
| |
|
| |
|
| |
|
| | |
| | def reorganize_transcription_c_unit(session_id, base_dir="session_data"): |
| | return |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" |
| | |
| | print(f"Input text: {test_text}") |
| | print(f"Words: {test_text.split()}") |
| | |
| | labels = segment_SaT(test_text) |
| | print(f"Segment labels: {labels}") |
| | |
| | |
| | words = test_text.split() |
| | segments = [] |
| | current_segment = [] |
| | |
| | for word, label in zip(words, labels): |
| | current_segment.append(word) |
| | if label == 1: |
| | segments.append(" ".join(current_segment)) |
| | current_segment = [] |
| | |
| | |
| | if current_segment: |
| | segments.append(" ".join(current_segment)) |
| | |
| | print("\nSegmented text:") |
| | for i, segment in enumerate(segments, 1): |
| | print(f"Segment {i}: {segment}") |