69 lines
2.5 KiB
Kotlin
69 lines
2.5 KiB
Kotlin
|
|
package com.example.myapplication.data
|
||
|
|
|
||
|
|
import android.content.Context
|
||
|
|
import java.io.BufferedReader
|
||
|
|
import java.io.InputStreamReader
|
||
|
|
|
||
|
|
data class BigramModel(
|
||
|
|
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
|
||
|
|
val uniLogp: FloatArray, // 长度 = vocab.size
|
||
|
|
val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR)
|
||
|
|
val biCols: IntArray, // 长度 = nnz
|
||
|
|
val biLogp: FloatArray // 长度 = nnz
|
||
|
|
)
|
||
|
|
|
||
|
|
object LanguageModelLoader {
|
||
|
|
fun load(context: Context): BigramModel {
|
||
|
|
val vocab = context.assets.open("vocab.txt").bufferedReader()
|
||
|
|
.use(BufferedReader::readLines)
|
||
|
|
|
||
|
|
val uniLogp = readFloat32(context, "uni_logp.bin")
|
||
|
|
val biRowptr = readInt32(context, "bi_rowptr.bin")
|
||
|
|
val biCols = readInt32(context, "bi_cols.bin")
|
||
|
|
val biLogp = readFloat32(context, "bi_logp.bin")
|
||
|
|
|
||
|
|
require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" }
|
||
|
|
require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" }
|
||
|
|
require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" }
|
||
|
|
|
||
|
|
return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp)
|
||
|
|
}
|
||
|
|
|
||
|
|
private fun readInt32(context: Context, name: String): IntArray {
|
||
|
|
context.assets.open(name).use { input ->
|
||
|
|
val bytes = input.readBytes()
|
||
|
|
val n = bytes.size / 4
|
||
|
|
val out = IntArray(n)
|
||
|
|
var i = 0; var j = 0
|
||
|
|
while (i < n) {
|
||
|
|
// 小端序
|
||
|
|
val v = (bytes[j].toInt() and 0xFF) or
|
||
|
|
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||
|
|
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||
|
|
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||
|
|
out[i++] = v
|
||
|
|
j += 4
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private fun readFloat32(context: Context, name: String): FloatArray {
|
||
|
|
context.assets.open(name).use { input ->
|
||
|
|
val bytes = input.readBytes()
|
||
|
|
val n = bytes.size / 4
|
||
|
|
val out = FloatArray(n)
|
||
|
|
var i = 0; var j = 0
|
||
|
|
while (i < n) {
|
||
|
|
val bits = (bytes[j].toInt() and 0xFF) or
|
||
|
|
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||
|
|
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||
|
|
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||
|
|
out[i++] = Float.fromBits(bits)
|
||
|
|
j += 4
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|