Sakalti commited on
Commit
23ec8a5
·
verified ·
1 Parent(s): 6c25565

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from datasets import load_dataset
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
5
+ import torch
6
+ import subprocess
7
+
8
+ def finetune(model_name, hf_token, upload_repo):
9
+ os.environ["HF_TOKEN"] = hf_token
10
+
11
+ # トークナイザとモデル準備
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
13
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token)
14
+
15
+ # データセット読み込み(日本語チャット)
16
+ dataset = load_dataset("rinna/llm-japanese-dataset-v1", split="train")
17
+
18
+ # 前処理
19
+ def tokenize_fn(example):
20
+ return tokenizer(example["text"], truncation=True, max_length=512)
21
+
22
+ tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
23
+
24
+ # データコラレータ
25
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
26
+
27
+ # トレーニング設定
28
+ training_args = TrainingArguments(
29
+ output_dir="./finetuned_model",
30
+ per_device_train_batch_size=2,
31
+ num_train_epochs=1,
32
+ save_total_limit=1,
33
+ logging_steps=10,
34
+ push_to_hub=True,
35
+ hub_model_id=upload_repo,
36
+ hub_token=hf_token
37
+ )
38
+
39
+ # Trainerセットアップ
40
+ trainer = Trainer(
41
+ model=model,
42
+ args=training_args,
43
+ train_dataset=tokenized_dataset,
44
+ data_collator=data_collator
45
+ )
46
+
47
+ # 学習実行
48
+ trainer.train()
49
+
50
+ # モデルをHugging Face Hubへアップロード
51
+ trainer.push_to_hub()
52
+
53
+ return f"ファインチューニング完了!モデルは https://huggingface.co/{upload_repo} にアップロードされました。"
54
+
55
+ # Gradioインターフェース
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# 日本語チャットモデル 簡易ファインチューニング")
58
+
59
+ model_name = gr.Textbox(label="元モデル名(例:rinna/japanese-gpt-neox-3.6b)")
60
+ hf_token = gr.Textbox(label="Hugging Face トークン", type="password")
61
+ upload_repo = gr.Textbox(label="アップロード先リポジトリ名(例:yourname/finetuned-chat-jp)")
62
+
63
+ start_btn = gr.Button("ファインチューニング開始")
64
+ output = gr.Textbox(label="実行結果")
65
+
66
+ start_btn.click(finetune, inputs=[model_name, hf_token, upload_repo], outputs=output)
67
+
68
+ demo.launch()