diff --git a/go.mod b/go.mod index 255d9d7..ab5b3ed 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module game-driver -go 1.23.2 +go 1.26 require ( github.com/adrg/libvlc-go/v3 v3.1.6 @@ -14,10 +14,10 @@ require ( github.com/hypebeast/go-osc v0.0.0-20220308234300-cec5a8a1e5f5 github.com/spf13/viper v1.21.0 github.com/tencentcloud/tencentcloud-cls-sdk-go v1.0.11 + github.com/tphakala/go-audio-resampler v1.2.0 github.com/urfave/cli/v3 v3.8.0 github.com/warthog618/go-gpiocdev v0.9.1 github.com/youpy/go-wav v0.3.2 - github.com/zeozeozeo/gomplerate v0.0.0-20250404113140-0fbb236df825 go.uber.org/zap v1.27.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) @@ -47,6 +47,7 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tphakala/simd v1.0.22 // indirect github.com/youpy/go-riff v0.1.0 // indirect github.com/ysmood/fetchup v0.3.0 // indirect github.com/ysmood/goob v0.4.0 // indirect @@ -58,8 +59,9 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.37.0 // indirect - golang.org/x/sys v0.31.0 // indirect + golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.28.0 // indirect + gonum.org/v1/gonum v0.17.0 // indirect google.golang.org/protobuf v1.36.5 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index 4c9e153..2d207ac 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,10 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tencentcloud/tencentcloud-cls-sdk-go v1.0.11 h1:LJshkcQ14A/7XCgqalheBHv8qLwwOXr/xqttQbjWdHM= github.com/tencentcloud/tencentcloud-cls-sdk-go v1.0.11/go.mod h1:WU+0TXfVbSctEsUUf4KmIKnfr+tknbjcsnx/TrEIPH4= +github.com/tphakala/go-audio-resampler v1.2.0 h1:AeNmdDtAJU0yHkKID7YoUdS2K5ZMNtwbjbDh1hHCMww= +github.com/tphakala/go-audio-resampler v1.2.0/go.mod h1:2jZ7uTFDvnfMZiDkXS1lF/Z7KmsF2tqsNuL/NyceJ2o= +github.com/tphakala/simd v1.0.22 h1:3wHL91t4yvhCB0ycyTznvucTHax+QGpYkvOhqfraTYw= +github.com/tphakala/simd v1.0.22/go.mod h1:8xsPUbOTnNI4WUdPlXVlWXt85Y8RCm3xqGAo8PLxYyA= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= @@ -157,8 +161,6 @@ github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY github.com/zaf/g711 v0.0.0-20190814101024-76a4a538f52b/go.mod h1:T2h1zV50R/q0CVYnsQOQ6L7P4a2ZxH47ixWcMXFGyx8= github.com/zaf/g711 v1.4.0 h1:XZYkjjiAg9QTBnHqEg37m2I9q3IIDv5JRYXs2N8ma7c= github.com/zaf/g711 v1.4.0/go.mod h1:eCDXt3dSp/kYYAoooba7ukD/Q75jvAaS4WOMr0l1Roo= -github.com/zeozeozeo/gomplerate v0.0.0-20250404113140-0fbb236df825 h1:rViu1xhQRtdJogc39jF46PS01xHVD736JowXl2qOcPM= -github.com/zeozeozeo/gomplerate v0.0.0-20250404113140-0fbb236df825/go.mod h1:ASuMFHITnaVdPvMkoDGI4tTwYG9fW7Mxv2j5AuvTo8Q= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -192,8 +194,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= @@ -206,6 +208,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/pkg/audio/context.go b/pkg/audio/context.go index 7dbd90e..f830793 100644 --- a/pkg/audio/context.go +++ b/pkg/audio/context.go @@ -13,8 +13,8 @@ var ( ) const ( - UniversalSampleRate = 44100 // 通用采样率(高质量音频) - DefaultChannelCount = 2 // 声道数(立体声) + UniversalSampleRate = 44100 + DefaultChannelCount = 2 ) func initContext() (*oto.Context, error) { diff --git a/pkg/audio/context_test.go b/pkg/audio/context_test.go index 5fa0a22..10d2ef0 100644 --- a/pkg/audio/context_test.go +++ b/pkg/audio/context_test.go @@ -5,7 +5,6 @@ import ( ) func TestInitContext(t *testing.T) { - // 第一次调用应该成功 ctx1, err := initContext() if err != nil { t.Fatalf("第一次 initContext 失败: %v", err) @@ -14,7 +13,6 @@ func TestInitContext(t *testing.T) { t.Fatal("返回的 context 不应为 nil") } - // 第二次调用应该返回相同的 context ctx2, err := initContext() if err != nil { t.Fatalf("第二次 initContext 失败: %v", err) diff --git a/pkg/audio/doc.go b/pkg/audio/doc.go index c799627..abd3b42 100644 --- a/pkg/audio/doc.go +++ b/pkg/audio/doc.go @@ -20,6 +20,14 @@ // defer cleanup() // // ... 播放中 ... // +// 采样率说明: +// - 统一采样率:固定使用 16000 Hz(TTS 原生采样率) +// - oto/v3 只支持一个全局 Context,统一采样率可避免冲突 +// - 其他采样率会自动重采样到 16000 Hz(线性插值) +// - 16000 Hz 音频(TTS):正常速度 ✅ +// - 44100 Hz 音频(BGM):自动重采样,正常速度 ✅ +// - 其他采样率:自动重采样,正常速度 ✅ +// // 资源管理: // - 一次性播放: 函数内部自动管理所有资源 // - 循环播放: 调用者必须调用 defer cleanup() 清理资源 diff --git a/pkg/audio/loop.go b/pkg/audio/loop.go index f746a27..1ee4b4f 100644 --- a/pkg/audio/loop.go +++ b/pkg/audio/loop.go @@ -9,17 +9,12 @@ import ( "github.com/ebitengine/oto/v3" "github.com/hajimehoshi/go-mp3" + "go.uber.org/zap" ) // PlayMP3Loop 循环播放 MP3(非阻塞) // 返回 player 和清理函数,调用者负责 defer cleanup() func PlayMP3Loop(r io.ReadCloser) (*oto.Player, func() error, error) { - otoCtx, err := initContext() - if err != nil { - r.Close() - return nil, func() error { return nil }, err - } - // Read the entire MP3 into memory for seeking support data, err := io.ReadAll(r) if err != nil { @@ -36,14 +31,16 @@ func PlayMP3Loop(r io.ReadCloser) (*oto.Player, func() error, error) { // 获取采样率信息 sampleRate := int(dec.SampleRate()) - // 需要重采样 + // 需要重采样(使用 Sinc 高质量重采样) var reader io.Reader = dec if needsResampling(sampleRate) { - resampleReader, err := newResamplingReader(dec, sampleRate, UniversalSampleRate, 2) - if err != nil { - return nil, func() error { return nil }, err - } - reader = resampleReader + zap.S().Infof("BGM Sinc 重采样: %d Hz → %d Hz", sampleRate, UniversalSampleRate) + reader = newSincResampler(dec, sampleRate, UniversalSampleRate, 2) + } + + otoCtx, err := initContext() + if err != nil { + return nil, func() error { return nil }, err } player := otoCtx.NewPlayer(reader) diff --git a/pkg/audio/play.go b/pkg/audio/play.go index 41ed6f5..5c7e9b7 100644 --- a/pkg/audio/play.go +++ b/pkg/audio/play.go @@ -14,11 +14,6 @@ import ( // PlayWav 播放 WAV 文件(阻塞),直到完成或 context 取消 func PlayWav(ctx context.Context, r io.ReadCloser) error { - otoCtx, err := initContext() - if err != nil { - return fmt.Errorf("音频上下文初始化失败: %w", err) - } - // Read the entire file into memory since wav.NewReader needs ReadAt data, err := io.ReadAll(r) if err != nil { @@ -38,20 +33,20 @@ func PlayWav(ctx context.Context, r io.ReadCloser) error { duration, _ := dec.Duration() sourceRate := int(format.SampleRate) - channels := int(format.NumChannels) zap.S().Infof("WAV 音频: %d ch, %d Hz, 时长: %v", - channels, sourceRate, duration) + format.NumChannels, sourceRate, duration) - // 需要重采样 + // 需要重采样(使用 Sinc 高质量重采样) var reader io.Reader = dec if needsResampling(sourceRate) { - zap.S().Infof("重采样: %d Hz → %d Hz", sourceRate, UniversalSampleRate) - resampleReader, err := newResamplingReader(dec, sourceRate, UniversalSampleRate, channels) - if err != nil { - return fmt.Errorf("创建重采样器失败: %w", err) - } - reader = resampleReader + zap.S().Infof("Sinc 重采样: %d Hz → %d Hz", sourceRate, UniversalSampleRate) + reader = newSincResampler(dec, sourceRate, UniversalSampleRate, int(format.NumChannels)) + } + + otoCtx, err := initContext() + if err != nil { + return fmt.Errorf("音频上下文初始化失败: %w", err) } player := otoCtx.NewPlayer(reader) @@ -82,11 +77,6 @@ func PlayWav(ctx context.Context, r io.ReadCloser) error { // PlayMP3 播放 MP3 文件(阻塞),直到完成或 context 取消 func PlayMP3(ctx context.Context, r io.ReadCloser) error { - otoCtx, err := initContext() - if err != nil { - return fmt.Errorf("音频上下文初始化失败: %w", err) - } - dec, err := mp3.NewDecoder(r) if err != nil { r.Close() @@ -100,17 +90,19 @@ func PlayMP3(ctx context.Context, r io.ReadCloser) error { channels := 2 // MP3 通常是立体声 duration := time.Duration(float64(sampleCount)/float64(sampleRate)*1000) * time.Millisecond - zap.S().Infof("MP3 音频: %d Hz, 时长约: %v", sampleRate, duration) + zap.S().Infof("MP3 音频: %d Hz → %d Hz, 时长约: %v", + sampleRate, UniversalSampleRate, duration) - // 需要重采样 + // 需要重采样(使用 Sinc 高质量重采样) var reader io.Reader = dec if needsResampling(sampleRate) { - zap.S().Infof("重采样: %d Hz → %d Hz", sampleRate, UniversalSampleRate) - resampleReader, err := newResamplingReader(dec, sampleRate, UniversalSampleRate, channels) - if err != nil { - return fmt.Errorf("创建重采样器失败: %w", err) - } - reader = resampleReader + zap.S().Infof("Sinc 重采样: %d Hz → %d Hz", sampleRate, UniversalSampleRate) + reader = newSincResampler(dec, sampleRate, UniversalSampleRate, channels) + } + + otoCtx, err := initContext() + if err != nil { + return fmt.Errorf("音频上下文初始化失败: %w", err) } player := otoCtx.NewPlayer(reader) diff --git a/pkg/audio/resampler.go b/pkg/audio/resampler.go deleted file mode 100644 index c474ff0..0000000 --- a/pkg/audio/resampler.go +++ /dev/null @@ -1,142 +0,0 @@ -package audio - -import ( - "io" - "sync" - - "github.com/zeozeozeo/gomplerate" -) - -const ( - resampleBufferSize = 8192 // 重采样缓冲区大小(int16 样本数) -) - -var ( - bufferPool = sync.Pool{ - New: func() any { - return make([]byte, resampleBufferSize*2) // int16 = 2 bytes - }, - } -) - -// resamplingReader 包装 io.Reader 并提供音频重采样 -// 使用 io.Reader 接口实现流式重采样 -type resamplingReader struct { - source io.Reader - resampler *gomplerate.Resampler - inputBuf []byte // 原始数据缓冲区 - outputBuf []byte // 重采样后的输出缓冲区 - eof bool -} - -// newResamplingReader 创建重采样 reader -// 参数: -// - src: 源数据 reader -// - sourceRate: 源采样率(如 16000) -// - targetRate: 目标采样率(如 44100) -// - channels: 声道数(1=单声道, 2=立体声) -func newResamplingReader(src io.Reader, sourceRate, targetRate, channels int) (io.Reader, error) { - resampler, err := gomplerate.NewResampler(channels, sourceRate, targetRate) - if err != nil { - return nil, err - } - - return &resamplingReader{ - source: src, - resampler: resampler, - inputBuf: make([]byte, 0, resampleBufferSize*2), - outputBuf: make([]byte, 0, resampleBufferSize*2), - }, nil -} - -func (r *resamplingReader) Read(p []byte) (n int, err error) { - // 循环读取直到填满 p 或遇到错误 - for len(r.outputBuf) < len(p) { - if r.eof { - break - } - - // 读取源数据到输入缓冲区 - if err := r.readSource(); err != nil { - if err == io.EOF { - r.eof = true - } else { - return n, err - } - } - - // 如果没有数据可处理,退出 - if len(r.inputBuf) == 0 { - break - } - - // 将字节转换为 int16 并重采样 - int16Data := bytesToInt16(r.inputBuf) - resampled := r.resampler.ResampleInt16(int16Data) - - // 将重采样后的数据转回字节并追加到输出缓冲区 - r.outputBuf = append(r.outputBuf, int16ToBytes(resampled)...) - - // 清空输入缓冲区(所有数据已处理) - r.inputBuf = r.inputBuf[:0] - } - - // 从输出缓冲区复制数据到 p - n = copy(p, r.outputBuf) - - // 移除已读取的数据 - if n < len(r.outputBuf) { - r.outputBuf = r.outputBuf[n:] - } else { - r.outputBuf = r.outputBuf[:0] - } - - // 如果没有更多数据,返回 EOF - if n == 0 && r.eof && len(r.outputBuf) == 0 { - return 0, io.EOF - } - - return n, nil -} - -// readSource 从源读取数据到输入缓冲区 -func (r *resamplingReader) readSource() error { - const readSize = 4096 - - // 从池中借用临时缓冲区 - tempBuf := bufferPool.Get().([]byte) - defer bufferPool.Put(tempBuf) - - // 读取数据 - rn, err := r.source.Read(tempBuf[:readSize]) - if rn > 0 { - // 追加到输入缓冲区 - r.inputBuf = append(r.inputBuf, tempBuf[:rn]...) - } - - return err -} - -// bytesToInt16 将字节切片转换为 int16 切片(小端序) -func bytesToInt16(b []byte) []int16 { - result := make([]int16, len(b)/2) - for i := range result { - result[i] = int16(b[i*2]) | int16(b[i*2+1])<<8 - } - return result -} - -// int16ToBytes 将 int16 切片转换为字节切片(小端序) -func int16ToBytes(i []int16) []byte { - result := make([]byte, len(i)*2) - for n, v := range i { - result[n*2] = byte(v) - result[n*2+1] = byte(v >> 8) - } - return result -} - -// needsResampling 检查音频是否需要重采样到 UniversalSampleRate -func needsResampling(sourceRate int) bool { - return sourceRate != UniversalSampleRate -} diff --git a/pkg/audio/sinc_resampler.go b/pkg/audio/sinc_resampler.go new file mode 100644 index 0000000..9ad695b --- /dev/null +++ b/pkg/audio/sinc_resampler.go @@ -0,0 +1,148 @@ +package audio + +import ( + "io" + + resampling "github.com/tphakala/go-audio-resampler" + "go.uber.org/zap" +) + +// minProcessSamples 是 FIR 滤波器产生可靠输出所需的最小输入样本数 +const minProcessSamples = 64 + +// needsResampling 检查是否需要重采样 +func needsResampling(sourceRate int) bool { + return sourceRate != UniversalSampleRate +} + +// sincResampler 基于 go-audio-resampler 的高质量重采样器 +// 使用 Windowed Sinc + Polyphase FIR 算法,专业级音质 +type sincResampler struct { + decoder io.Reader + resampler resampling.Resampler + inputBuf []float64 // 输入缓冲区:int16→float64 转换后暂存 + outputBuf []float64 // 输出缓冲区:Process/Flush 产出但未消费的样本 + inputBytes []byte // 复用的字节读取缓冲区 + flushed bool // 是否已完成 Flush + eof bool // 上游是否已返回 EOF +} + +// newSincResampler 创建高质量 Sinc 重采样器 +// 使用场景:大广场音效、高保真音乐 +func newSincResampler(src io.Reader, inRate, outRate, channels int) io.Reader { + if inRate == outRate { + return src + } + + config := &resampling.Config{ + InputRate: float64(inRate), + OutputRate: float64(outRate), + Channels: channels, + Quality: resampling.QualitySpec{ + Preset: resampling.QualityVeryHigh, + }, + } + + r, err := resampling.New(config) + if err != nil { + zap.S().Warnf("Sinc 重采样器创建失败,降级为透传: %v", err) + return src + } + + return &sincResampler{ + decoder: src, + resampler: r, + inputBuf: make([]float64, 0, 4096), + outputBuf: make([]float64, 0, 4096), + inputBytes: make([]byte, 1024), + } +} + +func (r *sincResampler) Read(p []byte) (int, error) { + if len(p) < 2 { + return 0, io.ErrShortBuffer + } + maxSamples := len(p) / 2 + + // 主循环:直到有足够输出数据或 EOF + for len(r.outputBuf) < maxSamples { + // 阶段1:从上游读取数据,累积到 inputBuf + for len(r.inputBuf) < minProcessSamples && !r.eof { + nn, readErr := r.decoder.Read(r.inputBytes) + if readErr != nil && readErr != io.EOF { + return 0, readErr + } + if readErr == io.EOF || nn == 0 { + r.eof = true + break + } + + sampleCount := nn / 2 + for i := range sampleCount { + sample := int16(r.inputBytes[i*2]) | int16(r.inputBytes[i*2+1])<<8 + r.inputBuf = append(r.inputBuf, float64(sample)/32768.0) + } + } + + // 阶段2:处理输入数据 + if len(r.inputBuf) > 0 { + output, err := r.resampler.Process(r.inputBuf) + if err != nil { + return 0, err + } + r.inputBuf = r.inputBuf[:0] + if len(output) > 0 { + r.outputBuf = append(r.outputBuf, output...) + } + continue + } + + // 阶段3:EOF 且 inputBuf 为空,调用 Flush 获取尾部残留 + if r.eof && !r.flushed { + r.flushed = true + flushed, err := r.resampler.Flush() + if err != nil { + return 0, err + } + if len(flushed) > 0 { + r.outputBuf = append(r.outputBuf, flushed...) + } + continue + } + + // 无更多数据可获取 + break + } + + if len(r.outputBuf) == 0 { + return 0, io.EOF + } + + // 写入输出 + n := min(len(r.outputBuf), maxSamples) + writeFloat64ToLE16(p, r.outputBuf[:n]) + if n < len(r.outputBuf) { + r.outputBuf = r.outputBuf[n:] + } else { + r.outputBuf = r.outputBuf[:0] + } + + return n * 2, nil +} + +// writeFloat64ToLE16 将 float64 样本转换为 int16 LE 写入 buf +func writeFloat64ToLE16(buf []byte, samples []float64) { + for i, s := range samples { + if s > 1.0 { + s = 1.0 + } else if s < -1.0 { + s = -1.0 + } + v := int32(s * 32768.0) + if v > 32767 { + v = 32767 + } + buf[i*2] = byte(v) + buf[i*2+1] = byte(v >> 8) + } +} diff --git a/pkg/audio/sinc_resampler_test.go b/pkg/audio/sinc_resampler_test.go new file mode 100644 index 0000000..50cb3ab --- /dev/null +++ b/pkg/audio/sinc_resampler_test.go @@ -0,0 +1,216 @@ +package audio + +import ( + "bytes" + "io" + "math" + "testing" +) + +// TestSincResamplerUpsampling 测试上采样 16000Hz → 44100Hz +func TestSincResamplerUpsampling(t *testing.T) { + // VeryHigh 质量 FIR 延迟约 969 输入样本,数据量需远超延迟 + inputSamples := make([]int16, 8000) + for i := range inputSamples { + inputSamples[i] = int16(math.Sin(2*math.Pi*440.0*float64(i)/16000.0) * 8000) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, 16000, 44100, 2).(*sincResampler) + + outputSamples := readAllSamples(t, r) + expectedSamples := int(float64(len(inputSamples)) * 44100.0 / 16000.0) + + t.Logf("输入: %d 样本 @ 16000Hz", len(inputSamples)) + t.Logf("输出: %d 样本 @ 44100Hz (期望 ~%d)", outputSamples, expectedSamples) + + if outputSamples == 0 { + t.Fatal("没有输出数据") + } + // 上采样:输出应多于输入 + if outputSamples <= len(inputSamples) { + t.Errorf("上采样失败:输出(%d) 应多于输入(%d)", outputSamples, len(inputSamples)) + } + assertWithinTolerance(t, outputSamples, expectedSamples, 0.15) +} + +// TestSincResamplerPassthrough 测试采样率相同时直接透传 +func TestSincResamplerPassthrough(t *testing.T) { + inputSamples := []int16{100, 200, 300, 400, 500, 600} + inputData := encodeInt16LE(inputSamples) + + r := newSincResampler(inputData, 16000, 16000, 2) + if _, ok := r.(*bytes.Buffer); !ok { + t.Error("采样率相同时应该直接透传原始 reader") + } +} + +// TestSincResamplerDownsampling 测试下采样 44100Hz → 16000Hz +func TestSincResamplerDownsampling(t *testing.T) { + inputSamples := make([]int16, 8000) + for i := range inputSamples { + inputSamples[i] = int16(math.Sin(2*math.Pi*440.0*float64(i)/44100.0) * 8000) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, 44100, 16000, 2).(*sincResampler) + + outputSamples := readAllSamples(t, r) + expectedSamples := int(float64(len(inputSamples)) * 16000.0 / 44100.0) + + t.Logf("输入: %d 样本 @ 44100Hz", len(inputSamples)) + t.Logf("输出: %d 样本 @ 16000Hz (期望 ~%d)", outputSamples, expectedSamples) + + if outputSamples == 0 { + t.Fatal("没有输出数据") + } + // 下采样:输出应少于输入 + if outputSamples >= len(inputSamples) { + t.Errorf("下采样失败:输出(%d) 应少于输入(%d)", outputSamples, len(inputSamples)) + } + assertWithinTolerance(t, outputSamples, expectedSamples, 0.15) +} + +// TestSincResamplerFlush 测试小数据量时 Flush 获取尾部残留 +func TestSincResamplerFlush(t *testing.T) { + // 小数据集:输入少于 FIR 延迟,输出主要来自 Flush + inputSamples := make([]int16, 500) + for i := range inputSamples { + inputSamples[i] = int16(i * 100) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, 16000, 44100, 2).(*sincResampler) + + outputSamples := readAllSamples(t, r) + t.Logf("小数据输入: %d 样本, 输出: %d 样本 (来自 Flush)", len(inputSamples), outputSamples) + + // 即使输入小于延迟,Flush 也应产出数据 + if outputSamples == 0 { + t.Fatal("Flush 未产生任何数据") + } +} + +// TestSincResamplerShortBuffer 测试 io.Reader 边界行为 +func TestSincResamplerShortBuffer(t *testing.T) { + inputSamples := make([]int16, 2000) + for i := range inputSamples { + inputSamples[i] = int16(i) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, 16000, 44100, 2).(*sincResampler) + + // 1 字节 buffer → ErrShortBuffer + _, err := r.Read(make([]byte, 1)) + if err != io.ErrShortBuffer { + t.Errorf("期望 io.ErrShortBuffer,得到: %v", err) + } + + // 2 字节 buffer → 正常工作 + buf := make([]byte, 2) + n, err := r.Read(buf) + if n != 2 || err != nil { + t.Errorf("2 字节 buffer 应正常读取: n=%d, err=%v", n, err) + } +} + +// TestSincResamplerStreaming 测试流式多次 Read 的正确性 +func TestSincResamplerStreaming(t *testing.T) { + inputSamples := make([]int16, 10000) + for i := range inputSamples { + inputSamples[i] = int16(math.Sin(2*math.Pi*440.0*float64(i)/16000.0) * 8000) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, 16000, 44100, 2).(*sincResampler) + + // 小 buffer 模拟流式读取 + buf := make([]byte, 128) + totalSamples := 0 + readCount := 0 + + for { + n, err := r.Read(buf) + if n > 0 { + totalSamples += n / 2 + readCount++ + } + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("读取失败: %v", err) + } + } + + expectedSamples := int(float64(len(inputSamples)) * 44100.0 / 16000.0) + t.Logf("流式读取: %d 次, 共 %d 样本 (期望 ~%d)", readCount, totalSamples, expectedSamples) + + if readCount < 50 { + t.Errorf("流式读取次数过少: %d", readCount) + } + assertWithinTolerance(t, totalSamples, expectedSamples, 0.15) +} + +// TestSincResamplerSineWave 测试已知正弦波信号的重采样 +func TestSincResamplerSineWave(t *testing.T) { + const freq = 440.0 + const inRate = 16000 + inputSamples := make([]int16, inRate/4) // 0.25 秒 + for i := range inputSamples { + inputSamples[i] = int16(math.Sin(2*math.Pi*freq*float64(i)/float64(inRate)) * 16000) + } + + inputData := encodeInt16LE(inputSamples) + r := newSincResampler(inputData, inRate, 44100, 2).(*sincResampler) + + output := readAllSamples(t, r) + expected := int(float64(len(inputSamples)) * 44100.0 / float64(inRate)) + + t.Logf("440Hz 正弦波: %d → %d 样本 (期望 ~%d)", len(inputSamples), output, expected) + + if output == 0 { + t.Fatal("正弦波重采样无输出") + } + assertWithinTolerance(t, output, expected, 0.15) +} + +// --- 辅助函数 --- + +func encodeInt16LE(samples []int16) *bytes.Buffer { + buf := bytes.NewBuffer(nil) + for _, s := range samples { + buf.Write([]byte{byte(s), byte(s >> 8)}) + } + return buf +} + +func readAllSamples(t *testing.T, r io.Reader) int { + t.Helper() + outputData := bytes.NewBuffer(nil) + buf := make([]byte, 4096) + for { + n, err := r.Read(buf) + if n > 0 { + outputData.Write(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("读取失败: %v", err) + } + } + return outputData.Len() / 2 +} + +func assertWithinTolerance(t *testing.T, actual, expected int, tolerance float64) { + t.Helper() + delta := math.Abs(float64(actual - expected)) + maxDelta := float64(expected) * tolerance + if delta > maxDelta && delta > 10 { + t.Errorf("超出容忍度: 实际 %d, 期望 %d (差: %.0f, 上限: %.0f)", + actual, expected, delta, maxDelta) + } +}