Demo

import asyncio
import os

import torch
import numpy as np
import sounddevice as sd
import time
import threading
import queue
import re
from funasr import AutoModel
import pyttsx3
import requests
from config.logger import setup_logging
from llama_cpp import Llama
from sentence_transformers import SentenceTransformer
import faiss


import edge_tts
import tempfile
from pydub import AudioSegment
import simpleaudio as sa

from textUtils import get_string_no_punctuation_or_emoji

logger = setup_logging()


# ================= 1. 配置 =================


def get_project_dir():
    """获取项目根目录"""
    return os.path.dirname(__file__) + "/"

is_tts_playing = False
SAMPLE_RATE = 16000
BLOCK_SIZE = 512

THRESHOLD_HIGH = 0.5
THRESHOLD_LOW = 0.2
MIN_SPEECH_FRAMES = 3
SILENCE_DURATION_MS = 600

VAD_MODEL_PATH = f'{get_project_dir()}models/vad/snakers4_silero-vad'
ASR_MODEL_DIR = f'{get_project_dir()}models/asr/SenseVoiceSmall'
EMBEDDING_MODEL_DIR = f'{get_project_dir()}models/embedding/sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL_DIR = f'{get_project_dir()}models/llm/Qwen3-4B-Instruct-2507-Q4_K_M-GGUF/qwen3-4b-instruct-2507-q4_k_m.gguf'

# ================= 2. 全局变量 =================

audio_queue = queue.Queue()
llm_queue = queue.Queue()
tts_queue = queue.Queue()

vad_model = None
asr_model = None
llm_model = None
tts_engine = None

audio_buffer = []
is_recording = False

class EdgeTTS:
    def __init__(self, voice="zh-CN-XiaoxiaoNeural"):
        self.voice = voice
        self.loop = asyncio.new_event_loop()
        threading.Thread(target=self.loop.run_forever, daemon=True).start()

    async def _speak_async(self, text):
        global is_tts_playing
        is_tts_playing = True   # 🔴 开始播放

        try:
            communicate = edge_tts.Communicate(text, self.voice)

            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
                tmp_path = f.name

            await communicate.save(tmp_path)

            audio = AudioSegment.from_file(tmp_path, format="mp3")
            play_obj = sa.play_buffer(
                audio.raw_data,
                num_channels=audio.channels,
                bytes_per_sample=audio.sample_width,
                sample_rate=audio.frame_rate,
            )
            play_obj.wait_done()

            os.remove(tmp_path)

        finally:
            is_tts_playing = False   # 🟢 播放结束

    def speak(self, text):
        asyncio.run_coroutine_threadsafe(
            self._speak_async(text),
            self.loop
        )

class VadState:
    def __init__(self):
        self.is_speaking = False
        self.consecutive_voice_frames = 0
        self.silence_start_time = None


state = VadState()


# ================= 3. 模型初始化 =================

def load_models():
    global vad_model, asr_model, tts_engine, llm_model

    logger.info("🚀 初始化模型...")

    # VAD
    model, utils = torch.hub.load(
        repo_or_dir=VAD_MODEL_PATH,
        source="local",
        model='silero_vad',
        force_reload=False,
        trust_repo=True
    )
    vad_model = model
    logger.info("✅ VAD 加载完成")

    # ASR
    asr_model = AutoModel(
        model=ASR_MODEL_DIR,
        vad_model=f"{get_project_dir()}models/vad/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
        vad_kwargs={"max_single_segment_time": 30000},
        device="cpu",
        disable_update=True
    )
    logger.info("✅ ASR 加载完成")

    # LLM

    llm_model = Llama(
        model_path=LLM_MODEL_DIR,
        n_ctx=4096,
        n_threads=4,
        n_gpu_layers=-1,
        verbose=False
    )

    # TTS
    # tts_engine = pyttsx3.init()
    # tts_engine.setProperty('rate', 150)
    # tts_engine.setProperty('volume', 1.0)
    # logger.info("✅ TTS 初始化完成")

    # Edge TTS

    tts_engine = EdgeTTS()
    logger.info("✅ Edge TTS 初始化完成")

# ================= 4. 文本清理 =================

def clean_text(text):
    # 去除 <|xxx|> 标签
    text = re.sub(r"<\|.*?\|>", "", text)
    return text.strip()


