bert-tiny-amd / simple_upload.py
Adya662's picture
Upload trained BERT-Tiny AMD model
4523f56 verified
#!/usr/bin/env python3
"""
Simple script to upload model files to Hugging Face Hub
"""
import os
import torch
from transformers import AutoTokenizer
from huggingface_hub import HfApi
import json
from pathlib import Path
# Configuration
REPO_ID = "Adya662/bert-tiny-amd"
MODEL_PATH = "best_enhanced_progressive_amd.pth"
BASE_MODEL = "prajjwal1/bert-tiny"
def create_model_config():
"""Create model configuration"""
config = {
"model_type": "bert",
"architectures": ["BertForSequenceClassification"],
"attention_proxy_dtype": "float32",
"attention_dropout": 0.1,
"classifier_dropout": None,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 128,
"initializer_range": 0.02,
"intermediate_size": 512,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 2,
"num_hidden_layers": 2,
"num_labels": 1,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"problem_type": "single_label_classification",
"torch_dtype": "float32",
"transformers_version": "4.21.0",
"type_vocab_size": 2,
"use_cache": True,
"vocab_size": 30522
}
return config
def create_training_metadata():
"""Create training metadata"""
metadata = {
"model_name": "bert-tiny-amd",
"base_model": "prajjwal1/bert-tiny",
"task": "text-classification",
"dataset": "ElevateNow call center transcripts",
"language": "en",
"license": "mit",
"pipeline_tag": "text-classification",
"tags": [
"text-classification",
"answering-machine-detection",
"bert-tiny",
"binary-classification",
"call-center",
"voice-processing"
],
"performance": {
"validation_accuracy": 0.9394,
"precision": 0.9275,
"recall": 0.8727,
"f1_score": 0.8993
},
"training_details": {
"total_samples": 3548,
"training_samples": 2838,
"validation_samples": 710,
"epochs": 15,
"batch_size": 32,
"learning_rate": 3e-5,
"device": "mps"
}
}
return metadata
def upload_files():
"""Upload files to Hugging Face Hub"""
print("🚀 Starting file upload to Hugging Face Hub...")
# Initialize HF API
api = HfApi()
# Create model configuration
config = create_model_config()
# Save config
with open("config.json", "w") as f:
json.dump(config, f, indent=2)
# Create training metadata
metadata = create_training_metadata()
# Save training metadata
with open("training_metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
# Load and save tokenizer from base model
print("📥 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.save_pretrained(".")
# Copy model weights
if os.path.exists(MODEL_PATH):
print("📥 Copying model weights...")
import shutil
shutil.copy2(MODEL_PATH, "pytorch_model.bin")
print("✅ Model weights copied successfully")
else:
print(f"❌ Model file {MODEL_PATH} not found!")
return False
# Create README.md
readme_content = """---
license: mit
tags:
- text-classification
- answering-machine-detection
- bert-tiny
- binary-classification
- call-center
- voice-processing
pipeline_tag: text-classification
---
# BERT-Tiny AMD Classifier
A lightweight BERT-Tiny model fine-tuned for Answering Machine Detection (AMD) in call center environments.
## Model Description
This model is based on `prajjwal1/bert-tiny` and fine-tuned to classify phone call transcripts as either human or machine (answering machine/voicemail) responses. It's designed for real-time call center applications where quick and accurate detection of answering machines is crucial.
## Model Architecture
- **Base Model**: `prajjwal1/bert-tiny` (2 layers, 128 hidden size, 2 attention heads)
- **Total Parameters**: ~4.4M (lightweight and efficient)
- **Input**: User transcript text (max 128 tokens)
- **Output**: Single logit with sigmoid activation for binary classification
- **Loss Function**: BCEWithLogitsLoss with positive weight for class imbalance
## Performance
- **Validation Accuracy**: 93.94%
- **Precision**: 92.75%
- **Recall**: 87.27%
- **F1-Score**: 89.93%
- **Training Device**: MPS (Apple Silicon GPU)
- **Best Epoch**: 15 (with early stopping)
## Training Data
- **Total Samples**: 3,548 phone call transcripts
- **Training Set**: 2,838 samples
- **Validation Set**: 710 samples
- **Class Distribution**: 30.8% machine calls, 69.2% human calls
- **Source**: ElevateNow call center data
## Usage
### Basic Inference
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("Adya662/bert-tiny-amd")
tokenizer = AutoTokenizer.from_pretrained("Adya662/bert-tiny-amd")
# Prepare input
text = "Hello, this is John speaking"
inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True, padding=True)
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze(-1)
probability = torch.sigmoid(logits).item()
is_machine = probability >= 0.5
print(f"Prediction: {'Machine' if is_machine else 'Human'}")
print(f"Confidence: {probability:.4f}")
```
## Training Details
- **Optimizer**: AdamW with weight decay (0.01)
- **Learning Rate**: 3e-5 with linear scheduling
- **Batch Size**: 32
- **Epochs**: 15 (with early stopping)
- **Early Stopping**: Patience of 3 epochs
- **Class Imbalance**: Handled with positive weight
## Limitations
- Trained on English phone call transcripts
- May not generalize well to other languages or domains
- Performance may vary with different transcription quality
- Designed for short utterances (max 128 tokens)
## License
MIT License - see LICENSE file for details.
"""
with open("README.md", "w") as f:
f.write(readme_content)
# Upload to Hub
print("⬆️ Uploading to Hugging Face Hub...")
try:
api.upload_folder(
folder_path=".",
repo_id=REPO_ID,
repo_type="model",
commit_message="Upload trained BERT-Tiny AMD model"
)
print("✅ Model uploaded successfully!")
print(f"🔗 Model available at: https://huggingface.co/{REPO_ID}")
return True
except Exception as e:
print(f"❌ Upload failed: {e}")
return False
if __name__ == "__main__":
success = upload_files()
if success:
print("\n🎉 Model deployment completed successfully!")
else:
print("\n💥 Model deployment failed!")