Files
Android-key-of-love/app/src/main/java/com/example/myapplication/data/LanguageModelLoader.kt
pengxiaolong 673b4491d7 优化plus
2026-01-15 21:32:32 +08:00

142 lines
5.4 KiB
Kotlin

package com.example.myapplication.data
import android.content.Context
import java.io.BufferedReader
import java.io.FileInputStream
import java.io.FileNotFoundException
import java.io.InputStream
import java.io.InputStreamReader
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.Channels
import java.nio.channels.FileChannel
import kotlin.math.max
data class BigramModel(
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
val uniLogp: FloatArray, // 长度 = vocab.size
val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR)
val biCols: IntArray, // 长度 = nnz
val biLogp: FloatArray // 长度 = nnz
)
object LanguageModelLoader {
fun load(context: Context): BigramModel {
val vocab = context.assets.open("vocab.txt").bufferedReader()
.use(BufferedReader::readLines)
val uniLogp = readFloat32(context, "uni_logp.bin")
val biRowptr = readInt32(context, "bi_rowptr.bin")
val biCols = readInt32(context, "bi_cols.bin")
val biLogp = readFloat32(context, "bi_logp.bin")
require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" }
require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" }
require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" }
return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp)
}
private fun readInt32(context: Context, name: String): IntArray {
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readInt32Channel(channel, afd.startOffset, afd.length)
}
}
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
context.assets.open(name).use { input ->
return readInt32Stream(input)
}
}
private fun readFloat32(context: Context, name: String): FloatArray {
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readFloat32Channel(channel, afd.startOffset, afd.length)
}
}
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
context.assets.open(name).use { input ->
return readFloat32Stream(input)
}
}
private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray {
require(length % 4L == 0L) { "int32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = IntArray(count)
mapped.asIntBuffer().get(out)
return out
}
private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray {
require(length % 4L == 0L) { "float32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = FloatArray(count)
mapped.asFloatBuffer().get(out)
return out
}
private fun readInt32Stream(input: InputStream): IntArray {
val initialSize = max(1024, input.available() / 4)
var out = IntArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getInt()
}
buffer.compact()
}
}
buffer.flip()
check(buffer.remaining() == 0) { "truncated int32 stream" }
return out.copyOf(count)
}
private fun readFloat32Stream(input: InputStream): FloatArray {
val initialSize = max(1024, input.available() / 4)
var out = FloatArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getFloat()
}
buffer.compact()
}
}
buffer.flip()
check(buffer.remaining() == 0) { "truncated float32 stream" }
return out.copyOf(count)
}
}