142 lines
5.4 KiB
Kotlin
142 lines
5.4 KiB
Kotlin
package com.example.myapplication.data
|
|
|
|
import android.content.Context
|
|
import java.io.BufferedReader
|
|
import java.io.FileInputStream
|
|
import java.io.FileNotFoundException
|
|
import java.io.InputStream
|
|
import java.io.InputStreamReader
|
|
import java.nio.ByteBuffer
|
|
import java.nio.ByteOrder
|
|
import java.nio.channels.Channels
|
|
import java.nio.channels.FileChannel
|
|
import kotlin.math.max
|
|
|
|
data class BigramModel(
|
|
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
|
|
val uniLogp: FloatArray, // 长度 = vocab.size
|
|
val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR)
|
|
val biCols: IntArray, // 长度 = nnz
|
|
val biLogp: FloatArray // 长度 = nnz
|
|
)
|
|
|
|
object LanguageModelLoader {
|
|
fun load(context: Context): BigramModel {
|
|
val vocab = context.assets.open("vocab.txt").bufferedReader()
|
|
.use(BufferedReader::readLines)
|
|
|
|
val uniLogp = readFloat32(context, "uni_logp.bin")
|
|
val biRowptr = readInt32(context, "bi_rowptr.bin")
|
|
val biCols = readInt32(context, "bi_cols.bin")
|
|
val biLogp = readFloat32(context, "bi_logp.bin")
|
|
|
|
require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" }
|
|
require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" }
|
|
require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" }
|
|
|
|
return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp)
|
|
}
|
|
|
|
private fun readInt32(context: Context, name: String): IntArray {
|
|
try {
|
|
context.assets.openFd(name).use { afd ->
|
|
FileInputStream(afd.fileDescriptor).channel.use { channel ->
|
|
return readInt32Channel(channel, afd.startOffset, afd.length)
|
|
}
|
|
}
|
|
} catch (e: FileNotFoundException) {
|
|
// Compressed assets do not support openFd; fall back to streaming.
|
|
}
|
|
|
|
context.assets.open(name).use { input ->
|
|
return readInt32Stream(input)
|
|
}
|
|
}
|
|
|
|
private fun readFloat32(context: Context, name: String): FloatArray {
|
|
try {
|
|
context.assets.openFd(name).use { afd ->
|
|
FileInputStream(afd.fileDescriptor).channel.use { channel ->
|
|
return readFloat32Channel(channel, afd.startOffset, afd.length)
|
|
}
|
|
}
|
|
} catch (e: FileNotFoundException) {
|
|
// Compressed assets do not support openFd; fall back to streaming.
|
|
}
|
|
|
|
context.assets.open(name).use { input ->
|
|
return readFloat32Stream(input)
|
|
}
|
|
}
|
|
|
|
private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray {
|
|
require(length % 4L == 0L) { "int32 length invalid: $length" }
|
|
require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" }
|
|
val count = (length / 4L).toInt()
|
|
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
|
|
mapped.order(ByteOrder.LITTLE_ENDIAN)
|
|
val out = IntArray(count)
|
|
mapped.asIntBuffer().get(out)
|
|
return out
|
|
}
|
|
|
|
private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray {
|
|
require(length % 4L == 0L) { "float32 length invalid: $length" }
|
|
require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" }
|
|
val count = (length / 4L).toInt()
|
|
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
|
|
mapped.order(ByteOrder.LITTLE_ENDIAN)
|
|
val out = FloatArray(count)
|
|
mapped.asFloatBuffer().get(out)
|
|
return out
|
|
}
|
|
|
|
private fun readInt32Stream(input: InputStream): IntArray {
|
|
val initialSize = max(1024, input.available() / 4)
|
|
var out = IntArray(initialSize)
|
|
var count = 0
|
|
val buffer = ByteBuffer.allocateDirect(64 * 1024)
|
|
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
|
Channels.newChannel(input).use { channel ->
|
|
while (true) {
|
|
val read = channel.read(buffer)
|
|
if (read == -1) break
|
|
if (read == 0) continue
|
|
buffer.flip()
|
|
while (buffer.remaining() >= 4) {
|
|
if (count == out.size) out = out.copyOf(out.size * 2)
|
|
out[count++] = buffer.getInt()
|
|
}
|
|
buffer.compact()
|
|
}
|
|
}
|
|
buffer.flip()
|
|
check(buffer.remaining() == 0) { "truncated int32 stream" }
|
|
return out.copyOf(count)
|
|
}
|
|
|
|
private fun readFloat32Stream(input: InputStream): FloatArray {
|
|
val initialSize = max(1024, input.available() / 4)
|
|
var out = FloatArray(initialSize)
|
|
var count = 0
|
|
val buffer = ByteBuffer.allocateDirect(64 * 1024)
|
|
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
|
Channels.newChannel(input).use { channel ->
|
|
while (true) {
|
|
val read = channel.read(buffer)
|
|
if (read == -1) break
|
|
if (read == 0) continue
|
|
buffer.flip()
|
|
while (buffer.remaining() >= 4) {
|
|
if (count == out.size) out = out.copyOf(out.size * 2)
|
|
out[count++] = buffer.getFloat()
|
|
}
|
|
buffer.compact()
|
|
}
|
|
}
|
|
buffer.flip()
|
|
check(buffer.remaining() == 0) { "truncated float32 stream" }
|
|
return out.copyOf(count)
|
|
}
|
|
}
|