From 03876855f440438dcc66065e6395d8112673694e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 14 Jan 2025 18:20:27 +0800 Subject: [PATCH] Add C++ runtime support for Whisper models (#698) --- .github/scripts/run-offline-whisper.sh | 84 ++++ .github/workflows/build-doc.yml | 2 +- .github/workflows/macos-cpu-wheels.yml | 8 +- .github/workflows/publish_to_pypi.yml | 6 +- .github/workflows/python_style_check.yml | 6 +- .github/workflows/run-cpp-test.yaml | 12 +- .github/workflows/run-cpp-websocket-test.yaml | 6 +- .github/workflows/run-python-test.yaml | 6 +- .github/workflows/style_check.yml | 6 +- .../test-offline-websocket-rtf-wer.yaml | 8 +- .../test-online-websocket-rtf-wer.yaml | 8 +- .github/workflows/ubuntu-cpu-wheels.yml | 8 +- .github/workflows/ubuntu-cuda-wheels.yml | 6 +- .github/workflows/windows-x64-cpu-wheels.yml | 6 +- .../offline-recognizer-sense-voice-impl.h | 4 + .../cpp_api/offline-recognizer-whisper-impl.h | 363 ++++++++++++++++++ sherpa/cpp_api/offline-recognizer.cc | 8 +- sherpa/cpp_api/offline-stream.h | 5 + sherpa/csrc/CMakeLists.txt | 15 +- sherpa/csrc/base64-decode.cc | 67 ++++ sherpa/csrc/base64-decode.h | 19 + sherpa/csrc/fbank-features.cc | 15 +- sherpa/csrc/fbank-features.h | 4 +- sherpa/csrc/offline-ctc-one-best-decoder.cc | 3 +- sherpa/csrc/offline-model-config.cc | 6 + sherpa/csrc/offline-model-config.h | 4 + .../offline-sense-voice-model-meta-data.h | 2 +- sherpa/csrc/offline-sense-voice-model.cc | 24 +- sherpa/csrc/offline-sense-voice-model.h | 2 + sherpa/csrc/offline-stream.cc | 88 ++++- ...-transducer-modified-beam-search-decoder.h | 9 +- sherpa/csrc/offline-whisper-model-config.cc | 65 ++++ sherpa/csrc/offline-whisper-model-config.h | 44 +++ .../csrc/offline-whisper-model-meta-data.cc | 68 ++++ sherpa/csrc/offline-whisper-model-meta-data.h | 50 +++ sherpa/csrc/offline-whisper-model.cc | 243 ++++++++++++ sherpa/csrc/offline-whisper-model.h | 68 ++++ sherpa/csrc/online-stream.cc | 4 +- ...transducer-modified-beam-search-decoder.cc | 2 +- ...-transducer-modified-beam-search-decoder.h | 4 +- sherpa/csrc/parse-options.cc | 128 +----- sherpa/csrc/symbol-table.cc | 13 +- sherpa/csrc/symbol-table.h | 7 +- sherpa/csrc/text-utils.cc | 349 +++++++++++++++++ sherpa/csrc/text-utils.h | 123 ++++++ sherpa/python/csrc/CMakeLists.txt | 1 + sherpa/python/csrc/offline-model-config.cc | 10 +- .../csrc/offline-sense-voice-model-config.cc | 1 + .../csrc/offline-whisper-model-config.cc | 27 ++ .../csrc/offline-whisper-model-config.h | 15 + sherpa/python/sherpa/__init__.py | 2 + 51 files changed, 1815 insertions(+), 219 deletions(-) create mode 100755 .github/scripts/run-offline-whisper.sh create mode 100644 sherpa/cpp_api/offline-recognizer-whisper-impl.h create mode 100644 sherpa/csrc/base64-decode.cc create mode 100644 sherpa/csrc/base64-decode.h create mode 100644 sherpa/csrc/offline-whisper-model-config.cc create mode 100644 sherpa/csrc/offline-whisper-model-config.h create mode 100644 sherpa/csrc/offline-whisper-model-meta-data.cc create mode 100644 sherpa/csrc/offline-whisper-model-meta-data.h create mode 100644 sherpa/csrc/offline-whisper-model.cc create mode 100644 sherpa/csrc/offline-whisper-model.h create mode 100644 sherpa/csrc/text-utils.cc create mode 100644 sherpa/csrc/text-utils.h create mode 100644 sherpa/python/csrc/offline-whisper-model-config.cc create mode 100644 sherpa/python/csrc/offline-whisper-model-config.h diff --git a/.github/scripts/run-offline-whisper.sh b/.github/scripts/run-offline-whisper.sh new file mode 100755 index 000000000..4e04358b1 --- /dev/null +++ b/.github/scripts/run-offline-whisper.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "==========================================================================" +model_list=( +base +base.en +distil-large-v2 +distil-medium.en +distil-small.en +medium +medium.en +small +small.en +tiny +tiny.en +turbo +) + +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-zh-wenet-aishell2/resolve/main/test_wavs/0.wav +mv 0.wav zh.wav + +for m in ${model_list[@]}; do + d=sherpa-whisper-$m + log "----------testing $d----------" + curl -SL -O https://github.com/k2-fsa/sherpa/releases/download/asr-models/$d.tar.bz2 + tar xvf $d.tar.bz2 + rm $d.tar.bz2 + ls -lh $d + + if [[ $d == *en ]]; then + log "decode a single file" + + ./build/bin/sherpa-offline \ + --debug=1 \ + --whisper-model=./$d/model.pt \ + --tokens=./$d/tokens.txt \ + ./$d/test_wavs/0.wav + + log "decode two files" + ./build/bin/sherpa-offline \ + --debug=1 \ + --whisper-model=./$d/model.pt \ + --tokens=./$d/tokens.txt \ + ./$d/test_wavs/0.wav \ + ./$d/test_wavs/1.wav + fi + + if [[ $d != *en ]]; then + + log "decode a single file" + + ./build/bin/sherpa-offline \ + --debug=1 \ + --whisper-model=./$d/model.pt \ + --tokens=./$d/tokens.txt \ + ./$d/test_wavs/0.wav + + log "decode two files" + ./build/bin/sherpa-offline \ + --debug=1 \ + --whisper-model=./$d/model.pt \ + --tokens=./$d/tokens.txt \ + ./$d/test_wavs/0.wav \ + ./$d/test_wavs/1.wav + + log "decode three files" + ./build/bin/sherpa-offline \ + --debug=1 \ + --whisper-model=./$d/model.pt \ + --tokens=./$d/tokens.txt \ + ./$d/test_wavs/0.wav \ + ./$d/test_wavs/1.wav \ + ./zh.wav + fi + rm -rf $d +done diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index 6b7efb1ac..444552bad 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -49,7 +49,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.8] + python-version: ["3.10"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v4 diff --git a/.github/workflows/macos-cpu-wheels.yml b/.github/workflows/macos-cpu-wheels.yml index 2e04da114..ea5f549ed 100644 --- a/.github/workflows/macos-cpu-wheels.yml +++ b/.github/workflows/macos-cpu-wheels.yml @@ -17,7 +17,7 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Generating build matrix @@ -38,12 +38,12 @@ jobs: ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -73,7 +73,7 @@ jobs: ls -lh ./wheelhouse/ - name: Upload Wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-latest-cpu path: wheelhouse/*.whl diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index 7fbd9ba4f..1a845398b 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -21,14 +21,14 @@ jobs: pypi: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.10" - name: Install Python dependencies shell: bash diff --git a/.github/workflows/python_style_check.yml b/.github/workflows/python_style_check.yml index d4c659e5a..8aba03011 100644 --- a/.github/workflows/python_style_check.yml +++ b/.github/workflows/python_style_check.yml @@ -44,16 +44,16 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.8] + python-version: ["3.10"] fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/run-cpp-test.yaml b/.github/workflows/run-cpp-test.yaml index 8ed7ba874..b02a43093 100644 --- a/.github/workflows/run-cpp-test.yaml +++ b/.github/workflows/run-cpp-test.yaml @@ -49,18 +49,17 @@ concurrency: jobs: run_cpp_tests: - if: github.event.label.name == 'ready' || github.event.label.name == 'cpp' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [ubuntu-latest] torch: ["2.1.2"] - python-version: ["3.8"] + python-version: ["3.10"] build_type: ["Release", "Debug"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -70,7 +69,7 @@ jobs: key: ${{ matrix.os }}-${{ matrix.torch }}-${{ matrix.python-version }}-${{ matrix.build_type }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -174,6 +173,11 @@ jobs: ./bin/sherpa-offline --help + - name: Run offline whisper + shell: bash + run: | + .github/scripts/run-offline-whisper.sh + - name: Run offline sense-voice shell: bash run: | diff --git a/.github/workflows/run-cpp-websocket-test.yaml b/.github/workflows/run-cpp-websocket-test.yaml index 336396e18..3e91d76fb 100644 --- a/.github/workflows/run-cpp-websocket-test.yaml +++ b/.github/workflows/run-cpp-websocket-test.yaml @@ -33,11 +33,11 @@ jobs: matrix: os: [ubuntu-latest] torch: ["1.13.1"] - python-version: ["3.8"] + python-version: ["3.10"] build_type: ["Release", "Debug"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -47,7 +47,7 @@ jobs: key: ${{ matrix.os }}-${{ matrix.torch }}-${{ matrix.python-version }}-${{ matrix.build_type }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 907eb1e44..a2167001e 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -38,11 +38,11 @@ jobs: os: [ubuntu-latest] torch: ["1.13.1"] torchaudio: ["0.13.1"] - python-version: ["3.8"] + python-version: ["3.10"] build_type: ["Release"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -52,7 +52,7 @@ jobs: key: ${{ matrix.os }}-${{ matrix.torch }}-${{ matrix.python-version }}-${{ matrix.build_type }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index f7ad8fe7f..5e60e94a1 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -42,16 +42,16 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ["3.10"] fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-offline-websocket-rtf-wer.yaml b/.github/workflows/test-offline-websocket-rtf-wer.yaml index fef7aec6f..24c359b40 100644 --- a/.github/workflows/test-offline-websocket-rtf-wer.yaml +++ b/.github/workflows/test-offline-websocket-rtf-wer.yaml @@ -36,12 +36,12 @@ jobs: os: [ubuntu-latest] torch: ["1.13.1"] torchaudio: ["0.13.1"] - python-version: ["3.8"] + python-version: ["3.10"] decoding_method: ["greedy_search", "modified_beam_search"] num_connections: [50, 100, 200] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -55,7 +55,7 @@ jobs: echo "Number of CPU cores: ${{ steps.cpu-cores.outputs.count }}" - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -115,7 +115,7 @@ jobs: cat ./log.txt - name: Upload decoding results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: os-${{ matrix.os }}-decoding-method-${{ matrix.decoding_method }} path: ./*.txt diff --git a/.github/workflows/test-online-websocket-rtf-wer.yaml b/.github/workflows/test-online-websocket-rtf-wer.yaml index 3649a5c60..bcb7a927d 100644 --- a/.github/workflows/test-online-websocket-rtf-wer.yaml +++ b/.github/workflows/test-online-websocket-rtf-wer.yaml @@ -36,12 +36,12 @@ jobs: os: [ubuntu-latest] torch: ["1.13.1"] torchaudio: ["0.13.1"] - python-version: ["3.8"] + python-version: ["3.10"] decoding_method: ["greedy_search", "modified_beam_search"] num_connections: [50, 100, 200] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -55,7 +55,7 @@ jobs: echo "Number of CPU cores: ${{ steps.cpu-cores.outputs.count }}" - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -114,7 +114,7 @@ jobs: cat ./log.txt - name: Upload decoding results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: os-${{ matrix.os }}-decoding-method-${{ matrix.decoding_method }} path: ./*.txt diff --git a/.github/workflows/ubuntu-cpu-wheels.yml b/.github/workflows/ubuntu-cpu-wheels.yml index dbb7f1480..c1eaccea2 100644 --- a/.github/workflows/ubuntu-cpu-wheels.yml +++ b/.github/workflows/ubuntu-cpu-wheels.yml @@ -17,7 +17,7 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Generating build matrix @@ -38,12 +38,12 @@ jobs: ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -88,7 +88,7 @@ jobs: ls -lh ./wheelhouse/ - name: Upload Wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu path: wheelhouse/*.whl diff --git a/.github/workflows/ubuntu-cuda-wheels.yml b/.github/workflows/ubuntu-cuda-wheels.yml index c7fafae85..0e8469eca 100644 --- a/.github/workflows/ubuntu-cuda-wheels.yml +++ b/.github/workflows/ubuntu-cuda-wheels.yml @@ -17,7 +17,7 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Generating build matrix @@ -52,7 +52,7 @@ jobs: echo "github.workspace ${{ github.workspace }}" - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -122,7 +122,7 @@ jobs: ls -lh ./wheelhouse/ - name: Upload Wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cuda path: wheelhouse/*.whl diff --git a/.github/workflows/windows-x64-cpu-wheels.yml b/.github/workflows/windows-x64-cpu-wheels.yml index 4bb201553..e647f1187 100644 --- a/.github/workflows/windows-x64-cpu-wheels.yml +++ b/.github/workflows/windows-x64-cpu-wheels.yml @@ -19,7 +19,7 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Generating build matrix @@ -40,7 +40,7 @@ jobs: ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -60,7 +60,7 @@ jobs: ls -lh ./wheelhouse/ - name: Upload Wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-latest-cpu path: wheelhouse/*.whl diff --git a/sherpa/cpp_api/offline-recognizer-sense-voice-impl.h b/sherpa/cpp_api/offline-recognizer-sense-voice-impl.h index a0f1216a4..7a18dac19 100644 --- a/sherpa/cpp_api/offline-recognizer-sense-voice-impl.h +++ b/sherpa/cpp_api/offline-recognizer-sense-voice-impl.h @@ -4,6 +4,10 @@ #ifndef SHERPA_CPP_API_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ #define SHERPA_CPP_API_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ +#include +#include +#include +#include #include "sherpa/csrc/macros.h" #include "sherpa/csrc/offline-ctc-decoder.h" diff --git a/sherpa/cpp_api/offline-recognizer-whisper-impl.h b/sherpa/cpp_api/offline-recognizer-whisper-impl.h new file mode 100644 index 000000000..c44646f31 --- /dev/null +++ b/sherpa/cpp_api/offline-recognizer-whisper-impl.h @@ -0,0 +1,363 @@ +// sherpa/cpp_api/offline-recognizer-whisper-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CPP_API_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ +#define SHERPA_CPP_API_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/offline-whisper-model.h" +#include "sherpa/csrc/symbol-table.h" + +namespace sherpa { + +static OfflineRecognitionResult Convert(const std::vector &tokens, + const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(tokens.size()); + + std::string text; + for (auto i : tokens) { + auto sym = sym_table[i]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + return r; +} + +class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) + : config_(config), symbol_table_(config.model.tokens) { + symbol_table_.ApplyBase64Decode(); + + model_ = std::make_unique(config.model); + + config_.feat_config.normalize_samples = true; + + auto whisper_opts = kaldifeat::WhisperFbankOptions(); + whisper_opts.num_mels = model_->GetModelMetadata().n_mels; + + whisper_ = std::make_unique(whisper_opts); + } + + std::unique_ptr CreateStream() override { + return std::make_unique(whisper_.get(), config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) override { + InferenceMode no_grad; + if (n == 1) { + DecodeStream(ss[0]); + return; + } + + auto device = model_->Device(); + +#if 0 + // TODO(fangjun): Figure out why this branch does not work. + // All wave files are decoded into the same result like the first wave file + std::vector features_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto features = ss[i]->GetFeatures(); + features = PadOrTrimFeatures(features); + features_vec[i] = PadOrTrimFeatures(features); + } + + auto features = torch::stack(features_vec, 0).to(device).permute({0, 2, 1}); + + torch::Tensor n_layer_cross_k_cache; + torch::Tensor n_layer_cross_v_cache; + + std::tie(n_layer_cross_k_cache, n_layer_cross_v_cache) = + model_->RunEncoder(features); +#else + std::vector n_layer_cross_k_cache_list; + std::vector n_layer_cross_v_cache_list; + + for (int32_t i = 0; i != n; ++i) { + auto features = ss[i]->GetFeatures(); + features = PadOrTrimFeatures(features).to(device).t().unsqueeze(0); + + torch::Tensor n_layer_cross_k_cache; + torch::Tensor n_layer_cross_v_cache; + + std::tie(n_layer_cross_k_cache, n_layer_cross_v_cache) = + model_->RunEncoder(features); + n_layer_cross_k_cache_list.push_back(n_layer_cross_k_cache); + n_layer_cross_v_cache_list.push_back(n_layer_cross_v_cache); + } + + torch::Tensor n_layer_cross_k_cache = + torch::cat(n_layer_cross_k_cache_list, 1); + torch::Tensor n_layer_cross_v_cache = + torch::cat(n_layer_cross_v_cache_list, 1); +#endif + + auto meta_data = model_->GetModelMetadata(); + auto sot_sequence = meta_data.sot_sequence; + sot_sequence.push_back(meta_data.no_timestamps); + torch::Tensor tokens = + torch::tensor(sot_sequence, torch::dtype(torch::kLong).device(device)) + .reshape({1, -1}) + .repeat({n, 1}); + + if (meta_data.is_multilingual) { + // sot_sequence: [sot, language, task, notimestamp] + auto language = config_.model.whisper.language; + if (!language.empty()) { + if (!meta_data.lang2id.count(language)) { + SHERPA_LOG(FATAL) << "language '" << language << " is not valid"; + } + tokens.index_put_({"...", 1}, meta_data.lang2id.at(language)); + } else { + if (config_.model.debug) { + SHERPA_LOGE("Begin to detect language"); + } + auto detected_language = model_->DetectLanguage(n_layer_cross_k_cache, + n_layer_cross_v_cache); + tokens.index_put_({"...", 1}, detected_language); + + if (config_.model.debug) { + detected_language = detected_language.cpu(); + auto acc = detected_language.accessor(); + for (int32_t i = 0; i != n; ++i) { + SHERPA_LOGE("Wave %d: detected language: %s", i, + meta_data.id2lang.at(acc[i]).c_str()); + } + } + } + + if (config_.model.whisper.task == "translate") { + tokens.index_put_({"...", 2}, meta_data.translate); + } + } + + torch::Tensor logits; + + torch::Tensor n_layer_self_k_cache = + torch::zeros({meta_data.n_text_layer, n, meta_data.n_text_ctx, + meta_data.n_text_state}, + torch::dtype(torch::kFloat).device(device)); + + torch::Tensor n_layer_self_v_cache = + torch::zeros({meta_data.n_text_layer, n, meta_data.n_text_ctx, + meta_data.n_text_state}, + torch::dtype(torch::kFloat).device(device)); + + torch::Tensor offset = + torch::zeros({n}, torch::dtype(torch::kInt).device(device)); + + std::tie(logits, n_layer_self_k_cache, n_layer_self_v_cache) = + model_->RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, + offset); + + torch::Tensor eot = torch::tensor( + {meta_data.eot}, torch::dtype(torch::kLong).device(device)); + + torch::Tensor results = + torch::full({n, meta_data.n_text_ctx}, meta_data.eot, + torch::dtype(torch::kLong).device(device)); + + torch::Tensor num_decoded_tokens = + torch::zeros({n}, torch::dtype(torch::kLong).device(device)); + + torch::Tensor new2old = + torch::arange(n, torch::dtype(torch::kLong).device(device)); + + for (int32_t i = 0; i < meta_data.n_text_ctx; ++i) { + tokens = logits.slice(1, -1).argmax(-1); + torch::Tensor eot_indexes = (tokens.squeeze() == eot).nonzero().squeeze(); + + if (eot_indexes.numel()) { + num_decoded_tokens.index_put_( + {"...", new2old.index_select(0, eot_indexes)}, i); + + if (eot_indexes.numel() == tokens.size(0)) { + break; + } + + torch::Tensor non_eot_indexes = + (tokens.squeeze() != eot).nonzero().squeeze(); + + tokens = tokens.index_select(0, non_eot_indexes); + + offset = offset.index_select(0, non_eot_indexes); + new2old = new2old.index_select(0, non_eot_indexes); + n_layer_cross_k_cache = + n_layer_cross_k_cache.index_select(1, non_eot_indexes); + n_layer_cross_v_cache = + n_layer_cross_v_cache.index_select(1, non_eot_indexes); + n_layer_self_k_cache = + n_layer_self_k_cache.index_select(1, non_eot_indexes); + n_layer_self_v_cache = + n_layer_self_v_cache.index_select(1, non_eot_indexes); + } + + results.index_put_({new2old, i}, tokens.squeeze()); + offset.add_(logits.size(1)); + + std::tie(logits, n_layer_self_k_cache, n_layer_self_v_cache) = + model_->RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, + offset); + } + num_decoded_tokens = num_decoded_tokens.cpu(); + auto acc = num_decoded_tokens.accessor(); + results = results.cpu(); + auto p = results.data_ptr(); + for (int32_t i = 0; i != n; ++i) { + auto token_ids = std::vector{p + i * results.size(1), + p + i * results.size(1) + acc[i]}; + + ss[i]->SetResult(Convert(token_ids, symbol_table_)); + } + } + + private: + void DecodeStream(OfflineStream *s) { + auto device = model_->Device(); + + torch::Tensor features = s->GetFeatures(); + features = PadOrTrimFeatures(features); + features = features.t().unsqueeze(0).to(device); + + torch::Tensor n_layer_cross_k_cache; + torch::Tensor n_layer_cross_v_cache; + + std::tie(n_layer_cross_k_cache, n_layer_cross_v_cache) = + model_->RunEncoder(features); + + auto meta_data = model_->GetModelMetadata(); + auto sot_sequence = meta_data.sot_sequence; + sot_sequence.push_back(meta_data.no_timestamps); + + if (meta_data.is_multilingual) { + // sot_sequence: [sot, language, task, notimestamp] + auto language = config_.model.whisper.language; + if (!language.empty()) { + if (!meta_data.lang2id.count(language)) { + SHERPA_LOG(FATAL) << "language '" << language << " is not valid"; + } + + sot_sequence[1] = meta_data.lang2id.at(language); + } else { + if (config_.model.debug) { + SHERPA_LOGE("Begin to detect language"); + } + sot_sequence[1] = + model_->DetectLanguage(n_layer_cross_k_cache, n_layer_cross_v_cache) + .item() + .toInt(); + if (config_.model.debug) { + SHERPA_LOGE("Detected language: %s", + meta_data.id2lang.at(sot_sequence[1]).c_str()); + } + } + + if (config_.model.whisper.task == "translate") { + sot_sequence[2] = meta_data.translate; + } + } + + torch::Tensor tokens = + torch::from_blob(sot_sequence.data(), + {1, static_cast(sot_sequence.size())}, + torch::kLong) + .to(device); + + torch::Tensor logits; + + torch::Tensor n_layer_self_k_cache = + torch::zeros({meta_data.n_text_layer, 1, meta_data.n_text_ctx, + meta_data.n_text_state}, + torch::dtype(torch::kFloat).device(device)); + + torch::Tensor n_layer_self_v_cache = + torch::zeros({meta_data.n_text_layer, 1, meta_data.n_text_ctx, + meta_data.n_text_state}, + torch::dtype(torch::kFloat).device(device)); + + torch::Tensor offset = + torch::zeros({1}, torch::dtype(torch::kInt).device(device)); + + std::tie(logits, n_layer_self_k_cache, n_layer_self_v_cache) = + model_->RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, + offset); + + torch::Tensor eot = torch::tensor( + {meta_data.eot}, torch::dtype(torch::kLong).device(device)); + + torch::Tensor results = + torch::full({1, meta_data.n_text_ctx}, meta_data.eot, + torch::dtype(torch::kLong).device(device)); + + int32_t i; + for (i = 0; i < meta_data.n_text_ctx; ++i) { + tokens = logits.slice(1, -1).argmax(-1); + if ((tokens == eot).sum().item().toInt() == 1) { + break; + } + results.slice(1, i, i + 1) = tokens; + offset.add_(logits.size(1)); + + std::tie(logits, n_layer_self_k_cache, n_layer_self_v_cache) = + model_->RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, + offset); + } + results = results.slice(1, 0, i).cpu(); + + std::vector token_ids = { + results.data_ptr(), + results.data_ptr() + results.numel()}; + + s->SetResult(Convert(token_ids, symbol_table_)); + } + + private: + void WarmUp() { + SHERPA_LOG(INFO) << "WarmUp begins"; + + SHERPA_LOG(INFO) << "WarmUp ended"; + } + + torch::Tensor PadOrTrimFeatures(const torch::Tensor &feat) { + auto features = feat; + int32_t target_len = 3000; + int32_t src_len = features.size(0); + if (src_len > target_len) { + SHERPA_LOGE( + "\nInput audio is too long (about %.3f seconds). Only the first %d " + "seconds are used.", + src_len * 0.01, static_cast(target_len * 0.01)); + features = features.slice(0, 0, target_len); + } else if (src_len < target_len) { + int32_t padding = target_len - src_len; + features = torch::nn::functional::pad( + features, torch::nn::functional::PadFuncOptions({0, 0, 0, padding}) + .mode(torch::kConstant) + .value(0)); + } + + return features; + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr whisper_; + std::unique_ptr model_; +}; +} // namespace sherpa +#endif // SHERPA_CPP_API_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ diff --git a/sherpa/cpp_api/offline-recognizer.cc b/sherpa/cpp_api/offline-recognizer.cc index 819127dff..cab51a3f7 100644 --- a/sherpa/cpp_api/offline-recognizer.cc +++ b/sherpa/cpp_api/offline-recognizer.cc @@ -11,6 +11,7 @@ #include "sherpa/cpp_api/offline-recognizer-impl.h" #include "sherpa/cpp_api/offline-recognizer-sense-voice-impl.h" #include "sherpa/cpp_api/offline-recognizer-transducer-impl.h" +#include "sherpa/cpp_api/offline-recognizer-whisper-impl.h" #include "sherpa/csrc/file-utils.h" #include "sherpa/csrc/log.h" #include "torch/script.h" @@ -130,7 +131,7 @@ void OfflineRecognizerConfig::Validate() const { } AssertFileExists(tokens); - if (!model.sense_voice.model.empty()) { + if (!model.sense_voice.model.empty() || !model.whisper.model.empty()) { model.tokens = tokens; model.use_gpu = use_gpu; if (!model.Validate()) { @@ -194,6 +195,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) { return; } + if (!config.model.whisper.model.empty()) { + impl_ = std::make_unique(config); + return; + } + if (!config.nn_model.empty()) { torch::jit::Module m = torch::jit::load(config.nn_model, torch::kCPU); if (!m.hasattr("joiner")) { diff --git a/sherpa/cpp_api/offline-stream.h b/sherpa/cpp_api/offline-stream.h index 274e162b9..8deec738a 100644 --- a/sherpa/cpp_api/offline-stream.h +++ b/sherpa/cpp_api/offline-stream.h @@ -9,6 +9,7 @@ #include #include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/whisper-fbank.h" #include "sherpa/cpp_api/feature-config.h" #include "sherpa/csrc/context-graph.h" #include "torch/script.h" @@ -55,6 +56,10 @@ class OfflineStream { OfflineStream(kaldifeat::Fbank *fbank, const FeatureConfig &feat_config, ContextGraphPtr context_graph = nullptr); + OfflineStream(kaldifeat::WhisperFbank *whisper, + const FeatureConfig &feat_config, + ContextGraphPtr context_graph = nullptr); + /** Create a stream from a WAVE file. * * @param wave_file Path to the WAVE file. Its sample frequency should diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt index 628d63e80..73fda44a7 100644 --- a/sherpa/csrc/CMakeLists.txt +++ b/sherpa/csrc/CMakeLists.txt @@ -1,5 +1,6 @@ # Please sort the filenames alphabetically set(sherpa_srcs + base64-decode.cc byte_util.cc context-graph.cc fbank-features.cc @@ -8,15 +9,22 @@ set(sherpa_srcs log.cc offline-conformer-ctc-model.cc offline-conformer-transducer-model.cc - offline-ctc-one-best-decoder.cc offline-ctc-greedy-search-decoder.cc + offline-ctc-one-best-decoder.cc + offline-model-config.cc offline-nemo-enc-dec-ctc-model-bpe.cc + offline-sense-voice-model-config.cc + offline-sense-voice-model-meta-data.cc + offline-sense-voice-model.cc offline-stream.cc offline-transducer-fast-beam-search-decoder.cc offline-transducer-greedy-search-decoder.cc offline-transducer-modified-beam-search-decoder.cc offline-wav2vec2-ctc-model.cc offline-wenet-conformer-ctc-model.cc + offline-whisper-model-config.cc + offline-whisper-model-meta-data.cc + offline-whisper-model.cc online-conformer-transducer-model.cc online-conv-emformer-transducer-model.cc online-emformer-transducer-model.cc @@ -30,10 +38,7 @@ set(sherpa_srcs parse-options.cc resample.cc symbol-table.cc - offline-sense-voice-model-config.cc - offline-sense-voice-model.cc - offline-sense-voice-model-meta-data.cc - offline-model-config.cc + text-utils.cc ) add_library(sherpa_core ${sherpa_srcs}) diff --git a/sherpa/csrc/base64-decode.cc b/sherpa/csrc/base64-decode.cc new file mode 100644 index 000000000..0b9a35df8 --- /dev/null +++ b/sherpa/csrc/base64-decode.cc @@ -0,0 +1,67 @@ +// sherpa/csrc/base64-decode.cc +// +// Copyright (c) 2022-2025 Xiaomi Corporation + +#include "sherpa/csrc/base64-decode.h" + +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +static int32_t Ord(char c) { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + ('Z' - 'A') + 1; + } else if (c >= '0' && c <= '9') { + return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2; + } else if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + + SHERPA_LOGE("Unknown character %d, %c\n", c, c); + + exit(-1); +} + +// see +// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243 +std::string Base64Decode(const std::string &s) { + if (s.empty()) { + SHERPA_LOGE("Empty string!"); + exit(-1); + } + + int32_t n = static_cast(s.size()) / 4 * 3; + + std::string ans; + ans.reserve(n); + + int32_t i = 0; + while (i < static_cast(s.size())) { + if (s[i] == '=') { + return " "; + } + + int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4); + ans.push_back(static_cast(first)); + + if (i + 2 < static_cast(s.size()) && s[i + 2] != '=') { + int32_t second = + ((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2); + ans.push_back(static_cast(second)); + + if (i + 3 < static_cast(s.size()) && s[i + 3] != '=') { + int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]); + ans.push_back(static_cast(third)); + } + } + i += 4; + } + + return ans; +} + +} // namespace sherpa diff --git a/sherpa/csrc/base64-decode.h b/sherpa/csrc/base64-decode.h new file mode 100644 index 000000000..f922c94dd --- /dev/null +++ b/sherpa/csrc/base64-decode.h @@ -0,0 +1,19 @@ +// sherpa/csrc/base64-decode.h +// +// Copyright (c) 2022-2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_BASE64_DECODE_H_ +#define SHERPA_CSRC_BASE64_DECODE_H_ + +#include + +namespace sherpa { + +/** @param s A base64 encoded string. + * @return Return the decoded string. + */ +std::string Base64Decode(const std::string &s); + +} // namespace sherpa + +#endif // SHERPA_CSRC_BASE64_DECODE_H_ diff --git a/sherpa/csrc/fbank-features.cc b/sherpa/csrc/fbank-features.cc index d1a809a6c..9aa6b4e4a 100644 --- a/sherpa/csrc/fbank-features.cc +++ b/sherpa/csrc/fbank-features.cc @@ -20,6 +20,8 @@ #include "kaldi_native_io/csrc/kaldi-io.h" #include "kaldi_native_io/csrc/wave-reader.h" +#include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/whisper-fbank.h" #include "sherpa/csrc/log.h" #include "torch/script.h" @@ -54,8 +56,9 @@ std::pair ReadWave(const std::string &filename, return {tensor / 32768, wave_data.Duration()}; } +template std::vector ComputeFeatures( - kaldifeat::Fbank &fbank, // NOLINT + FbankComputer &fbank, // NOLINT const std::vector &wave_data, std::vector *num_frames /*=nullptr*/) { const auto &frame_opts = fbank.GetOptions().frame_opts; @@ -84,4 +87,14 @@ std::vector ComputeFeatures( return ans; } +template std::vector ComputeFeatures( + kaldifeat::Fbank &fbank, // NOLINT + const std::vector &wave_data, + std::vector *num_frames = nullptr); + +template std::vector ComputeFeatures( + kaldifeat::WhisperFbank &fbank, // NOLINT + const std::vector &wave_data, + std::vector *num_frames = nullptr); + } // namespace sherpa diff --git a/sherpa/csrc/fbank-features.h b/sherpa/csrc/fbank-features.h index bd05bf87e..9cf6c95b8 100644 --- a/sherpa/csrc/fbank-features.h +++ b/sherpa/csrc/fbank-features.h @@ -23,7 +23,6 @@ #include #include -#include "kaldifeat/csrc/feature-fbank.h" #include "torch/script.h" namespace sherpa { @@ -60,8 +59,9 @@ std::pair ReadWave(const std::string &filename, * number of feature frames and the number of columns equals to the * feature dimension. */ +template std::vector ComputeFeatures( - kaldifeat::Fbank &fbank, // NOLINT + FbankComputer &fbank, // NOLINT const std::vector &wave_data, std::vector *num_frames = nullptr); } // namespace sherpa diff --git a/sherpa/csrc/offline-ctc-one-best-decoder.cc b/sherpa/csrc/offline-ctc-one-best-decoder.cc index a46cacdcd..795ef9085 100644 --- a/sherpa/csrc/offline-ctc-one-best-decoder.cc +++ b/sherpa/csrc/offline-ctc-one-best-decoder.cc @@ -68,7 +68,8 @@ std::vector OfflineCtcOneBestDecoder::Decode( last_token_is_blank = true; continue; } - if (t != 0 && !p->tokens.empty() && token == p->tokens.back() && (!last_token_is_blank)) { + if (t != 0 && !p->tokens.empty() && token == p->tokens.back() && + (!last_token_is_blank)) { // This is a repeat, skip it. ++t; continue; diff --git a/sherpa/csrc/offline-model-config.cc b/sherpa/csrc/offline-model-config.cc index aa935e072..588a605e6 100644 --- a/sherpa/csrc/offline-model-config.cc +++ b/sherpa/csrc/offline-model-config.cc @@ -12,6 +12,7 @@ namespace sherpa { void OfflineModelConfig::Register(ParseOptions *po) { sense_voice.Register(po); + whisper.Register(po); // TODO(fangjun): enable it // po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -33,6 +34,10 @@ bool OfflineModelConfig::Validate() const { return sense_voice.Validate(); } + if (!whisper.model.empty()) { + return whisper.Validate(); + } + return true; } @@ -41,6 +46,7 @@ std::string OfflineModelConfig::ToString() const { os << "OfflineModelConfig("; os << "sense_voice=" << sense_voice.ToString() << ", "; + os << "whisper=" << whisper.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "debug=" << (debug ? "True" : "False") << ", "; os << "use_gpu=" << (debug ? "True" : "False") << ")"; diff --git a/sherpa/csrc/offline-model-config.h b/sherpa/csrc/offline-model-config.h index 3f89c2b2d..310649f17 100644 --- a/sherpa/csrc/offline-model-config.h +++ b/sherpa/csrc/offline-model-config.h @@ -8,11 +8,13 @@ #include "sherpa/cpp_api/parse-options.h" #include "sherpa/csrc/offline-sense-voice-model-config.h" +#include "sherpa/csrc/offline-whisper-model-config.h" namespace sherpa { struct OfflineModelConfig { OfflineSenseVoiceModelConfig sense_voice; + OfflineWhisperModelConfig whisper; std::string tokens; bool debug = false; @@ -20,8 +22,10 @@ struct OfflineModelConfig { OfflineModelConfig() = default; OfflineModelConfig(const OfflineSenseVoiceModelConfig &sense_voice, + const OfflineWhisperModelConfig &whisper, const std::string &tokens, bool debug, bool use_gpu) : sense_voice(sense_voice), + whisper(whisper), tokens(tokens), debug(debug), use_gpu(use_gpu) {} diff --git a/sherpa/csrc/offline-sense-voice-model-meta-data.h b/sherpa/csrc/offline-sense-voice-model-meta-data.h index eb9910521..ba796ae75 100644 --- a/sherpa/csrc/offline-sense-voice-model-meta-data.h +++ b/sherpa/csrc/offline-sense-voice-model-meta-data.h @@ -50,4 +50,4 @@ struct OfflineSenseVoiceModelMetaData { }; } // namespace sherpa -#endif // SHERPA_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ +#endif // SHERPA_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ diff --git a/sherpa/csrc/offline-sense-voice-model.cc b/sherpa/csrc/offline-sense-voice-model.cc index 6b6496d58..07f681e98 100644 --- a/sherpa/csrc/offline-sense-voice-model.cc +++ b/sherpa/csrc/offline-sense-voice-model.cc @@ -3,7 +3,9 @@ // Copyright (c) 2025 Xiaomi Corporation #include "sherpa/csrc/offline-sense-voice-model.h" +#include #include +#include #include "sherpa/cpp_api/macros.h" #include "sherpa/csrc/macros.h" @@ -20,18 +22,18 @@ static std::vector ToFloat(const std::string &s) { class OfflineSenseVoiceModel::Impl { public: - Impl(const OfflineModelConfig &config) { + explicit Impl(const OfflineModelConfig &config) { torch::jit::ExtraFilesMap meta_data{ - {"model_type", ""}, {"lfr_window_size", ""}, - {"lfr_window_shift", ""}, {"neg_mean", ""}, - {"inv_stddev", ""}, {"vocab_size", ""}, - {"normalize_samples", ""}, {"version", ""}, - {"model_author", ""}, {"maintainer", ""}, - {"lang_auto", ""}, {"lang_zh", ""}, - {"lang_en", ""}, {"lang_yue", ""}, - {"lang_ja", ""}, {"lang_ko", ""}, - {"lang_nospeech", ""}, {"with_itn", ""}, - {"without_itn", ""}, {"url", ""}, + {"model_type", {}}, {"lfr_window_size", {}}, + {"lfr_window_shift", {}}, {"neg_mean", {}}, + {"inv_stddev", {}}, {"vocab_size", {}}, + {"normalize_samples", {}}, {"version", {}}, + {"model_author", {}}, {"maintainer", {}}, + {"lang_auto", {}}, {"lang_zh", {}}, + {"lang_en", {}}, {"lang_yue", {}}, + {"lang_ja", {}}, {"lang_ko", {}}, + {"lang_nospeech", {}}, {"with_itn", {}}, + {"without_itn", {}}, {"url", {}}, }; if (config.use_gpu) { device_ = torch::Device{torch::kCUDA}; diff --git a/sherpa/csrc/offline-sense-voice-model.h b/sherpa/csrc/offline-sense-voice-model.h index 023d66530..0bdd66362 100644 --- a/sherpa/csrc/offline-sense-voice-model.h +++ b/sherpa/csrc/offline-sense-voice-model.h @@ -4,6 +4,8 @@ #ifndef SHERPA_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ #define SHERPA_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ +#include +#include #include #include "sherpa/csrc/offline-model-config.h" diff --git a/sherpa/csrc/offline-stream.cc b/sherpa/csrc/offline-stream.cc index c59b21cc2..b62b7bca9 100644 --- a/sherpa/csrc/offline-stream.cc +++ b/sherpa/csrc/offline-stream.cc @@ -3,10 +3,10 @@ // Copyright (c) 2022 Xiaomi Corporation #include "sherpa/cpp_api/offline-stream.h" +#include #include #include -#include "nlohmann/json.hpp" #include "sherpa/cpp_api/feature-config.h" #include "sherpa/csrc/fbank-features.h" #include "sherpa/csrc/log.h" @@ -14,25 +14,46 @@ namespace sherpa { std::string OfflineRecognitionResult::AsJsonString() const { - using json = nlohmann::json; - json j; - j["text"] = text; - j["tokens"] = tokens; - std::ostringstream os; - os << "["; + os << "{"; + + os << "\"text\"" + << ": "; + os << std::quoted(text) << ", "; + std::string sep = ""; for (auto t : timestamps) { os << sep << std::fixed << std::setprecision(2) << t; - sep = ","; + sep = ", "; + } + os << "], "; + + os << "\"" + << "tokens" + << "\"" + << ":"; + os << "["; + + sep = ""; + auto oldFlags = os.flags(); + for (const auto &t : tokens) { + if (t.size() == 1 && static_cast(t[0]) > 0x7f) { + const uint8_t *p = reinterpret_cast(t.c_str()); + os << sep << "\"" + << "<0x" << std::hex << std::uppercase << static_cast(p[0]) + << ">" + << "\""; + os.flags(oldFlags); + } else { + os << sep << std::quoted(t); + } + sep = ", "; } os << "]"; - // NOTE: We don't use j["timestamps"] = timestamps; - // because we need to control the number of decimal points to keep - j["timestamps"] = os.str(); + os << "}"; - return j.dump(); + return os.str(); } class OfflineStream::OfflineStreamImpl { @@ -48,9 +69,27 @@ class OfflineStream::OfflineStreamImpl { } } + OfflineStreamImpl(kaldifeat::WhisperFbank *whisper, + const FeatureConfig &feat_config, + ContextGraphPtr context_graph) + : whisper_(whisper), + feat_config_(feat_config), + context_graph_(context_graph) { + if (!feat_config_.nemo_normalize.empty()) { + SHERPA_CHECK_EQ(feat_config_.nemo_normalize, "per_feature") + << "Only per_feature is implemented at present"; + } + } + void AcceptWaveFile(const std::string &wave_file) { - torch::Tensor samples = - ReadWave(wave_file, fbank_->GetFrameOptions().samp_freq).first; + torch::Tensor samples; + if (fbank_) { + samples = ReadWave(wave_file, fbank_->GetFrameOptions().samp_freq).first; + } else { + samples = + ReadWave(wave_file, whisper_->GetFrameOptions().samp_freq).first; + } + if (!feat_config_.normalize_samples) { samples.mul_(32767); } @@ -59,7 +98,11 @@ class OfflineStream::OfflineStreamImpl { // We return audio samples directly, e.g., for Wav2Vec2.0 features_ = samples; } else { - features_ = ComputeFeatures(*fbank_, {samples})[0]; + if (fbank_) { + features_ = ComputeFeatures(*fbank_, {samples})[0]; + } else { + features_ = ComputeFeatures(*whisper_, {samples})[0]; + } features_ = Normalize(features_); } } @@ -76,7 +119,11 @@ class OfflineStream::OfflineStreamImpl { // We return audio samples directly, e.g., for Wav2Vec2.0 features_ = tensor.clone(); } else { - features_ = ComputeFeatures(*fbank_, {tensor})[0]; + if (fbank_) { + features_ = ComputeFeatures(*fbank_, {tensor})[0]; + } else { + features_ = ComputeFeatures(*whisper_, {tensor})[0]; + } features_ = Normalize(features_); } } @@ -117,7 +164,8 @@ class OfflineStream::OfflineStreamImpl { private: torch::Tensor features_; OfflineRecognitionResult result_; - kaldifeat::Fbank *fbank_ = nullptr; // not owned + kaldifeat::Fbank *fbank_ = nullptr; // not owned + kaldifeat::WhisperFbank *whisper_ = nullptr; // not owned FeatureConfig feat_config_; ContextGraphPtr context_graph_; }; @@ -130,6 +178,12 @@ OfflineStream::OfflineStream(kaldifeat::Fbank *fbank, : impl_(std::make_unique(fbank, feat_config, context_graph)) {} +OfflineStream::OfflineStream(kaldifeat::WhisperFbank *whisper, + const FeatureConfig &feat_config, + ContextGraphPtr context_graph /* nullptr */) + : impl_(std::make_unique(whisper, feat_config, + context_graph)) {} + void OfflineStream::AcceptWaveFile(const std::string &filename) { impl_->AcceptWaveFile(filename); } diff --git a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h index a6141b71a..7effc3b5b 100644 --- a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h +++ b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h @@ -15,9 +15,12 @@ namespace sherpa { class OfflineTransducerModifiedBeamSearchDecoder : public OfflineTransducerDecoder { public: - OfflineTransducerModifiedBeamSearchDecoder( - OfflineTransducerModel *model, int32_t num_active_paths, float temperature) - : model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {} + OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, + int32_t num_active_paths, + float temperature) + : model_(model), + num_active_paths_(num_active_paths), + temperature_(temperature) {} /** Run modified beam search given the output from the encoder model. * diff --git a/sherpa/csrc/offline-whisper-model-config.cc b/sherpa/csrc/offline-whisper-model-config.cc new file mode 100644 index 000000000..81b5d8d0c --- /dev/null +++ b/sherpa/csrc/offline-whisper-model-config.cc @@ -0,0 +1,65 @@ +// sherpa/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-whisper-model-config.h" + +#include "sherpa/csrc/file-utils.h" +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +void OfflineWhisperModelConfig::Register(ParseOptions *po) { + po->Register("whisper-model", &model, + "Path to the torchscript model of whisper"); + + po->Register( + "whisper-language", &language, + "The spoken language in the input audio file. Example values: " + "en, de, fr, zh, jp. If it is not given for a multilingual model, we will" + " infer the language from the input audio file. " + "Please refer to " + "https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10" + " for valid values. Note that for non-multilingual models, it supports " + "only 'en'"); + + po->Register("whisper-task", &task, + "Valid values: transcribe, translate. " + "Note that for non-multilingual models, it supports " + "only 'transcribe'"); +} + +bool OfflineWhisperModelConfig::Validate() const { + if (model.empty()) { + SHERPA_LOGE("Please provide --whisper-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_LOGE("whisper model file '%s' does not exist", model.c_str()); + return false; + } + + if (task != "translate" && task != "transcribe") { + SHERPA_LOGE( + "--whisper-task supports only translate and transcribe. Given: %s", + task.c_str()); + + return false; + } + + return true; +} + +std::string OfflineWhisperModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineWhisperModelConfig("; + os << "model=\"" << model << "\", "; + os << "language=\"" << language << "\", "; + os << "task=\"" << task << "\")"; + + return os.str(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-whisper-model-config.h b/sherpa/csrc/offline-whisper-model-config.h new file mode 100644 index 000000000..ee6dbe1f9 --- /dev/null +++ b/sherpa/csrc/offline-whisper-model-config.h @@ -0,0 +1,44 @@ +// sherpa/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include + +#include "sherpa/cpp_api/parse-options.h" + +namespace sherpa { + +struct OfflineWhisperModelConfig { + std::string model; + + // Available languages can be found at + // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + // + // Note: For non-multilingual models, it supports only "en" + // + // If empty, we will infer it from the input audio file when + // the model is multilingual. + std::string language; + + // Valid values are transcribe and translate + // + // Note: For non-multilingual models, it supports only "transcribe" + std::string task = "transcribe"; + + OfflineWhisperModelConfig() = default; + OfflineWhisperModelConfig(const std::string &model, + const std::string &language, + const std::string &task) + : model(model), language(language), task(task) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/sherpa/csrc/offline-whisper-model-meta-data.cc b/sherpa/csrc/offline-whisper-model-meta-data.cc new file mode 100644 index 000000000..d1be779c1 --- /dev/null +++ b/sherpa/csrc/offline-whisper-model-meta-data.cc @@ -0,0 +1,68 @@ +// sherpa/csrc/offline-whisper-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-whisper-model-meta-data.h" + +#include +#include +#include + +namespace sherpa { + +std::string OfflineWhisperModelMetaData::ToString() const { + std::ostringstream os; + + os << "----------whisper meta data----------\n"; + + os << " comment: " << comment << "\n"; + os << " n_mels: " << n_mels << "\n"; + os << " n_audio_ctx: " << n_audio_ctx << "\n"; + os << " n_audio_state: " << n_audio_state << "\n"; + os << " n_audio_head: " << n_audio_head << "\n"; + os << " n_audio_layer: " << n_audio_layer << "\n"; + os << " n_vocab: " << n_vocab << "\n"; + os << " n_text_ctx: " << n_text_ctx << "\n"; + os << " n_text_state: " << n_text_state << "\n"; + os << " n_text_head: " << n_text_head << "\n"; + os << " n_text_layer: " << n_text_layer << "\n"; + os << " sot: " << sot << "\n"; + os << " sot_index: " << sot_index << "\n"; + os << " eot: " << eot << "\n"; + os << " blank_id: " << blank_id << "\n"; + os << " is_multilingual: " << is_multilingual << "\n"; + os << " no_speech: " << no_speech << "\n"; + os << " non_speech_tokens: " << non_speech_tokens << "\n"; + os << " transcribe: " << transcribe << "\n"; + os << " translate: " << translate << "\n"; + os << " sot_prev: " << sot_prev << "\n"; + os << " sot_lm: " << sot_lm << "\n"; + os << " no_timestamps: " << no_timestamps << "\n"; + os << " sot_sequence:"; + for (auto i : sot_sequence) { + os << " " << i; + } + os << "\n"; + + std::vector langs; + langs.reserve(lang2id.size()); + for (const auto &p : lang2id) { + langs.push_back(p.first); + } + std::sort(langs.begin(), langs.end()); + + os << " lang2id: (" << lang2id.size() << ")" << "\n "; + int32_t k = 0; + for (const auto &lang : langs) { + os << lang << " -> " << lang2id.at(lang) << ", "; + k += 1; + if (k % 10 == 0) { + os << "\n "; + } + } + os << "\n"; + + return os.str(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-whisper-model-meta-data.h b/sherpa/csrc/offline-whisper-model-meta-data.h new file mode 100644 index 000000000..b5df0cedd --- /dev/null +++ b/sherpa/csrc/offline-whisper-model-meta-data.h @@ -0,0 +1,50 @@ +// sherpa/csrc/offline-whisper-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_WHISPER_MODEL_META_DATA_H_ +#define SHERPA_CSRC_OFFLINE_WHISPER_MODEL_META_DATA_H_ + +#include +#include +#include + +#include "torch/script.h" + +namespace sherpa { + +struct OfflineWhisperModelMetaData { + int32_t n_mels; + int32_t n_audio_ctx; + int32_t n_audio_state; + int32_t n_audio_head; + int32_t n_audio_layer; + int32_t n_vocab; + int32_t n_text_ctx; + int32_t n_text_state; + int32_t n_text_head; + int32_t n_text_layer; + int32_t sot; + int32_t sot_index; + int32_t eot; + int32_t blank_id; + int32_t is_multilingual; + int32_t no_speech; + int32_t non_speech_tokens; + int32_t transcribe; + int32_t translate; + int32_t sot_prev; + int32_t sot_lm; + int32_t no_timestamps; + + std::string comment; + std::vector sot_sequence; + std::unordered_map lang2id; + std::unordered_map id2lang; + std::vector all_languages_id; + + std::string ToString() const; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_WHISPER_MODEL_META_DATA_H_ diff --git a/sherpa/csrc/offline-whisper-model.cc b/sherpa/csrc/offline-whisper-model.cc new file mode 100644 index 000000000..55b5ed424 --- /dev/null +++ b/sherpa/csrc/offline-whisper-model.cc @@ -0,0 +1,243 @@ +// sherpa/csrc/offline-whisper-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-whisper-model.h" + +#include +#include +#include +#include + +#include "sherpa/cpp_api/macros.h" +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/offline-whisper-model-meta-data.h" +#include "sherpa/csrc/text-utils.h" +namespace sherpa { + +class OfflineWhisperModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) { + torch::jit::ExtraFilesMap meta_data{ + {"model_type", {}}, + {"comment", {}}, + {"version", {}}, + {"maintainer", {}}, + {"n_mels", {}}, + {"n_audio_ctx", {}}, + {"n_audio_state", {}}, + {"n_audio_head", {}}, + {"n_audio_layer", {}}, + {"n_vocab", {}}, + {"n_text_ctx", {}}, + {"n_text_state", {}}, + {"n_text_head", {}}, + {"n_text_layer", {}}, + {"sot_sequence", {}}, + {"all_language_tokens", {}}, + {"all_language_codes", {}}, + {"sot", {}}, + {"sot_index", {}}, + {"eot", {}}, + {"blank_id", {}}, + {"is_multilingual", {}}, + {"no_speech", {}}, + {"non_speech_tokens", {}}, + {"transcribe", {}}, + {"translate", {}}, + {"sot_prev", {}}, + {"sot_lm", {}}, + {"no_timestamps", {}}, + }; + + if (config.use_gpu) { + device_ = torch::Device{torch::kCUDA}; + } + + model_ = torch::jit::load(config.whisper.model, device_, meta_data); + model_.eval(); + + if (meta_data.at("model_type") != "whisper" && + meta_data.at("model_type") != "Whisper") { + SHERPA_LOGE("Expect a whisper model. Given: '%s'", + meta_data.at("model_type").c_str()); + SHERPA_EXIT(-1); + } + + InitMetaData(meta_data); + + if (config.debug) { + SHERPA_LOGE("%s", meta_data_.ToString().c_str()); + } + } + + const OfflineWhisperModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + torch::Device Device() const { return device_; } + + std::pair RunEncoder( + const torch::Tensor &features) { + InferenceMode no_grad; + + auto outputs = model_.run_method("run_encoder", features).toTuple(); + + auto n_layer_cross_k_cache = outputs->elements()[0].toTensor(); + auto n_layer_cross_v_cache = outputs->elements()[1].toTensor(); + + return {n_layer_cross_k_cache, n_layer_cross_v_cache}; + } + + std::tuple RunDecoder( + const torch::Tensor &tokens, torch::Tensor n_layer_self_k_cache, + torch::Tensor n_layer_self_v_cache, torch::Tensor n_layer_cross_k_cache, + torch::Tensor n_layer_cross_v_cache, const torch::Tensor &offset) { + InferenceMode no_grad; + + auto outputs = model_ + .run_method("run_decoder", tokens, n_layer_self_k_cache, + n_layer_self_v_cache, n_layer_cross_k_cache, + n_layer_cross_v_cache, offset) + .toTuple(); + + auto logits = outputs->elements().vec()[0].toTensor(); + n_layer_self_k_cache = outputs->elements().vec()[1].toTensor(); + n_layer_self_v_cache = outputs->elements().vec()[2].toTensor(); + + return std::make_tuple(logits, n_layer_self_k_cache, n_layer_self_v_cache); + } + + torch::Tensor DetectLanguage(const torch::Tensor &n_layer_cross_k_cache, + const torch::Tensor &n_layer_cross_v_cache) { + InferenceMode no_grad; + + int32_t batch_size = n_layer_cross_v_cache.size(1); + torch::Tensor tokens = + torch::tensor({meta_data_.sot}, + torch::dtype(torch::kInt).device(device_)) + .unsqueeze(0) + .repeat({batch_size, 1}); + + torch::Tensor offset = + torch::zeros({batch_size}, torch::dtype(torch::kInt).device(device_)); + + torch::Tensor n_layer_self_k_cache = + torch::zeros({meta_data_.n_text_layer, batch_size, + meta_data_.n_text_ctx, meta_data_.n_text_state}, + torch::dtype(torch::kFloat).device(device_)); + + torch::Tensor n_layer_self_v_cache = + torch::zeros({meta_data_.n_text_layer, batch_size, + meta_data_.n_text_ctx, meta_data_.n_text_state}, + torch::dtype(torch::kFloat).device(device_)); + + auto out = RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, offset); + auto logits = std::get<0>(out); + + torch::Tensor all_languages_id = + torch::tensor(meta_data_.all_languages_id, + torch::dtype(torch::kLong).device(device_)); + torch::Tensor mask = + torch::ones(logits.size(2), torch::dtype(torch::kLong).device(device_)); + + mask.index_put_({all_languages_id}, 0); + + torch::Tensor non_language_indexes = mask.nonzero().squeeze(); + + logits.index_put_({"...", non_language_indexes}, + -std::numeric_limits::infinity()); + + return logits.argmax(-1).squeeze(); + } + + private: + void InitMetaData(const torch::jit::ExtraFilesMap &meta_data) { + meta_data_.comment = meta_data.at("comment"); + meta_data_.n_mels = atoi(meta_data.at("n_mels").c_str()); + meta_data_.n_audio_ctx = atoi(meta_data.at("n_audio_ctx").c_str()); + meta_data_.n_audio_state = atoi(meta_data.at("n_audio_state").c_str()); + meta_data_.n_audio_head = atoi(meta_data.at("n_audio_head").c_str()); + meta_data_.n_audio_layer = atoi(meta_data.at("n_audio_layer").c_str()); + meta_data_.n_vocab = atoi(meta_data.at("n_vocab").c_str()); + meta_data_.n_text_ctx = atoi(meta_data.at("n_text_ctx").c_str()); + meta_data_.n_text_state = atoi(meta_data.at("n_text_state").c_str()); + meta_data_.n_text_head = atoi(meta_data.at("n_text_head").c_str()); + meta_data_.n_text_layer = atoi(meta_data.at("n_text_layer").c_str()); + meta_data_.sot = atoi(meta_data.at("sot").c_str()); + meta_data_.sot_index = atoi(meta_data.at("sot_index").c_str()); + meta_data_.eot = atoi(meta_data.at("eot").c_str()); + meta_data_.blank_id = atoi(meta_data.at("blank_id").c_str()); + meta_data_.is_multilingual = atoi(meta_data.at("is_multilingual").c_str()); + meta_data_.no_speech = atoi(meta_data.at("no_speech").c_str()); + meta_data_.non_speech_tokens = + atoi(meta_data.at("non_speech_tokens").c_str()); + meta_data_.transcribe = atoi(meta_data.at("transcribe").c_str()); + meta_data_.translate = atoi(meta_data.at("translate").c_str()); + meta_data_.sot_prev = atoi(meta_data.at("sot_prev").c_str()); + meta_data_.sot_lm = atoi(meta_data.at("sot_lm").c_str()); + meta_data_.no_timestamps = atoi(meta_data.at("no_timestamps").c_str()); + + std::vector all_language_codes; + SplitStringToIntegers(meta_data.at("sot_sequence"), ",", true, + &meta_data_.sot_sequence); + + SplitStringToVector(meta_data.at("all_language_codes"), ",", true, + &all_language_codes); + + SplitStringToIntegers(meta_data.at("all_language_tokens"), ",", true, + &meta_data_.all_languages_id); + + for (int32_t i = 0; i < static_cast(all_language_codes.size()); + ++i) { + meta_data_.lang2id[all_language_codes[i]] = + meta_data_.all_languages_id[i]; + + meta_data_.id2lang[meta_data_.all_languages_id[i]] = + std::move(all_language_codes[i]); + } + } + + private: + torch::jit::Module model_; + OfflineWhisperModelMetaData meta_data_; + torch::Device device_{torch::kCPU}; +}; + +OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineWhisperModel::~OfflineWhisperModel() = default; + +const OfflineWhisperModelMetaData &OfflineWhisperModel::GetModelMetadata() + const { + return impl_->GetModelMetadata(); +} + +torch::Device OfflineWhisperModel::Device() const { return impl_->Device(); } + +std::pair OfflineWhisperModel::RunEncoder( + const torch::Tensor &features) const { + return impl_->RunEncoder(features); +} + +std::tuple +OfflineWhisperModel::RunDecoder(const torch::Tensor &tokens, + const torch::Tensor &n_layer_self_k_cache, + const torch::Tensor &n_layer_self_v_cache, + const torch::Tensor &n_layer_cross_k_cache, + const torch::Tensor &n_layer_cross_v_cache, + const torch::Tensor &offset) const { + return impl_->RunDecoder(tokens, n_layer_self_k_cache, n_layer_self_v_cache, + n_layer_cross_k_cache, n_layer_cross_v_cache, + offset); +} + +torch::Tensor OfflineWhisperModel::DetectLanguage( + const torch::Tensor &n_layer_cross_k_cache, + const torch::Tensor &n_layer_cross_v_cache) const { + return impl_->DetectLanguage(n_layer_cross_k_cache, n_layer_cross_v_cache); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-whisper-model.h b/sherpa/csrc/offline-whisper-model.h new file mode 100644 index 000000000..c4045a30d --- /dev/null +++ b/sherpa/csrc/offline-whisper-model.h @@ -0,0 +1,68 @@ +// sherpa/csrc/offline-whisper-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_WHISPER_MODEL_H_ +#define SHERPA_CSRC_OFFLINE_WHISPER_MODEL_H_ + +#include +#include +#include + +#include "sherpa/csrc/offline-model-config.h" +#include "sherpa/csrc/offline-whisper-model-meta-data.h" +#include "torch/script.h" + +namespace sherpa { + +class OfflineWhisperModel { + public: + explicit OfflineWhisperModel(const OfflineModelConfig &config); + + ~OfflineWhisperModel(); + + const OfflineWhisperModelMetaData &GetModelMetadata() const; + + torch::Device Device() const; + + /** + * @params features 3-D tensor of shape (N, C, T). + * @returns Return two tensors: + * - n_layer_cross_k_cache, 4-D tensor (num_layers, N, T, C) + * - n_layer_cross_v_cache, 4-D tensor (num_layers, N, T, C) + */ + std::pair RunEncoder( + const torch::Tensor &features) const; + + /* + * + * @params tokens A 2-D tensor of shape (N, num_tokens) + * @param n_layer_self_k_cache (num_layers, N, dim1, dim2) + * @param n_layer_self_v_cache (num_layers, N, dim1, dim2) + * @param n_layer_cross_k_cache (num_layers, N, T, dim) + * @param n_layer_cross_v_cache (num_layers, N, T, dim) + * @param offset A 1-D int32 tensor of shape (N,) + * + * @returns Return a tuple of 3 tensors: + * - logits, (N, num_tokens, dim) + * - n_layer_self_k_cache, (num_layers, batch-size, dim1, dim2) + * - n_layer_self_v_cache, (num_layers, batch-size, dim1, dim2) + */ + std::tuple RunDecoder( + const torch::Tensor &tokens, const torch::Tensor &n_layer_self_k_cache, + const torch::Tensor &n_layer_self_v_cache, + const torch::Tensor &n_layer_cross_k_cache, + const torch::Tensor &n_layer_cross_v_cache, + const torch::Tensor &offset) const; + + torch::Tensor DetectLanguage( + const torch::Tensor &n_layer_cross_k_cache, + const torch::Tensor &n_layer_cross_v_cache) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_WHISPER_MODEL_H_ diff --git a/sherpa/csrc/online-stream.cc b/sherpa/csrc/online-stream.cc index 902b5fdb7..05320390f 100644 --- a/sherpa/csrc/online-stream.cc +++ b/sherpa/csrc/online-stream.cc @@ -37,7 +37,9 @@ class OnlineStream::OnlineStreamImpl { public: explicit OnlineStreamImpl(const FeatureConfig &feat_config, ContextGraphPtr context_graph /*=nullptr*/) - : opts_(feat_config.fbank_opts), feat_config_(feat_config), context_graph_(context_graph) { + : opts_(feat_config.fbank_opts), + feat_config_(feat_config), + context_graph_(context_graph) { fbank_ = std::make_unique(opts_); } diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc index f8b19b93d..59b14f0ef 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc @@ -232,7 +232,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } cur.push_back(std::move(hyps)); } // for (int32_t k = 0; k != N; ++k) - } // for (int32_t t = 0; t != T; ++t) + } // for (int32_t t = 0; t != T; ++t) for (int32_t i = 0; i != N; ++i) { (*results)[i].hyps = std::move(cur[i]); diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa/csrc/online-transducer-modified-beam-search-decoder.h index 3927e002c..3c6da975e 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.h @@ -16,7 +16,9 @@ class OnlineTransducerModifiedBeamSearchDecoder public: explicit OnlineTransducerModifiedBeamSearchDecoder( OnlineTransducerModel *model, int32_t num_active_paths, float temperature) - : model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {} + : model_(model), + num_active_paths_(num_active_paths), + temperature_(temperature) {} OnlineTransducerDecoderResult GetEmptyResult() override; diff --git a/sherpa/csrc/parse-options.cc b/sherpa/csrc/parse-options.cc index ff1106233..c47fd624f 100644 --- a/sherpa/csrc/parse-options.cc +++ b/sherpa/csrc/parse-options.cc @@ -22,6 +22,8 @@ // This file is copied and modified from kaldi/src/util/parse-options.cu +#include "sherpa/cpp_api/parse-options.h" + #include #include @@ -33,135 +35,11 @@ #include #include -#include "sherpa/cpp_api/parse-options.h" #include "sherpa/csrc/log.h" - -#ifdef _MSC_VER -#define SHERPA_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); -#else -#define SHERPA_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); -#endif +#include "sherpa/csrc/text-utils.h" namespace sherpa { -/// Converts a string into an integer via strtoll and returns false if there was -/// any kind of problem (i.e. the string was not an integer or contained extra -/// non-whitespace junk, or the integer was too large to fit into the type it is -/// being converted into). Only sets *out if everything was OK and it returns -/// true. -template -bool ConvertStringToInteger(const std::string &str, Int *out) { - // copied from kaldi/src/util/text-util.h - static_assert(std::is_integral::value, ""); - const char *this_str = str.c_str(); - char *end = nullptr; - errno = 0; - int64_t i = SHERPA_STRTOLL(this_str, &end); - if (end != this_str) { - while (isspace(*end)) ++end; - } - if (end == this_str || *end != '\0' || errno != 0) return false; - Int iInt = static_cast(i); - if (static_cast(iInt) != i || - (i < 0 && !std::numeric_limits::is_signed)) { - return false; - } - *out = iInt; - return true; -} - -// copied from kaldi/src/util/text-util.cc -template -class NumberIstream { - public: - explicit NumberIstream(std::istream &i) : in_(i) {} - - NumberIstream &operator>>(T &x) { - if (!in_.good()) return *this; - in_ >> x; - if (!in_.fail() && RemainderIsOnlySpaces()) return *this; - return ParseOnFail(&x); - } - - private: - std::istream &in_; - - bool RemainderIsOnlySpaces() { - if (in_.tellg() != std::istream::pos_type(-1)) { - std::string rem; - in_ >> rem; - - if (rem.find_first_not_of(' ') != std::string::npos) { - // there is not only spaces - return false; - } - } - - in_.clear(); - return true; - } - - NumberIstream &ParseOnFail(T *x) { - std::string str; - in_.clear(); - in_.seekg(0); - // If the stream is broken even before trying - // to read from it or if there are many tokens, - // it's pointless to try. - if (!(in_ >> str) || !RemainderIsOnlySpaces()) { - in_.setstate(std::ios_base::failbit); - return *this; - } - - std::unordered_map inf_nan_map; - // we'll keep just uppercase values. - inf_nan_map["INF"] = std::numeric_limits::infinity(); - inf_nan_map["+INF"] = std::numeric_limits::infinity(); - inf_nan_map["-INF"] = -std::numeric_limits::infinity(); - inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); - inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); - inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); - inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); - // MSVC - inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); - inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); - inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); - - std::transform(str.begin(), str.end(), str.begin(), ::toupper); - - if (inf_nan_map.find(str) != inf_nan_map.end()) { - *x = inf_nan_map[str]; - } else { - in_.setstate(std::ios_base::failbit); - } - - return *this; - } -}; - -/// ConvertStringToReal converts a string into either float or double -/// and returns false if there was any kind of problem (i.e. the string -/// was not a floating point number or contained extra non-whitespace junk). -/// Be careful- this function will successfully read inf's or nan's. -template -bool ConvertStringToReal(const std::string &str, T *out) { - std::istringstream iss(str); - - NumberIstream i(iss); - - i >> *out; - - if (iss.fail()) { - // Number conversion failed. - return false; - } - - return true; -} - ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { if (po != nullptr && po->other_parser_ != nullptr) { diff --git a/sherpa/csrc/symbol-table.cc b/sherpa/csrc/symbol-table.cc index 68fadd62d..c8d885288 100644 --- a/sherpa/csrc/symbol-table.cc +++ b/sherpa/csrc/symbol-table.cc @@ -21,6 +21,7 @@ #include #include +#include "sherpa/csrc/base64-decode.h" #include "sherpa/csrc/log.h" namespace sherpa { @@ -66,9 +67,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const { return sym2id_.at(sym); } -bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } +bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; } -bool SymbolTable::contains(const std::string &sym) const { +bool SymbolTable::Contains(const std::string &sym) const { return sym2id_.count(sym) != 0; } @@ -84,4 +85,12 @@ void SymbolTable::Replace(int32_t id, const std::string &new_sym, sym2id_[new_sym] = id; } +void SymbolTable::ApplyBase64Decode() { + sym2id_.clear(); + for (auto &p : id2sym_) { + p.second = Base64Decode(p.second); + sym2id_[p.second] = p.first; + } +} + } // namespace sherpa diff --git a/sherpa/csrc/symbol-table.h b/sherpa/csrc/symbol-table.h index db6ab3293..484ddf5d3 100644 --- a/sherpa/csrc/symbol-table.h +++ b/sherpa/csrc/symbol-table.h @@ -49,10 +49,13 @@ class SymbolTable { const std::string &old_sym); /// Return true if there is a symbol with the given ID. - bool contains(int32_t id) const; + bool Contains(int32_t id) const; /// Return true if there is a given symbol in the symbol table. - bool contains(const std::string &sym) const; + bool Contains(const std::string &sym) const; + + // for tokens.txt from Whisper + void ApplyBase64Decode(); private: std::unordered_map sym2id_; diff --git a/sherpa/csrc/text-utils.cc b/sherpa/csrc/text-utils.cc new file mode 100644 index 000000000..dc9f4af14 --- /dev/null +++ b/sherpa/csrc/text-utils.cc @@ -0,0 +1,349 @@ +// sherpa/csrc/text-utils.cc +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023-2025 Xiaomi Corporation + +#include "sherpa/csrc/text-utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa/csrc/macros.h" + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc + +namespace sherpa { + +// copied from kaldi/src/util/text-util.cc +template +class NumberIstream { + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream &operator>>(T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream &ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::unordered_map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = -std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +template bool ConvertStringToReal(const std::string &str, float *out); + +template bool ConvertStringToReal(const std::string &str, double *out); + +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out) { + size_t start = 0, found = 0, end = full.size(); + out->clear(); + while (found != std::string::npos) { + found = full.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_strings || (found != start && start != end)) + out->push_back(full.substr(start, found - start)); + start = found + 1; + } +} + +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out) { + assert(out != nullptr); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); ++i) { + // assume atof never fails + F f = 0; + if (!ConvertStringToReal(split[i], &f)) return false; + (*out)[i] = f; + } + return true; +} + +// Instantiate the template above for float and double. +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +static bool IsPunct(char c) { return c != '\'' && std::ispunct(c); } +static bool IsGermanUmlaut(const std::string &word) { + // ä 0xC3 0xA4 + // ö 0xC3 0xB6 + // ü 0xC3 0xBC + // Ä 0xC3 0x84 + // Ö 0xC3 0x96 + // Ü 0xC3 0x9C + // ß 0xC3 0x9F + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa4 || c == 0xb6 || c == 0xbc || c == 0x84 || c == 0x96 || + c == 0x9c || c == 0x9f) { + return true; + } + + return false; +} + +// see https://www.tandem.net/blog/spanish-accents +// https://www.compart.com/en/unicode/U+00DC +static bool IsSpanishDiacritic(const std::string &word) { + // á 0xC3 0xA1 + // é 0xC3 0xA9 + // í 0xC3 0xAD + // ó 0xC3 0xB3 + // ú 0xC3 0xBA + // ü 0xC3 0xBC + // ñ 0xC3 0xB1 + // + // uppercase + // + // Á 0xC3 0x81 + // É 0xC3 0x89 + // Í 0xC3 0x8D + // Ó 0xC3 0x93 + // Ú 0xC3 0x9A + // Ü 0xC3 0x9C + // Ñ 0xC3 0x91 + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa1 || c == 0xa9 || c == 0xad || c == 0xb3 || c == 0xba || + c == 0xbc || c == 0xb1 || c == 0x81 || c == 0x89 || c == 0x8d || + c == 0x93 || c == 0x9a || c == 0x9c || c == 0x91) { + return true; + } + + return false; +} + +// see https://www.busuu.com/en/french/accent-marks +static bool IsFrenchDiacritic(const std::string &word) { + // acute accent + // é 0xC3 0xA9 + // + // grave accent + // à 0xC3 0xA0 + // è 0xC3 0xA8 + // ù 0xC3 0xB9 + // + // cedilla + // ç 0xC3 0xA7 + // + // circumflex + // â 0xC3 0xA2 + // ê 0xC3 0xAA + // î 0xC3 0xAE + // ô 0xC3 0xB4 + // û 0xC3 0xBB + // + // trema + // ë 0xC3 0xAB + // ï 0xC3 0xAF + // ü 0xC3 0xBC + // + // É 0xC3 0x89 + // + // À 0xC3 0x80 + // È 0xC3 0x88 + // Ù 0xC3 0x99 + // Ç 0xC3 0x87 + // Â 0xC3 0x82 + // Ê 0xC3 0x8A + // Î 0xC3 0x8E + // Ô 0xC3 0x94 + // Û 0xC3 0x9B + // Ë 0xC3 0x8B + // Ï 0xC3 0x8F + // Ü 0xC3 0x9C + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa9 || c == 0xa0 || c == 0xa8 || c == 0xb9 || c == 0xa7 || + c == 0xa2 || c == 0xaa || c == 0xae || c == 0xb4 || c == 0xbb || + c == 0xab || c == 0xaf || c == 0xbc || c == 0x89 || c == 0x80 || + c == 0x88 || c == 0x99 || c == 0x87 || c == 0x82 || c == 0x8a || + c == 0x8e || c == 0x94 || c == 0x9b || c == 0x8b || c == 0x8f || + c == 0x9c) { + return true; + } + return false; +} + +static bool IsSpecial(const std::string &w) { + bool ans = IsGermanUmlaut(w) || IsSpanishDiacritic(w) || IsFrenchDiacritic(w); + + // for french d’impossible + // ’ 0xE2 0x80 0x99 + bool ans2 = false; + if (w.size() == 3) { + auto c0 = static_cast(w[0]); + auto c1 = static_cast(w[1]); + auto c2 = static_cast(w[2]); + if (c0 == 0xe2 && c1 == 0x80 && c2 == 0x99) { + ans2 = true; + } + } + + return ans || ans2; +} + +static std::vector MergeCharactersIntoWords( + const std::vector &words) { + std::vector ans; + + int32_t n = static_cast(words.size()); + int32_t i = 0; + int32_t prev = -1; + + while (i < n) { + const auto &w = words[i]; + if (w.size() >= 3 || (w.size() == 2 && !IsSpecial(w)) || + (w.size() == 1 && (IsPunct(w[0]) || std::isspace(w[0])))) { + if (prev != -1) { + std::string t; + for (; prev < i; ++prev) { + t.append(words[prev]); + } + prev = -1; + ans.push_back(std::move(t)); + } + + if (!std::isspace(w[0])) { + ans.push_back(w); + } + ++i; + continue; + } + + // e.g., öffnen + if (w.size() == 1 || (w.size() == 2 && IsSpecial(w))) { + if (prev == -1) { + prev = i; + } + ++i; + continue; + } + + SHERPA_LOGE("Ignore %s", w.c_str()); + ++i; + } + + if (prev != -1) { + std::string t; + for (; prev < i; ++prev) { + t.append(words[prev]); + } + ans.push_back(std::move(t)); + } + + return ans; +} + +} // namespace sherpa diff --git a/sherpa/csrc/text-utils.h b/sherpa/csrc/text-utils.h new file mode 100644 index 000000000..7619df2d6 --- /dev/null +++ b/sherpa/csrc/text-utils.h @@ -0,0 +1,123 @@ +// sherpa/csrc/text-utils.h +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023-2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_TEXT_UTILS_H_ +#define SHERPA_CSRC_TEXT_UTILS_H_ +#include +#include + +#include +#include +#include +#include + +#ifdef _MSC_VER +#define SHERPA_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +#define SHERPA_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h + +namespace sherpa { + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, Int *out) { + // copied from kaldi/src/util/text-util.h + static_assert(std::is_integral::value, ""); + const char *this_str = str.c_str(); + char *end = nullptr; + errno = 0; + int64_t i = SHERPA_STRTOLL(this_str, &end); + if (end != this_str) { + while (isspace(*end)) ++end; + } + if (end == this_str || *end != '\0' || errno != 0) return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +/** + \brief Split a string (e.g. 1:2:3) into a vector of integers. + + \param [in] delim String containing a list of characters, any of which + is allowed as a delimiter. + \param [in] omit_empty_strings If true, empty strings between delimiters are + allowed and will not produce an output integer; if false, + instances of characters in 'delim' that are consecutive or + at the start or end of the string would be an error. + You'll normally want this to be true if 'delim' consists + of spaces, and false otherwise. + \param [out] out The output list of integers. +*/ +template +bool SplitStringToIntegers(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector *out) { + static_assert(std::is_integral::value, ""); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + int64_t j = 0; + j = SHERPA_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast(j); + if (static_cast(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +// This is defined for F = float and double. +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out); + +// This is defined for F = float and double. +template +bool ConvertStringToReal(const std::string &str, T *out); + +} // namespace sherpa + +#endif // SHERPA_CSRC_TEXT_UTILS_H_ diff --git a/sherpa/python/csrc/CMakeLists.txt b/sherpa/python/csrc/CMakeLists.txt index c93c19ba8..39df89d43 100644 --- a/sherpa/python/csrc/CMakeLists.txt +++ b/sherpa/python/csrc/CMakeLists.txt @@ -10,6 +10,7 @@ pybind11_add_module(_sherpa offline-recognizer.cc offline-sense-voice-model-config.cc offline-stream.cc + offline-whisper-model-config.cc online-recognizer.cc online-stream.cc resample.cc diff --git a/sherpa/python/csrc/offline-model-config.cc b/sherpa/python/csrc/offline-model-config.cc index 689fb0216..054967444 100644 --- a/sherpa/python/csrc/offline-model-config.cc +++ b/sherpa/python/csrc/offline-model-config.cc @@ -9,20 +9,24 @@ #include "sherpa/csrc/offline-model-config.h" #include "sherpa/python/csrc/offline-sense-voice-model-config.h" +#include "sherpa/python/csrc/offline-whisper-model-config.h" namespace sherpa { void PybindOfflineModelConfig(py::module *m) { PybindOfflineSenseVoiceModelConfig(m); + PybindOfflineWhisperModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") - .def(py::init(), py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), - py::arg("tokens"), py::arg("debug") = false, - py::arg("use_gpu") = false) + py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), + py::arg("debug") = false, py::arg("use_gpu") = false) .def_readwrite("sense_voice", &PyClass::sense_voice) + .def_readwrite("whisper", &PyClass::whisper) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("debug", &PyClass::debug) .def_readwrite("use_gpu", &PyClass::use_gpu) diff --git a/sherpa/python/csrc/offline-sense-voice-model-config.cc b/sherpa/python/csrc/offline-sense-voice-model-config.cc index 32b792276..68c887bb5 100644 --- a/sherpa/python/csrc/offline-sense-voice-model-config.cc +++ b/sherpa/python/csrc/offline-sense-voice-model-config.cc @@ -20,6 +20,7 @@ void PybindOfflineSenseVoiceModelConfig(py::module *m) { .def_readwrite("model", &PyClass::model) .def_readwrite("language", &PyClass::language) .def_readwrite("use_itn", &PyClass::use_itn) + .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa/python/csrc/offline-whisper-model-config.cc b/sherpa/python/csrc/offline-whisper-model-config.cc new file mode 100644 index 000000000..26c0c5e38 --- /dev/null +++ b/sherpa/python/csrc/offline-whisper-model-config.cc @@ -0,0 +1,27 @@ +// sherpa/python/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa/csrc/offline-whisper-model-config.h" + +#include +#include + +#include "sherpa/python/csrc/offline-whisper-model-config.h" + +namespace sherpa { + +void PybindOfflineWhisperModelConfig(py::module *m) { + using PyClass = OfflineWhisperModelConfig; + py::class_(*m, "OfflineWhisperModelConfig") + .def(py::init(), + py::arg("model"), py::arg("language"), py::arg("task")) + .def_readwrite("model", &PyClass::model) + .def_readwrite("language", &PyClass::language) + .def_readwrite("task", &PyClass::task) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/offline-whisper-model-config.h b/sherpa/python/csrc/offline-whisper-model-config.h new file mode 100644 index 000000000..942015633 --- /dev/null +++ b/sherpa/python/csrc/offline-whisper-model-config.h @@ -0,0 +1,15 @@ +// sherpa/python/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include "sherpa/python/csrc/sherpa.h" + +namespace sherpa { + +void PybindOfflineWhisperModelConfig(py::module *m); +} + +#endif // SHERPA_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/sherpa/python/sherpa/__init__.py b/sherpa/python/sherpa/__init__.py index 924be39f4..dd5d810c6 100644 --- a/sherpa/python/sherpa/__init__.py +++ b/sherpa/python/sherpa/__init__.py @@ -19,7 +19,9 @@ OfflineModelConfig, OfflineRecognizer, OfflineRecognizerConfig, + OfflineSenseVoiceModelConfig, OfflineStream, + OfflineWhisperModelConfig, OnlineRecognitionResult, OnlineRecognizer, OnlineRecognizerConfig,