|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import Foundation |
|
|
import CoreML |
|
|
import Tokenizers |
|
|
|
|
|
|
|
|
@MainActor |
|
|
public final class Qwen3CoreML { |
|
|
|
|
|
|
|
|
|
|
|
public struct Config { |
|
|
public static let maxContextLength = 1024 |
|
|
public static let maxTokens = 512 |
|
|
public static let temperature: Float = 0.7 |
|
|
public static let topK = 40 |
|
|
public static let topP: Float = 0.9 |
|
|
|
|
|
|
|
|
public static let prefillModelName = "Qwen3-0.6B-Prefill-Int4" |
|
|
public static let decodeModelName = "Qwen3-0.6B-Decode-Int4" |
|
|
public static let tokenizerModelId = "Qwen/Qwen3-0.6B" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private var prefillModel: MLModel? |
|
|
private var decodeModel: MLModel? |
|
|
private var tokenizer: Tokenizer? |
|
|
private var decodeState: MLState? |
|
|
|
|
|
private(set) var isModelsLoaded = false |
|
|
private(set) var isGenerating = false |
|
|
|
|
|
|
|
|
private let eosTokenIds: Set<Int> = [151643, 151645] |
|
|
private let bosTokenId = 151643 |
|
|
private let chatTemplate = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%@<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
|
|
|
private(set) var tokensPerSecond: Double = 0 |
|
|
private(set) var currentPosition = 0 |
|
|
|
|
|
|
|
|
|
|
|
public init() { |
|
|
print("🤖 Qwen3CoreML initialized") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public func loadModels() async throws { |
|
|
guard !isModelsLoaded else { |
|
|
print("🤖 Qwen3: Models already loaded") |
|
|
return |
|
|
} |
|
|
|
|
|
print("🤖 Qwen3: Loading CoreML models and tokenizer...") |
|
|
|
|
|
do { |
|
|
|
|
|
try await loadModel(named: Config.prefillModelName, into: &prefillModel) |
|
|
print("✅ Prefill model loaded") |
|
|
|
|
|
|
|
|
try await loadModel(named: Config.decodeModelName, into: &decodeModel, withState: true) |
|
|
print("✅ Decode model loaded") |
|
|
|
|
|
|
|
|
tokenizer = try await AutoTokenizer.from(pretrained: Config.tokenizerModelId) |
|
|
print("✅ Tokenizer loaded") |
|
|
|
|
|
isModelsLoaded = true |
|
|
print("🎉 Qwen3 models loaded successfully") |
|
|
|
|
|
} catch { |
|
|
print("❌ Failed to load Qwen3 models: \(error.localizedDescription)") |
|
|
throw Qwen3Error.modelLoadingFailed(error.localizedDescription) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private func loadModel(named modelName: String, into model: inout MLModel?, withState: Bool = false) async throws { |
|
|
let config = MLModelConfiguration() |
|
|
config.computeUnits = .cpuAndNeuralEngine |
|
|
|
|
|
|
|
|
var modelURL: URL? |
|
|
|
|
|
|
|
|
if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") { |
|
|
modelURL = url |
|
|
} |
|
|
|
|
|
else if let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first { |
|
|
let appDir = appSupport.appendingPathComponent("Qwen3CoreML") |
|
|
let modelsDir = appDir.appendingPathComponent("Models") |
|
|
let modelPath = modelsDir.appendingPathComponent("\(modelName).mlpackage") |
|
|
if FileManager.default.fileExists(atPath: modelPath.path) { |
|
|
modelURL = modelPath |
|
|
} |
|
|
} |
|
|
|
|
|
guard let modelURL = modelURL else { |
|
|
throw Qwen3Error.modelNotFound(modelName) |
|
|
} |
|
|
|
|
|
|
|
|
let compiledURL = try await MLModel.compileModel(at: modelURL) |
|
|
model = try MLModel(contentsOf: compiledURL, configuration: config) |
|
|
|
|
|
|
|
|
if withState { |
|
|
decodeState = model?.makeState() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public func generate( |
|
|
userMessage: String, |
|
|
systemPrompt: String = "You are a helpful assistant.", |
|
|
maxTokens: Int = Config.maxTokens, |
|
|
temperature: Float = Config.temperature, |
|
|
enableThinking: Bool = false |
|
|
) -> AsyncStream<String> { |
|
|
AsyncStream { continuation in |
|
|
Task { |
|
|
await generateInternal( |
|
|
userMessage: userMessage, |
|
|
systemPrompt: systemPrompt, |
|
|
maxTokens: maxTokens, |
|
|
temperature: temperature, |
|
|
enableThinking: enableThinking, |
|
|
continuation: continuation |
|
|
) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public func generateSync( |
|
|
userMessage: String, |
|
|
systemPrompt: String = "You are a helpful assistant.", |
|
|
maxTokens: Int = Config.maxTokens, |
|
|
temperature: Float = Config.temperature, |
|
|
enableThinking: Bool = false |
|
|
) async throws -> String { |
|
|
guard isModelsLoaded, let tokenizer = tokenizer else { |
|
|
throw Qwen3Error.modelNotLoaded |
|
|
} |
|
|
|
|
|
var result = "" |
|
|
|
|
|
for await chunk in generate( |
|
|
userMessage: userMessage, |
|
|
systemPrompt: systemPrompt, |
|
|
maxTokens: maxTokens, |
|
|
temperature: temperature, |
|
|
enableThinking: enableThinking |
|
|
) { |
|
|
result += chunk |
|
|
} |
|
|
|
|
|
return result |
|
|
} |
|
|
|
|
|
|
|
|
public func resetConversation() { |
|
|
decodeState = decodeModel?.makeState() |
|
|
currentPosition = 0 |
|
|
print("🔄 Qwen3 conversation reset") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private func generateInternal( |
|
|
userMessage: String, |
|
|
systemPrompt: String, |
|
|
maxTokens: Int, |
|
|
temperature: Float, |
|
|
enableThinking: Bool, |
|
|
continuation: AsyncStream<String>.Continuation |
|
|
) async { |
|
|
guard isModelsLoaded, |
|
|
let prefillModel = prefillModel, |
|
|
let decodeModel = decodeModel, |
|
|
let tokenizer = tokenizer, |
|
|
var decodeState = decodeState else { |
|
|
continuation.finish() |
|
|
return |
|
|
} |
|
|
|
|
|
isGenerating = true |
|
|
let startTime = Date() |
|
|
|
|
|
defer { |
|
|
isGenerating = false |
|
|
continuation.finish() |
|
|
} |
|
|
|
|
|
do { |
|
|
|
|
|
let chatPrompt = formatChatPrompt( |
|
|
userMessage: userMessage, |
|
|
systemPrompt: systemPrompt, |
|
|
enableThinking: enableThinking |
|
|
) |
|
|
|
|
|
|
|
|
let inputTokens = tokenizer.encode(text: chatPrompt) |
|
|
|
|
|
|
|
|
guard inputTokens.count + maxTokens <= Config.maxContextLength else { |
|
|
print("⚠️ Prompt too long, truncating...") |
|
|
|
|
|
let truncatedTokens = Array(inputTokens.suffix(Config.maxContextLength - maxTokens)) |
|
|
|
|
|
let tokensToProcess = truncatedTokens.first == bosTokenId ? truncatedTokens : [bosTokenId] + truncatedTokens |
|
|
try await processTokens(tokensToProcess, model: prefillModel) |
|
|
} |
|
|
|
|
|
|
|
|
try await processTokens(inputTokens, model: prefillModel) |
|
|
|
|
|
|
|
|
var generatedTokens: [Int] = [] |
|
|
var isInThinkingBlock = false |
|
|
|
|
|
for _ in 0..<maxTokens { |
|
|
let nextToken = try await generateNextToken( |
|
|
temperature: temperature, |
|
|
decodeModel: decodeModel, |
|
|
decodeState: &decodeState |
|
|
) |
|
|
|
|
|
|
|
|
if eosTokenIds.contains(nextToken) { |
|
|
break |
|
|
} |
|
|
|
|
|
generatedTokens.append(nextToken) |
|
|
|
|
|
|
|
|
if nextToken == 151667 { |
|
|
isInThinkingBlock = true |
|
|
} else if nextToken == 151668 { |
|
|
isInThinkingBlock = false |
|
|
if !enableThinking { |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let tokenText = tokenizer.decode(tokens: [nextToken]) |
|
|
|
|
|
|
|
|
if !isInThinkingBlock || enableThinking { |
|
|
continuation.yield(tokenText) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let elapsed = Date().timeIntervalSince(startTime) |
|
|
tokensPerSecond = Double(generatedTokens.count) / elapsed |
|
|
print("📊 Generation: \(generatedTokens.count) tokens in \(String(format: "%.2f", elapsed))s (\(String(format: "%.1f", tokensPerSecond)) tok/s)") |
|
|
|
|
|
} catch { |
|
|
print("❌ Generation failed: \(error.localizedDescription)") |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private func processTokens(_ tokens: [Int], model: MLModel) async throws { |
|
|
let seqLen = tokens.count |
|
|
|
|
|
|
|
|
let causalMask = createCausalMask(seqLen: seqLen, totalLen: seqLen) |
|
|
let inputIdsTensor = MLTensor( |
|
|
shape: [1, seqLen], |
|
|
scalars: tokens.map { Int32($0) }, |
|
|
scalarType: Int32.self |
|
|
) |
|
|
|
|
|
let inputs = try MLDictionaryFeatureProvider(dictionary: [ |
|
|
"inputIds": MLFeatureValue(tensor: inputIdsTensor), |
|
|
"causalMask": MLFeatureValue(tensor: causalMask) |
|
|
]) |
|
|
|
|
|
|
|
|
_ = try await model.prediction(from: inputs) |
|
|
currentPosition = seqLen |
|
|
} |
|
|
|
|
|
|
|
|
private func generateNextToken( |
|
|
temperature: Float, |
|
|
decodeModel: MLModel, |
|
|
decodeState: inout MLState |
|
|
) async throws -> Int { |
|
|
|
|
|
let positionIds = [Int32(currentPosition)] |
|
|
|
|
|
let positionTensor = MLTensor( |
|
|
shape: [1, 1], |
|
|
scalars: positionIds, |
|
|
scalarType: Int32.self |
|
|
) |
|
|
|
|
|
|
|
|
let dummyInputTensor = MLTensor( |
|
|
shape: [1, 1], |
|
|
scalars: [Int32(0)], |
|
|
scalarType: Int32.self |
|
|
) |
|
|
|
|
|
let inputs = try MLDictionaryFeatureProvider(dictionary: [ |
|
|
"inputIds": MLFeatureValue(tensor: dummyInputTensor), |
|
|
"positionIds": MLFeatureValue(tensor: positionTensor), |
|
|
]) |
|
|
|
|
|
let output = try await decodeModel.prediction(from: inputs, using: decodeState) |
|
|
|
|
|
guard let logitsTensor = output.featureValue(for: "logits")?.tensorValue(of: Float16.self) else { |
|
|
throw Qwen3Error.inferenceError("No logits in model output") |
|
|
} |
|
|
|
|
|
|
|
|
let nextToken = sampleToken(from: logitsTensor, temperature: temperature) |
|
|
|
|
|
|
|
|
currentPosition += 1 |
|
|
|
|
|
return nextToken |
|
|
} |
|
|
|
|
|
|
|
|
private func sampleToken(from logitsTensor: MLTensor, temperature: Float) -> Int { |
|
|
|
|
|
let vocabSize = logitsTensor.shape[2] |
|
|
|
|
|
var logitsArray = [Float](repeating: 0, count: vocabSize) |
|
|
logitsTensor.withUnsafeBufferPointer(of: Float16.self) { buffer in |
|
|
for i in 0..<vocabSize { |
|
|
logitsArray[i] = Float(buffer[vocabSize + i]) |
|
|
} |
|
|
} |
|
|
|
|
|
if temperature <= 0 { |
|
|
|
|
|
return logitsArray.enumerated().max(by: { $0.element < $1.element })?.offset ?? 0 |
|
|
} |
|
|
|
|
|
|
|
|
let scaledLogits = logitsArray.map { $0 / temperature } |
|
|
let maxLogit = scaledLogits.max() ?? 0 |
|
|
let expLogits = scaledLogits.map { exp($0 - maxLogit) } |
|
|
let sumExp = expLogits.reduce(0, +) |
|
|
let probs = expLogits.map { $0 / sumExp } |
|
|
|
|
|
|
|
|
let random = Float.random(in: 0..<1) |
|
|
var cumulative: Float = 0 |
|
|
|
|
|
for (index, prob) in probs.enumerated() { |
|
|
cumulative += prob |
|
|
if random < cumulative { |
|
|
return index |
|
|
} |
|
|
} |
|
|
|
|
|
return vocabSize - 1 |
|
|
} |
|
|
|
|
|
|
|
|
private func createCausalMask(seqLen: Int, totalLen: Int) -> MLTensor { |
|
|
var maskData = [Float16](repeating: Float16(-Float.infinity), count: seqLen * totalLen) |
|
|
|
|
|
for i in 0..<seqLen { |
|
|
for j in 0..<(totalLen - seqLen + i + 1) { |
|
|
maskData[i * totalLen + j] = Float16(0) |
|
|
} |
|
|
} |
|
|
|
|
|
return MLTensor( |
|
|
shape: [1, 1, seqLen, totalLen], |
|
|
scalars: maskData, |
|
|
scalarType: Float16.self |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
private func formatChatPrompt(userMessage: String, systemPrompt: String, enableThinking: Bool) -> String { |
|
|
let chatTemplate = "<|im_start|>system\n\(systemPrompt)<|im_end|>\n<|im_start|>user\n\(userMessage)<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
if enableThinking { |
|
|
return chatTemplate |
|
|
} else { |
|
|
return chatTemplate + "/no_think\n" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public enum Qwen3Error: LocalizedError { |
|
|
case modelNotFound(String) |
|
|
case modelNotLoaded |
|
|
case modelLoadingFailed(String) |
|
|
case inferenceError(String) |
|
|
case tokenizationError |
|
|
|
|
|
public var errorDescription: String? { |
|
|
switch self { |
|
|
case .modelNotFound(let modelName): |
|
|
return "Model '\(modelName)' not found. Place it in app bundle or ~/Library/Application Support/Qwen3CoreML/Models/" |
|
|
case .modelNotLoaded: |
|
|
return "Models are not loaded. Call loadModels() first." |
|
|
case .inferenceError(let message): |
|
|
return "Inference error: \(message)" |
|
|
case .tokenizationError: |
|
|
return "Tokenization error" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extension Qwen3CoreML { |
|
|
|
|
|
|
|
|
public func correct(text: String) async throws -> String { |
|
|
return try await generateSync( |
|
|
userMessage: """ |
|
|
Please correct the following text by fixing punctuation, capitalization, and grammatical errors. |
|
|
Keep the original language. Only output the corrected text, nothing else. |
|
|
|
|
|
Text: \(text) |
|
|
|
|
|
Corrected: |
|
|
""", |
|
|
systemPrompt: "You are a professional proofreader and text editor.", |
|
|
maxTokens: 256, |
|
|
temperature: 0.1 |
|
|
).trimmingCharacters(in: .whitespacesAndNewlines) |
|
|
} |
|
|
} |
|
|
|