2025-11-26 16:47:15 +08:00
|
|
|
|
package com.example.myapplication.data
|
|
|
|
|
|
|
|
|
|
|
|
import android.content.Context
|
|
|
|
|
|
import com.example.myapplication.Trie
|
|
|
|
|
|
import java.util.concurrent.atomic.AtomicBoolean
|
|
|
|
|
|
import java.util.PriorityQueue
|
|
|
|
|
|
import kotlin.math.max
|
|
|
|
|
|
|
|
|
|
|
|
class BigramPredictor(
|
|
|
|
|
|
private val context: Context,
|
|
|
|
|
|
private val trie: Trie
|
|
|
|
|
|
) {
|
|
|
|
|
|
@Volatile private var model: BigramModel? = null
|
|
|
|
|
|
|
|
|
|
|
|
private val loading = AtomicBoolean(false)
|
|
|
|
|
|
|
|
|
|
|
|
// 词 ↔ id 映射
|
|
|
|
|
|
@Volatile private var word2id: Map<String, Int> = emptyMap()
|
|
|
|
|
|
|
|
|
|
|
|
@Volatile private var id2word: List<String> = emptyList()
|
2026-01-15 21:32:32 +08:00
|
|
|
|
@Volatile private var topUnigrams: List<String> = emptyList()
|
|
|
|
|
|
|
|
|
|
|
|
private val unigramCacheSize = 2000
|
2025-11-26 16:47:15 +08:00
|
|
|
|
|
|
|
|
|
|
//预先加载语言模型,并构建词到ID和ID到词的双向映射。
|
|
|
|
|
|
fun preload() {
|
|
|
|
|
|
if (!loading.compareAndSet(false, true)) return
|
|
|
|
|
|
|
|
|
|
|
|
Thread {
|
|
|
|
|
|
try {
|
|
|
|
|
|
val m = LanguageModelLoader.load(context)
|
|
|
|
|
|
|
|
|
|
|
|
model = m
|
|
|
|
|
|
|
|
|
|
|
|
// 建索引(vocab 与 bigram 索引对齐,注意不丢前三个符号)
|
|
|
|
|
|
val map = HashMap<String, Int>(m.vocab.size * 2)
|
|
|
|
|
|
|
|
|
|
|
|
m.vocab.forEachIndexed { idx, w -> map[w] = idx }
|
|
|
|
|
|
|
|
|
|
|
|
word2id = map
|
|
|
|
|
|
|
|
|
|
|
|
id2word = m.vocab
|
2026-01-15 21:32:32 +08:00
|
|
|
|
topUnigrams = buildTopUnigrams(m, unigramCacheSize)
|
2025-11-26 16:47:15 +08:00
|
|
|
|
} catch (_: Throwable) {
|
|
|
|
|
|
// 保持静默,允许无模型运行(仅 Trie 起作用)
|
|
|
|
|
|
} finally {
|
|
|
|
|
|
loading.set(false)
|
|
|
|
|
|
}
|
|
|
|
|
|
}.start()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 模型是否已准备好
|
|
|
|
|
|
fun isReady(): Boolean = model != null
|
|
|
|
|
|
|
|
|
|
|
|
//基于上文 lastWord(可空)与前缀 prefix 联想,优先:bigram 条件概率 → Trie 过滤 → Top-K,兜底:unigram Top-K(同样做 Trie 过滤)
|
|
|
|
|
|
fun suggest(prefix: String, lastWord: String?, topK: Int = 10): List<String> {
|
|
|
|
|
|
val m = model
|
|
|
|
|
|
|
|
|
|
|
|
val pfx = prefix.trim()
|
|
|
|
|
|
|
|
|
|
|
|
if (m == null) {
|
|
|
|
|
|
// 模型未载入时,纯 Trie 前缀联想(你的 Trie 应提供类似 startsWith)
|
|
|
|
|
|
return safeTriePrefix(pfx, topK)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
val candidates = mutableListOf<Pair<String, Float>>()
|
|
|
|
|
|
|
|
|
|
|
|
val lastId = lastWord?.let { word2id[it] }
|
|
|
|
|
|
|
|
|
|
|
|
if (lastId != null) {
|
|
|
|
|
|
// 1) bigram 邻域
|
|
|
|
|
|
val start = m.biRowptr[lastId]
|
|
|
|
|
|
|
|
|
|
|
|
val end = m.biRowptr[lastId + 1]
|
|
|
|
|
|
|
|
|
|
|
|
if (start in 0..end && end <= m.biCols.size) {
|
|
|
|
|
|
// 先把 bigram 候选过一遍前缀过滤
|
|
|
|
|
|
for (i in start until end) {
|
|
|
|
|
|
val nextId = m.biCols[i]
|
|
|
|
|
|
|
|
|
|
|
|
val w = m.vocab[nextId]
|
|
|
|
|
|
if (pfx.isEmpty() || w.startsWith(pfx, ignoreCase = true)) {
|
|
|
|
|
|
val score = m.biLogp[i] // logP(next|last)
|
|
|
|
|
|
|
|
|
|
|
|
candidates += w to score
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 如果有 bigram 过滤后的候选,直接取 topK
|
|
|
|
|
|
if (candidates.isNotEmpty()) {
|
|
|
|
|
|
return topKByScore(candidates, topK)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 21:32:32 +08:00
|
|
|
|
// 3) 兜底:用预计算的 unigram Top-N + 前缀过滤
|
|
|
|
|
|
if (topK <= 0) return emptyList()
|
2025-11-26 16:47:15 +08:00
|
|
|
|
|
2026-01-15 21:32:32 +08:00
|
|
|
|
val cachedUnigrams = getTopUnigrams(m)
|
|
|
|
|
|
if (pfx.isEmpty()) {
|
|
|
|
|
|
return cachedUnigrams.take(topK)
|
|
|
|
|
|
}
|
2025-11-26 16:47:15 +08:00
|
|
|
|
|
2026-01-15 21:32:32 +08:00
|
|
|
|
val results = ArrayList<String>(topK)
|
|
|
|
|
|
if (cachedUnigrams.isNotEmpty()) {
|
|
|
|
|
|
for (w in cachedUnigrams) {
|
|
|
|
|
|
if (w.startsWith(pfx, ignoreCase = true)) {
|
|
|
|
|
|
results.add(w)
|
|
|
|
|
|
if (results.size >= topK) return results
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-11-26 16:47:15 +08:00
|
|
|
|
|
2026-01-15 21:32:32 +08:00
|
|
|
|
if (results.size < topK) {
|
|
|
|
|
|
val fromTrie = safeTriePrefix(pfx, topK)
|
|
|
|
|
|
for (w in fromTrie) {
|
|
|
|
|
|
if (w !in results) {
|
|
|
|
|
|
results.add(w)
|
|
|
|
|
|
if (results.size >= topK) break
|
|
|
|
|
|
}
|
2025-11-26 16:47:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-15 21:32:32 +08:00
|
|
|
|
return results
|
2025-11-26 16:47:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//供上层在用户选中词时更新“上文”状态
|
|
|
|
|
|
fun normalizeWordForContext(word: String): String? {
|
|
|
|
|
|
// 你可以在这里做大小写/符号处理,或将 OOV 映射为 <unk>
|
|
|
|
|
|
return if (word2id.containsKey(word)) word else "<unk>"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//在Trie数据结构中查找与给定前缀匹配的字符串,并返回其中评分最高的topK个结果。
|
|
|
|
|
|
private fun safeTriePrefix(prefix: String, topK: Int): List<String> {
|
|
|
|
|
|
if (prefix.isEmpty()) return emptyList()
|
|
|
|
|
|
|
|
|
|
|
|
return try {
|
2026-01-15 21:32:32 +08:00
|
|
|
|
trie.startsWith(prefix, topK)
|
2025-11-26 16:47:15 +08:00
|
|
|
|
} catch (_: Throwable) {
|
|
|
|
|
|
emptyList()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 21:32:32 +08:00
|
|
|
|
private fun getTopUnigrams(model: BigramModel): List<String> {
|
|
|
|
|
|
val cached = topUnigrams
|
|
|
|
|
|
if (cached.isNotEmpty()) return cached
|
|
|
|
|
|
|
|
|
|
|
|
val built = buildTopUnigrams(model, unigramCacheSize)
|
|
|
|
|
|
topUnigrams = built
|
|
|
|
|
|
return built
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private fun buildTopUnigrams(model: BigramModel, limit: Int): List<String> {
|
|
|
|
|
|
if (limit <= 0) return emptyList()
|
|
|
|
|
|
val heap = topKHeap(limit)
|
|
|
|
|
|
|
|
|
|
|
|
for (i in model.vocab.indices) {
|
|
|
|
|
|
heap.offer(model.vocab[i] to model.uniLogp[i])
|
|
|
|
|
|
if (heap.size > limit) heap.poll()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return heap.toSortedListDescending()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-11-26 16:47:15 +08:00
|
|
|
|
//从给定的候选词对列表中,通过一个小顶堆来过滤出评分最高的前k个词
|
|
|
|
|
|
private fun topKByScore(pairs: List<Pair<String, Float>>, k: Int): List<String> {
|
|
|
|
|
|
val heap = topKHeap(k)
|
|
|
|
|
|
|
|
|
|
|
|
for (p in pairs) {
|
|
|
|
|
|
heap.offer(p)
|
|
|
|
|
|
|
|
|
|
|
|
if (heap.size > k) heap.poll()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return heap.toSortedListDescending()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//创建一个优先队列,用于在一组候选词对中保持评分最高的 k 个词。
|
|
|
|
|
|
private fun topKHeap(k: Int): PriorityQueue<Pair<String, Float>> {
|
|
|
|
|
|
// 小顶堆,比较 Float 分数
|
|
|
|
|
|
return PriorityQueue(k.coerceAtLeast(1)) { a, b ->
|
|
|
|
|
|
a.second.compareTo(b.second) // 分数小的优先被弹出
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 排序后的候选词列表
|
|
|
|
|
|
private fun PriorityQueue<Pair<String, Float>>.toSortedListDescending(): List<String> {
|
|
|
|
|
|
val list = ArrayList<Pair<String, Float>>(this.size)
|
|
|
|
|
|
|
|
|
|
|
|
while (this.isNotEmpty()) {
|
|
|
|
|
|
val p = this.poll() ?: continue // 防御性判断,避免 null
|
|
|
|
|
|
list.add(p)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
list.reverse() // 从高分到低分
|
|
|
|
|
|
|
|
|
|
|
|
return list.map { it.first }
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|