aryannzzz commited on
Commit
8346b0e
·
verified ·
1 Parent(s): 70175b7

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +158 -0
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ tags:
4
+ - reinforcement-learning
5
+ - deep-q-network
6
+ - dqn
7
+ - cartpole
8
+ - from-scratch
9
+ - pytorch
10
+ license: mit
11
+ ---
12
+
13
+ # Deep Q-Network (DQN) - CartPole (From Scratch)
14
+
15
+ This is a **Deep Q-Network (DQN) agent** trained from scratch on the CartPole-v1 environment using PyTorch and Gymnasium.
16
+
17
+ ## Model Description
18
+
19
+ - **Algorithm**: Deep Q-Network (DQN)
20
+ - **Environment**: CartPole-v1
21
+ - **Implementation**: From scratch using PyTorch
22
+ - **Training Episodes**: 500
23
+ - **Success Rate**: 100%
24
+ - **Average Reward**: 500.00 (Perfect Score!)
25
+
26
+ ## Architecture
27
+
28
+ ### Q-Network
29
+ ```
30
+ Input (4) → FC(128) → ReLU → FC(128) → ReLU → Output(2)
31
+ Total Parameters: 17,410
32
+ ```
33
+
34
+ - **State Space**: 4 continuous values (cart position, velocity, pole angle, angular velocity)
35
+ - **Action Space**: 2 discrete actions (push left, push right)
36
+ - **Hidden Layers**: 2 layers with 128 neurons each
37
+ - **Activation**: ReLU
38
+
39
+ ## Training Details
40
+
41
+ ### Hyperparameters
42
+ - Learning Rate: 0.001
43
+ - Discount Factor (γ): 0.99
44
+ - Epsilon Start: 1.0
45
+ - Epsilon End: 0.01
46
+ - Epsilon Decay: 0.995
47
+ - Batch Size: 64
48
+ - Replay Buffer: 10,000
49
+ - Target Network Update: Every 10 episodes
50
+
51
+ ### Key Techniques
52
+ - **Experience Replay**: Random sampling from replay buffer to break correlation
53
+ - **Target Network**: Separate network for stable Q-value targets
54
+ - **Epsilon-Greedy**: Exploration-exploitation balance
55
+
56
+ ## Performance
57
+
58
+ The agent achieves **perfect performance** (500/500 steps):
59
+ - Balances the pole for maximum duration
60
+ - 100% success rate on evaluation
61
+ - Environment considered "solved" at 195+ average reward
62
+
63
+ ![Training Progress](dqn_training.png)
64
+
65
+ ## Usage
66
+
67
+ ```python
68
+ import torch
69
+ import gymnasium as gym
70
+ from huggingface_hub import hf_hub_download
71
+
72
+ # Download model
73
+ model_path = hf_hub_download(
74
+ repo_id="aryannzzz/dqn-cartpole-scratch",
75
+ filename="dqn_cartpole.pth"
76
+ )
77
+
78
+ # Define Q-Network architecture
79
+ class QNetwork(torch.nn.Module):
80
+ def __init__(self, state_dim, action_dim, hidden_dim=128):
81
+ super().__init__()
82
+ self.network = torch.nn.Sequential(
83
+ torch.nn.Linear(state_dim, hidden_dim),
84
+ torch.nn.ReLU(),
85
+ torch.nn.Linear(hidden_dim, hidden_dim),
86
+ torch.nn.ReLU(),
87
+ torch.nn.Linear(hidden_dim, action_dim)
88
+ )
89
+
90
+ def forward(self, x):
91
+ return self.network(x)
92
+
93
+ # Load model
94
+ checkpoint = torch.load(model_path, map_location='cpu')
95
+ q_network = QNetwork(state_dim=4, action_dim=2)
96
+ q_network.load_state_dict(checkpoint['q_network_state_dict'])
97
+ q_network.eval()
98
+
99
+ # Use the agent
100
+ env = gym.make('CartPole-v1')
101
+ state, _ = env.reset()
102
+
103
+ done = False
104
+ total_reward = 0
105
+
106
+ while not done:
107
+ state_tensor = torch.FloatTensor(state).unsqueeze(0)
108
+ with torch.no_grad():
109
+ q_values = q_network(state_tensor)
110
+ action = q_values.argmax().item()
111
+
112
+ state, reward, terminated, truncated, _ = env.step(action)
113
+ total_reward += reward
114
+ done = terminated or truncated
115
+
116
+ print(f"Total Reward: {total_reward}")
117
+ env.close()
118
+ ```
119
+
120
+ ## Model Files
121
+
122
+ - `dqn_cartpole.pth` - Complete model checkpoint (Q-network + optimizer state)
123
+ - `dqn_training.png` - Training progress visualization
124
+
125
+ ## Training Code
126
+
127
+ This model was trained using custom DQN implementation:
128
+ - Repository: [RL-Competition-Starter](https://github.com/aryannzzz/RL-Competition-Starter)
129
+ - File: `02_dqn_from_scratch.py`
130
+
131
+ ## Benchmarks
132
+
133
+ | Metric | Value |
134
+ |--------|-------|
135
+ | Average Reward | 500.00 |
136
+ | Success Rate | 100% |
137
+ | Training Episodes | 500 |
138
+ | Training Time | ~2 minutes (GPU) |
139
+ | Parameters | 17,410 |
140
+
141
+ ## Limitations
142
+
143
+ - Trained specifically for CartPole-v1
144
+ - May require retraining for different physics parameters
145
+ - Neural network approach has higher computational cost than tabular methods
146
+
147
+ ## Citation
148
+
149
+ If you use this model, please reference:
150
+ ```
151
+ @misc{dqn-cartpole-scratch,
152
+ author = {aryannzzz},
153
+ title = {DQN CartPole Model from Scratch},
154
+ year = {2025},
155
+ publisher = {HuggingFace},
156
+ url = {https://huggingface.co/aryannzzz/dqn-cartpole-scratch}
157
+ }
158
+ ```