# =========================
# 1️⃣ 向量数据库
# =========================
class VectorStore:
    def __init__(self, index_path="faiss.index"):
        self.model = SentenceTransformer(EMBEDDING_MODEL_DIR)
        self.dimension = 384
        self.index_path = index_path

        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            self.texts = np.load("faiss_texts.npy", allow_pickle=True).tolist()
        else:
            self.index = faiss.IndexFlatL2(self.dimension)
            self.texts = []

    def save(self):
        faiss.write_index(self.index, self.index_path)
        np.save("faiss_texts.npy", np.array(self.texts, dtype=object))

    def add(self, text):
        embedding = self.model.encode([text])
        self.index.add(np.array(embedding).astype("float32"))
        self.texts.append(text)
        self.save()

    def search(self, query, k=3):
        if len(self.texts) == 0:
            return []

        embedding = self.model.encode([query])
        D, I = self.index.search(np.array(embedding).astype("float32"), k)
        return [self.texts[i] for i in I[0] if i < len(self.texts)]

    def rebuild_index(self):
        """重建索引(删除后使用)"""
        self.index = faiss.IndexFlatL2(self.dimension)
        if len(self.texts) > 0:
            embeddings = self.model.encode(self.texts)
            self.index.add(np.array(embeddings).astype("float32"))
        self.save()

    def delete(self, index):
        """删除指定索引的文本"""
        if 0 <= index < len(self.texts):
            del self.texts[index]
            self.rebuild_index()


# =========================
# 2️⃣ 记忆管理
# =========================
class MemoryManager:
    def __init__(self, max_short=6):
        self.short_term = []
        self.long_term = ""
        self.max_short = max_short

    def add(self, user, assistant):
        self.short_term.append((user, assistant))
        if len(self.short_term) > self.max_short:
            self.summarize()

    def summarize(self):
        summary = ""
        for u, a in self.short_term:
            summary += f"用户:{u}\n助手:{a}\n"

        self.long_term += summary
        self.short_term = []

    def get_context(self):
        recent = ""
        for u, a in self.short_term:
            recent += f"用户:{u}\n助手:{a}\n"

        return self.long_term + "\n" + recent


# =========================
# 3️⃣ 防幻觉 Prompt
# =========================
def create_prompt(rag_context, memory_context, question):
    return f"""<|im_start|>system
    你是一个严谨的智能助手。

    规则:
    1. 优先使用【知识库内容】回答。
    2. 如果知识库没有相关内容,可以使用常识。
    3. 不允许编造数据。
    4. 不确定就说“这个我不清楚”。

    【知识库内容】
    {rag_context}

    【历史对话】
    {memory_context}
    <|im_end|>
    <|im_start|>user
    {question}
    <|im_end|>
    <|im_start|>assistant
"""


# =========================
# 5️⃣ 初始化模块
# =========================
vector_store = VectorStore()
memory = MemoryManager()

# 如果第一次运行,初始化知识库
if len(vector_store.texts) == 0:
    vector_store.add("机器狗通过底部磁吸触点充电,耗时约2小时。")
    vector_store.add("机器狗支持语音控制和自动避障功能。")

def llm_thread():
    logger.info("🧠 LLM线程启动")

    while True:
        llm_data = llm_queue.get()
        if llm_data is None:
            break

        try:
            # ========= RAG =========
            rag_results = vector_store.search(llm_data)
            rag_context = "\n".join(rag_results)

            memory_context = memory.get_context()
            prompt = create_prompt(rag_context, memory_context, llm_data)

            output_stream = llm_model(
                prompt,
                max_tokens=512,
                temperature=0.7,
                stop=["<|im_end|>"],
                stream=True
            )

            full_response = ""
            sentence_buffer = ""

            print("\n🤖 助手: ", end="", flush=True)

            for token in output_stream:
                text = token["choices"][0]["text"]

                full_response += text
                sentence_buffer += text

                print(text, end="", flush=True)

                # =====================
                # 句子检测(中文标点)
                # =====================
                if any(p in sentence_buffer for p in ["。", "!", "?", ".", "!", "?"]):
                    clean_sentence = sentence_buffer.strip()
                    if clean_sentence:
                        tts_queue.put(clean_sentence)
                    sentence_buffer = ""

            # 剩余未播报
            if sentence_buffer.strip():
                tts_queue.put(sentence_buffer.strip())

            print("\n")

            memory.add(llm_data, full_response.strip())

        except Exception as e:
            print(f"❌ LLM错误: {e}")

        finally:
            llm_queue.task_done()


