Deep Q-Network (DQN) - CartPole (From Scratch)

This is a Deep Q-Network (DQN) agent trained from scratch on the CartPole-v1 environment using PyTorch and Gymnasium.

Model Description

  • Algorithm: Deep Q-Network (DQN)
  • Environment: CartPole-v1
  • Implementation: From scratch using PyTorch
  • Training Episodes: 500
  • Success Rate: 100%
  • Average Reward: 500.00 (Perfect Score!)

Architecture

Q-Network

Input (4) โ†’ FC(128) โ†’ ReLU โ†’ FC(128) โ†’ ReLU โ†’ Output(2)
Total Parameters: 17,410
  • State Space: 4 continuous values (cart position, velocity, pole angle, angular velocity)
  • Action Space: 2 discrete actions (push left, push right)
  • Hidden Layers: 2 layers with 128 neurons each
  • Activation: ReLU

Training Details

Hyperparameters

  • Learning Rate: 0.001
  • Discount Factor (ฮณ): 0.99
  • Epsilon Start: 1.0
  • Epsilon End: 0.01
  • Epsilon Decay: 0.995
  • Batch Size: 64
  • Replay Buffer: 10,000
  • Target Network Update: Every 10 episodes

Key Techniques

  • Experience Replay: Random sampling from replay buffer to break correlation
  • Target Network: Separate network for stable Q-value targets
  • Epsilon-Greedy: Exploration-exploitation balance

Performance

The agent achieves perfect performance (500/500 steps):

  • Balances the pole for maximum duration
  • 100% success rate on evaluation
  • Environment considered "solved" at 195+ average reward

Training Progress

Usage

import torch
import gymnasium as gym
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="aryannzzz/dqn-cartpole-scratch",
    filename="dqn_cartpole.pth"
)

# Define Q-Network architecture
class QNetwork(torch.nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(state_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.network(x)

# Load model
checkpoint = torch.load(model_path, map_location='cpu')
q_network = QNetwork(state_dim=4, action_dim=2)
q_network.load_state_dict(checkpoint['q_network_state_dict'])
q_network.eval()

# Use the agent
env = gym.make('CartPole-v1')
state, _ = env.reset()

done = False
total_reward = 0

while not done:
    state_tensor = torch.FloatTensor(state).unsqueeze(0)
    with torch.no_grad():
        q_values = q_network(state_tensor)
    action = q_values.argmax().item()
    
    state, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    done = terminated or truncated

print(f"Total Reward: {total_reward}")
env.close()

Model Files

  • dqn_cartpole.pth - Complete model checkpoint (Q-network + optimizer state)
  • dqn_training.png - Training progress visualization

Training Code

This model was trained using custom DQN implementation:

Benchmarks

Metric Value
Average Reward 500.00
Success Rate 100%
Training Episodes 500
Training Time ~2 minutes (GPU)
Parameters 17,410

Limitations

  • Trained specifically for CartPole-v1
  • May require retraining for different physics parameters
  • Neural network approach has higher computational cost than tabular methods

Citation

If you use this model, please reference:

@misc{dqn-cartpole-scratch,
  author = {aryannzzz},
  title = {DQN CartPole Model from Scratch},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/aryannzzz/dqn-cartpole-scratch}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading