Qwen3-0.6B-CoreML-4bit / Examples /Qwen3CoreML.swift
smkrv's picture
Upload folder using huggingface_hub
77146ee verified
//
// Qwen3CoreML.swift
// Qwen3 CoreML Example
//
// Swift Example for Qwen3 CoreML Integration
// Qwen3-0.6B CoreML integration with Stateful KV-Cache and Int4 quantization
//
// Requirements:
// - iOS 18.0+ / macOS 15.0+ (Apple Neural Engine support)
// - 400-500MB RAM for both models
// - swift-transformers package
//
// Usage:
// let qwen3 = Qwen3CoreML()
// await qwen3.loadModels()
// let response = await qwen3.generate("Hello, world!")
//
import Foundation
import CoreML
import Tokenizers
/// Qwen3-0.6B CoreML model wrapper with Stateful KV-Cache
@MainActor
public final class Qwen3CoreML {
// MARK: - Configuration
public struct Config {
public static let maxContextLength = 1024
public static let maxTokens = 512
public static let temperature: Float = 0.7
public static let topK = 40
public static let topP: Float = 0.9
// Model paths (relative to app bundle or absolute)
public static let prefillModelName = "Qwen3-0.6B-Prefill-Int4"
public static let decodeModelName = "Qwen3-0.6B-Decode-Int4"
public static let tokenizerModelId = "Qwen/Qwen3-0.6B"
}
// MARK: - State
private var prefillModel: MLModel?
private var decodeModel: MLModel?
private var tokenizer: Tokenizer?
private var decodeState: MLState?
private(set) var isModelsLoaded = false
private(set) var isGenerating = false
// Qwen3 special tokens
private let eosTokenIds: Set<Int> = [151643, 151645] // <|endoftext|>, <|im_end|>
private let bosTokenId = 151643
private let chatTemplate = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%@<|im_end|>\n<|im_start|>assistant\n"
// Performance tracking
private(set) var tokensPerSecond: Double = 0
private(set) var currentPosition = 0
// MARK: - Initialization
public init() {
print("🤖 Qwen3CoreML initialized")
}
// MARK: - Model Loading
/// Load both Prefill and Decode CoreML models and tokenizer
public func loadModels() async throws {
guard !isModelsLoaded else {
print("🤖 Qwen3: Models already loaded")
return
}
print("🤖 Qwen3: Loading CoreML models and tokenizer...")
do {
// Load Prefill model
try await loadModel(named: Config.prefillModelName, into: &prefillModel)
print("✅ Prefill model loaded")
// Load Decode model with state
try await loadModel(named: Config.decodeModelName, into: &decodeModel, withState: true)
print("✅ Decode model loaded")
// Load tokenizer via Tokenizers framework
tokenizer = try await AutoTokenizer.from(pretrained: Config.tokenizerModelId)
print("✅ Tokenizer loaded")
isModelsLoaded = true
print("🎉 Qwen3 models loaded successfully")
} catch {
print("❌ Failed to load Qwen3 models: \(error.localizedDescription)")
throw Qwen3Error.modelLoadingFailed(error.localizedDescription)
}
}
/// Load a single CoreML model
private func loadModel(named modelName: String, into model: inout MLModel?, withState: Bool = false) async throws {
let config = MLModelConfiguration()
config.computeUnits = .cpuAndNeuralEngine // Use ANE when available
// Try Bundle first, then local paths
var modelURL: URL?
// Check main bundle
if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") {
modelURL = url
}
// Check app support directory
else if let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first {
let appDir = appSupport.appendingPathComponent("Qwen3CoreML")
let modelsDir = appDir.appendingPathComponent("Models")
let modelPath = modelsDir.appendingPathComponent("\(modelName).mlpackage")
if FileManager.default.fileExists(atPath: modelPath.path) {
modelURL = modelPath
}
}
guard let modelURL = modelURL else {
throw Qwen3Error.modelNotFound(modelName)
}
// Compile and load model
let compiledURL = try await MLModel.compileModel(at: modelURL)
model = try MLModel(contentsOf: compiledURL, configuration: config)
// Create state for decode model only
if withState {
decodeState = model?.makeState()
}
}
// MARK: - Text Generation
/// Generate text response for user message (streaming)
public func generate(
userMessage: String,
systemPrompt: String = "You are a helpful assistant.",
maxTokens: Int = Config.maxTokens,
temperature: Float = Config.temperature,
enableThinking: Bool = false
) -> AsyncStream<String> {
AsyncStream { continuation in
Task {
await generateInternal(
userMessage: userMessage,
systemPrompt: systemPrompt,
maxTokens: maxTokens,
temperature: temperature,
enableThinking: enableThinking,
continuation: continuation
)
}
}
}
/// Generate text response for user message (non-streaming)
public func generateSync(
userMessage: String,
systemPrompt: String = "You are a helpful assistant.",
maxTokens: Int = Config.maxTokens,
temperature: Float = Config.temperature,
enableThinking: Bool = false
) async throws -> String {
guard isModelsLoaded, let tokenizer = tokenizer else {
throw Qwen3Error.modelNotLoaded
}
var result = ""
for await chunk in generate(
userMessage: userMessage,
systemPrompt: systemPrompt,
maxTokens: maxTokens,
temperature: temperature,
enableThinking: enableThinking
) {
result += chunk
}
return result
}
/// Reset conversation and KV-Cache state
public func resetConversation() {
decodeState = decodeModel?.makeState()
currentPosition = 0
print("🔄 Qwen3 conversation reset")
}
// MARK: - Private Generation
private func generateInternal(
userMessage: String,
systemPrompt: String,
maxTokens: Int,
temperature: Float,
enableThinking: Bool,
continuation: AsyncStream<String>.Continuation
) async {
guard isModelsLoaded,
let prefillModel = prefillModel,
let decodeModel = decodeModel,
let tokenizer = tokenizer,
var decodeState = decodeState else {
continuation.finish()
return
}
isGenerating = true
let startTime = Date()
defer {
isGenerating = false
continuation.finish()
}
do {
// Format chat prompt
let chatPrompt = formatChatPrompt(
userMessage: userMessage,
systemPrompt: systemPrompt,
enableThinking: enableThinking
)
// Tokenize prompt
let inputTokens = tokenizer.encode(text: chatPrompt)
// Check context length
guard inputTokens.count + maxTokens <= Config.maxContextLength else {
print("⚠️ Prompt too long, truncating...")
// Truncate if needed
let truncatedTokens = Array(inputTokens.suffix(Config.maxContextLength - maxTokens))
// Add BOS token if missing
let tokensToProcess = truncatedTokens.first == bosTokenId ? truncatedTokens : [bosTokenId] + truncatedTokens
try await processTokens(tokensToProcess, model: prefillModel)
}
// Process initial tokens with Prefill model
try await processTokens(inputTokens, model: prefillModel)
// Generate new tokens with Decode model
var generatedTokens: [Int] = []
var isInThinkingBlock = false
for _ in 0..<maxTokens {
let nextToken = try await generateNextToken(
temperature: temperature,
decodeModel: decodeModel,
decodeState: &decodeState
)
// Check for end of generation
if eosTokenIds.contains(nextToken) {
break
}
generatedTokens.append(nextToken)
// Handle thinking blocks (thinking mode)
if nextToken == 151667 { // <think>
isInThinkingBlock = true
} else if nextToken == 151668 { // </think>
isInThinkingBlock = false
if !enableThinking {
continue
}
}
// Decode token to text
let tokenText = tokenizer.decode(tokens: [nextToken])
// Stream token if not in thinking block or thinking enabled
if !isInThinkingBlock || enableThinking {
continuation.yield(tokenText)
}
}
// Calculate performance
let elapsed = Date().timeIntervalSince(startTime)
tokensPerSecond = Double(generatedTokens.count) / elapsed
print("📊 Generation: \(generatedTokens.count) tokens in \(String(format: "%.2f", elapsed))s (\(String(format: "%.1f", tokensPerSecond)) tok/s)")
} catch {
print("❌ Generation failed: \(error.localizedDescription)")
// Note: We don't throw here since continuation is already finished
}
}
/// Process initial tokens using Prefill model
private func processTokens(_ tokens: [Int], model: MLModel) async throws {
let seqLen = tokens.count
// Create causal mask for all tokens
let causalMask = createCausalMask(seqLen: seqLen, totalLen: seqLen)
let inputIdsTensor = MLTensor(
shape: [1, seqLen],
scalars: tokens.map { Int32($0) },
scalarType: Int32.self
)
let inputs = try MLDictionaryFeatureProvider(dictionary: [
"inputIds": MLFeatureValue(tensor: inputIdsTensor),
"causalMask": MLFeatureValue(tensor: causalMask)
])
// Run prefill inference
_ = try await model.prediction(from: inputs)
currentPosition = seqLen
}
/// Generate next token using Decode model
private func generateNextToken(
temperature: Float,
decodeModel: MLModel,
decodeState: inout MLState
) async throws -> Int {
// Current position as input
let positionIds = [Int32(currentPosition)]
let positionTensor = MLTensor(
shape: [1, 1],
scalars: positionIds,
scalarType: Int32.self
)
// We need a dummy input ID, actual logit generation uses past KV cache
let dummyInputTensor = MLTensor(
shape: [1, 1],
scalars: [Int32(0)], // Will be ignored in decode model
scalarType: Int32.self
)
let inputs = try MLDictionaryFeatureProvider(dictionary: [
"inputIds": MLFeatureValue(tensor: dummyInputTensor),
"positionIds": MLFeatureValue(tensor: positionTensor),
])
let output = try await decodeModel.prediction(from: inputs, using: decodeState)
guard let logitsTensor = output.featureValue(for: "logits")?.tensorValue(of: Float16.self) else {
throw Qwen3Error.inferenceError("No logits in model output")
}
// Sample from logits
let nextToken = sampleToken(from: logitsTensor, temperature: temperature)
// Update position for next step
currentPosition += 1
return nextToken
}
/// Sample next token from logits
private func sampleToken(from logitsTensor: MLTensor, temperature: Float) -> Int {
// Extract logits for the last token [1, 1, vocab_size] -> [vocab_size]
let vocabSize = logitsTensor.shape[2]
var logitsArray = [Float](repeating: 0, count: vocabSize)
logitsTensor.withUnsafeBufferPointer(of: Float16.self) { buffer in
for i in 0..<vocabSize {
logitsArray[i] = Float(buffer[vocabSize + i]) // Last token position
}
}
if temperature <= 0 {
// Greedy sampling
return logitsArray.enumerated().max(by: { $0.element < $1.element })?.offset ?? 0
}
// Apply temperature and sample
let scaledLogits = logitsArray.map { $0 / temperature }
let maxLogit = scaledLogits.max() ?? 0
let expLogits = scaledLogits.map { exp($0 - maxLogit) }
let sumExp = expLogits.reduce(0, +)
let probs = expLogits.map { $0 / sumExp }
// Sample from distribution
let random = Float.random(in: 0..<1)
var cumulative: Float = 0
for (index, prob) in probs.enumerated() {
cumulative += prob
if random < cumulative {
return index
}
}
return vocabSize - 1
}
/// Create causal attention mask
private func createCausalMask(seqLen: Int, totalLen: Int) -> MLTensor {
var maskData = [Float16](repeating: Float16(-Float.infinity), count: seqLen * totalLen)
for i in 0..<seqLen {
for j in 0..<(totalLen - seqLen + i + 1) {
maskData[i * totalLen + j] = Float16(0)
}
}
return MLTensor(
shape: [1, 1, seqLen, totalLen],
scalars: maskData,
scalarType: Float16.self
)
}
/// Format chat prompt using Qwen3 chat template
private func formatChatPrompt(userMessage: String, systemPrompt: String, enableThinking: Bool) -> String {
let chatTemplate = "<|im_start|>system\n\(systemPrompt)<|im_end|>\n<|im_start|>user\n\(userMessage)<|im_end|>\n<|im_start|>assistant\n"
if enableThinking {
return chatTemplate
} else {
return chatTemplate + "/no_think\n"
}
}
}
// MARK: - Errors
public enum Qwen3Error: LocalizedError {
case modelNotFound(String)
case modelNotLoaded
case modelLoadingFailed(String)
case inferenceError(String)
case tokenizationError
public var errorDescription: String? {
switch self {
case .modelNotFound(let modelName):
return "Model '\(modelName)' not found. Place it in app bundle or ~/Library/Application Support/Qwen3CoreML/Models/"
case .modelNotLoaded:
return "Models are not loaded. Call loadModels() first."
case .inferenceError(let message):
return "Inference error: \(message)"
case .tokenizationError:
return "Tokenization error"
}
}
}
// MARK: - Helper Methods
/// Extension with utility methods for text processing
extension Qwen3CoreML {
/// Correct text using Qwen3 (compatible with LLMRunner.correct())
public func correct(text: String) async throws -> String {
return try await generateSync(
userMessage: """
Please correct the following text by fixing punctuation, capitalization, and grammatical errors.
Keep the original language. Only output the corrected text, nothing else.
Text: \(text)
Corrected:
""",
systemPrompt: "You are a professional proofreader and text editor.",
maxTokens: 256,
temperature: 0.1 // Low temperature for consistent corrections
).trimmingCharacters(in: .whitespacesAndNewlines)
}
}