# ================= 5. ASR线程 =================
def asr_thread():
    logger.info("🧠 ASR线程启动")

    while True:
        audio_data = audio_queue.get()

        if audio_data is None:
            break

        try:
            print(f"\n🔍 识别中 ({len(audio_data) / 16000:.2f}秒)...")

            res = asr_model.generate(
                input=[audio_data],
                cache={},
                batch_size_s=0
            )

            if res and len(res) > 0:
                raw_text = res[0].get("text", "").strip()
                text = clean_text(raw_text)

                if text:
                    print(f"✨ 识别结果: {text}")
                    # tts_queue.put(text)
                    llm_queue.put(text)
                else:
                    logger.debug("⚠️ 空识别结果")

        except Exception as e:
            print(f"❌ ASR错误: {e}")

        finally:
            audio_queue.task_done()


# ================= 6. TTS线程 =================

def tts_edge_thread():
    print("🔊 Edge-TTS线程启动")

    while True:
        text = tts_queue.get()
        if text is None:
            break

        try:
            text = get_string_no_punctuation_or_emoji(text)
            print(f"\n🔊 播放: {text}")
            tts_engine.speak(text)

        except Exception as e:
            print(f"❌ TTS错误: {e}")

        finally:
            tts_queue.task_done()
def tts_thread():
    print("🔊 TTS线程启动")

    while True:
        text = tts_queue.get()

        if text is None:
            break

        try:
            print("🔊 朗读中...")
            tts_engine.say(text)
            tts_engine.runAndWait()
            print("✅ 朗读完成")

        except Exception as e:
            print(f"❌ TTS错误: {e}")

        finally:
            tts_queue.task_done()


# ================= 7. 音频回调 =================

def audio_callback(indata, frames, time_info, status):
    global audio_buffer, is_recording

    if is_tts_playing:
        return  # 🔴 播放期间直接丢弃麦克风数据

    try:
        audio_data = indata[:, 0].astype(np.float32)
        audio_tensor = torch.from_numpy(audio_data)

        with torch.no_grad():
            speech_prob = vad_model(audio_tensor, SAMPLE_RATE).item()

        if speech_prob >= THRESHOLD_HIGH:
            current_voice = True
        elif speech_prob <= THRESHOLD_LOW:
            current_voice = False
        else:
            current_voice = state.is_speaking

        now_ms = time.time() * 1000

        if current_voice:
            state.consecutive_voice_frames += 1

            if not state.is_speaking and state.consecutive_voice_frames >= MIN_SPEECH_FRAMES:
                state.is_speaking = True
                state.silence_start_time = None
                audio_buffer.clear()
                is_recording = True
                print(">>> [START]")

            if is_recording:
                audio_buffer.append(audio_data.copy())

        else:
            state.consecutive_voice_frames = 0

            if state.is_speaking:
                if state.silence_start_time is None:
                    state.silence_start_time = now_ms

                silence_duration = now_ms - state.silence_start_time

                if silence_duration >= SILENCE_DURATION_MS:
                    state.is_speaking = False
                    state.silence_start_time = None
                    is_recording = False
                    print("<<< [END]")

                    if len(audio_buffer) > 0:
                        full_audio = np.concatenate(audio_buffer)
                        audio_queue.put(full_audio)

    except Exception as e:
        print(f"❌ 回调异常: {e}")


# ================= 8. 主函数 =================

def main():
    load_models()

    threading.Thread(target=asr_thread, daemon=True).start()
    threading.Thread(target=llm_thread, daemon=True).start()
    # threading.Thread(target=tts_thread, daemon=True).start()
    threading.Thread(target=tts_edge_thread, daemon=True).start()

    logger.info("🎙️ 系统就绪,请说话...")

    try:
        with sd.InputStream(
                samplerate=SAMPLE_RATE,
                channels=1,
                blocksize=BLOCK_SIZE,
                dtype=np.float32,
                callback=audio_callback
        ):
            while True:
                sd.sleep(1000)

    except KeyboardInterrupt:
        print("\n👋 已退出")
        audio_queue.put(None)
        llm_queue.put(None)
        tts_queue.put(None)


if __name__ == "__main__":
    main()
☁️ 部署建议
如果你打算长期运行项目(博客 / API / 自动化脚本),建议直接用云服务器,会比本地稳定很多。
👉 查看云服务器(新用户优惠)