Spaces:
Runtime error
Runtime error
| from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification, Trainer, TrainingArguments | |
| import datasets | |
| # Load the pre-trained XLM-Roberta-Large model and tokenizer | |
| model_name = 'xlm-roberta-large' | |
| tokenizer = XLMRobertaTokenizer.from_pretrained(model_name) | |
| model = XLMRobertaForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
| # Load the sentiment analysis dataset | |
| dataset = datasets.load_dataset('imdb') | |
| # Tokenize the dataset | |
| def tokenize(batch): | |
| return tokenizer(batch['text'], padding=True, truncation=True) | |
| dataset = dataset.map(tokenize, batched=True) | |
| # Fine-tune the model on the dataset | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| evaluation_strategy='epoch', | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| num_train_epochs=3, | |
| weight_decay=0.01, | |
| push_to_hub=False, | |
| logging_dir='./logs', | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| metric_for_best_model='accuracy' | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset['train'], | |
| eval_dataset=dataset['test'] | |
| ) | |
| trainer.train() | |
| import torch | |
| # Load the fine-tuned XLM-Roberta-Large model | |
| model_path = './results/checkpoint-1000' | |
| model = XLMRobertaForSequenceClassification.from_pretrained(model_path) | |
| # Create a function that takes a text input and returns the predicted sentiment label | |
| def predict_sentiment(text): | |
| inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = torch.argmax(logits, dim=1) | |
| return 'positive' if predicted_class == 1 else 'negative' | |
| import gradio as gr | |
| # Create a Gradio interface for the predict_sentiment function | |
| iface = gr.Interface( | |
| fn=predict_sentiment, | |
| inputs=gr.inputs.Textbox(placeholder='Enter text here...'), | |
| outputs=gr.outputs.Textbox(placeholder='Sentiment prediction...') | |
| ) | |
| # Launch the interface | |
| iface.launch() |