package com.example.myapplication.data import android.content.Context import com.example.myapplication.Trie import java.util.concurrent.atomic.AtomicBoolean import java.util.PriorityQueue import kotlin.math.max class BigramPredictor( private val context: Context, private val trie: Trie ) { @Volatile private var model: BigramModel? = null private val loading = AtomicBoolean(false) // 词 ↔ id 映射 @Volatile private var word2id: Map = emptyMap() @Volatile private var id2word: List = emptyList() @Volatile private var topUnigrams: List = emptyList() private val unigramCacheSize = 2000 //预先加载语言模型,并构建词到ID和ID到词的双向映射。 fun preload() { if (!loading.compareAndSet(false, true)) return Thread { try { val m = LanguageModelLoader.load(context) model = m // 建索引(vocab 与 bigram 索引对齐,注意不丢前三个符号) val map = HashMap(m.vocab.size * 2) m.vocab.forEachIndexed { idx, w -> map[w] = idx } word2id = map id2word = m.vocab topUnigrams = buildTopUnigrams(m, unigramCacheSize) } catch (_: Throwable) { // 保持静默,允许无模型运行(仅 Trie 起作用) } finally { loading.set(false) } }.start() } // 模型是否已准备好 fun isReady(): Boolean = model != null //基于上文 lastWord(可空)与前缀 prefix 联想,优先:bigram 条件概率 → Trie 过滤 → Top-K,兜底:unigram Top-K(同样做 Trie 过滤) fun suggest(prefix: String, lastWord: String?, topK: Int = 10): List { val m = model val pfx = prefix.trim() if (m == null) { // 模型未载入时,纯 Trie 前缀联想(你的 Trie 应提供类似 startsWith) return safeTriePrefix(pfx, topK) } val candidates = mutableListOf>() val lastId = lastWord?.let { word2id[it] } if (lastId != null) { // 1) bigram 邻域 val start = m.biRowptr[lastId] val end = m.biRowptr[lastId + 1] if (start in 0..end && end <= m.biCols.size) { // 先把 bigram 候选过一遍前缀过滤 for (i in start until end) { val nextId = m.biCols[i] val w = m.vocab[nextId] if (pfx.isEmpty() || w.startsWith(pfx, ignoreCase = true)) { val score = m.biLogp[i] // logP(next|last) candidates += w to score } } } } // 2) 如果有 bigram 过滤后的候选,直接取 topK if (candidates.isNotEmpty()) { return topKByScore(candidates, topK) } // 3) 兜底:用预计算的 unigram Top-N + 前缀过滤 if (topK <= 0) return emptyList() val cachedUnigrams = getTopUnigrams(m) if (pfx.isEmpty()) { return cachedUnigrams.take(topK) } val results = ArrayList(topK) if (cachedUnigrams.isNotEmpty()) { for (w in cachedUnigrams) { if (w.startsWith(pfx, ignoreCase = true)) { results.add(w) if (results.size >= topK) return results } } } if (results.size < topK) { val fromTrie = safeTriePrefix(pfx, topK) for (w in fromTrie) { if (w !in results) { results.add(w) if (results.size >= topK) break } } } return results } //供上层在用户选中词时更新“上文”状态 fun normalizeWordForContext(word: String): String? { // 你可以在这里做大小写/符号处理,或将 OOV 映射为 return if (word2id.containsKey(word)) word else "" } //在Trie数据结构中查找与给定前缀匹配的字符串,并返回其中评分最高的topK个结果。 private fun safeTriePrefix(prefix: String, topK: Int): List { if (prefix.isEmpty()) return emptyList() return try { trie.startsWith(prefix, topK) } catch (_: Throwable) { emptyList() } } private fun getTopUnigrams(model: BigramModel): List { val cached = topUnigrams if (cached.isNotEmpty()) return cached val built = buildTopUnigrams(model, unigramCacheSize) topUnigrams = built return built } private fun buildTopUnigrams(model: BigramModel, limit: Int): List { if (limit <= 0) return emptyList() val heap = topKHeap(limit) for (i in model.vocab.indices) { heap.offer(model.vocab[i] to model.uniLogp[i]) if (heap.size > limit) heap.poll() } return heap.toSortedListDescending() } //从给定的候选词对列表中,通过一个小顶堆来过滤出评分最高的前k个词 private fun topKByScore(pairs: List>, k: Int): List { val heap = topKHeap(k) for (p in pairs) { heap.offer(p) if (heap.size > k) heap.poll() } return heap.toSortedListDescending() } //创建一个优先队列,用于在一组候选词对中保持评分最高的 k 个词。 private fun topKHeap(k: Int): PriorityQueue> { // 小顶堆,比较 Float 分数 return PriorityQueue(k.coerceAtLeast(1)) { a, b -> a.second.compareTo(b.second) // 分数小的优先被弹出 } } // 排序后的候选词列表 private fun PriorityQueue>.toSortedListDescending(): List { val list = ArrayList>(this.size) while (this.isNotEmpty()) { val p = this.poll() ?: continue // 防御性判断,避免 null list.add(p) } list.reverse() // 从高分到低分 return list.map { it.first } } }