Files
Android-key-of-love/app/src/main/java/com/example/myapplication/data/LanguageModelLoader.kt

142 lines
5.4 KiB
Kotlin
Raw Normal View History

2025-11-26 16:47:15 +08:00
package com.example.myapplication.data
import android.content.Context
import java.io.BufferedReader
2026-01-15 21:32:32 +08:00
import java.io.FileInputStream
import java.io.FileNotFoundException
import java.io.InputStream
2025-11-26 16:47:15 +08:00
import java.io.InputStreamReader
2026-01-15 21:32:32 +08:00
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.Channels
import java.nio.channels.FileChannel
import kotlin.math.max
2025-11-26 16:47:15 +08:00
data class BigramModel(
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
val uniLogp: FloatArray, // 长度 = vocab.size
val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR)
val biCols: IntArray, // 长度 = nnz
val biLogp: FloatArray // 长度 = nnz
)
object LanguageModelLoader {
fun load(context: Context): BigramModel {
val vocab = context.assets.open("vocab.txt").bufferedReader()
.use(BufferedReader::readLines)
val uniLogp = readFloat32(context, "uni_logp.bin")
val biRowptr = readInt32(context, "bi_rowptr.bin")
val biCols = readInt32(context, "bi_cols.bin")
val biLogp = readFloat32(context, "bi_logp.bin")
require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" }
require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" }
require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" }
return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp)
}
private fun readInt32(context: Context, name: String): IntArray {
2026-01-15 21:32:32 +08:00
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readInt32Channel(channel, afd.startOffset, afd.length)
}
2025-11-26 16:47:15 +08:00
}
2026-01-15 21:32:32 +08:00
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
context.assets.open(name).use { input ->
return readInt32Stream(input)
2025-11-26 16:47:15 +08:00
}
}
private fun readFloat32(context: Context, name: String): FloatArray {
2026-01-15 21:32:32 +08:00
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readFloat32Channel(channel, afd.startOffset, afd.length)
}
}
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
2025-11-26 16:47:15 +08:00
context.assets.open(name).use { input ->
2026-01-15 21:32:32 +08:00
return readFloat32Stream(input)
}
}
private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray {
require(length % 4L == 0L) { "int32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = IntArray(count)
mapped.asIntBuffer().get(out)
return out
}
private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray {
require(length % 4L == 0L) { "float32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = FloatArray(count)
mapped.asFloatBuffer().get(out)
return out
}
private fun readInt32Stream(input: InputStream): IntArray {
val initialSize = max(1024, input.available() / 4)
var out = IntArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getInt()
}
buffer.compact()
}
}
buffer.flip()
check(buffer.remaining() == 0) { "truncated int32 stream" }
return out.copyOf(count)
}
private fun readFloat32Stream(input: InputStream): FloatArray {
val initialSize = max(1024, input.available() / 4)
var out = FloatArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getFloat()
}
buffer.compact()
2025-11-26 16:47:15 +08:00
}
}
2026-01-15 21:32:32 +08:00
buffer.flip()
check(buffer.remaining() == 0) { "truncated float32 stream" }
return out.copyOf(count)
2025-11-26 16:47:15 +08:00
}
}