Skip to content

Commit

Permalink
Merge pull request #13 from 78/mqtt
Browse files Browse the repository at this point in the history
add settings
  • Loading branch information
78 authored Nov 14, 2024
2 parents 01fe53c + 58de385 commit 1deb477
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 89 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16)

set(PROJECT_VER "0.8.0")
set(PROJECT_VER "0.8.1")

include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi)
1 change: 1 addition & 0 deletions main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(SOURCES "audio_codec.cc"
"button.cc"
"led.cc"
"ota.cc"
"settings.cc"
"main.cc"
)
set(INCLUDE_DIRS ".")
Expand Down
97 changes: 49 additions & 48 deletions main/application.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,41 @@ Application::~Application() {
if (audio_encode_task_stack_ != nullptr) {
heap_caps_free(audio_encode_task_stack_);
}
if (main_loop_task_stack_ != nullptr) {
heap_caps_free(main_loop_task_stack_);
}

vEventGroupDelete(event_group_);
}

void Application::CheckNewVersion() {
// Check if there is a new firmware version available
ota_.SetPostData(Board::GetInstance().GetJson());
ota_.CheckVersion();
if (ota_.HasNewVersion()) {
SetChatState(kChatStateUpgrading);
ota_.StartUpgrade([](int progress, size_t speed) {
char buffer[64];
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
auto display = Board::GetInstance().GetDisplay();
display->SetText(buffer);
});
// If upgrade success, the device will reboot and never reach here
ESP_LOGI(TAG, "Firmware upgrade failed...");
SetChatState(kChatStateIdle);
} else {
ota_.MarkCurrentVersionValid();

while (true) {
if (ota_.CheckVersion()) {
if (ota_.HasNewVersion()) {
// Wait for the chat state to be idle
while (chat_state_ != kChatStateIdle) {
vTaskDelay(100);
}

SetChatState(kChatStateUpgrading);
ota_.StartUpgrade([](int progress, size_t speed) {
char buffer[64];
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
auto display = Board::GetInstance().GetDisplay();
display->SetText(buffer);
});

// If upgrade success, the device will reboot and never reach here
ESP_LOGI(TAG, "Firmware upgrade failed...");
SetChatState(kChatStateIdle);
} else {
ota_.MarkCurrentVersionValid();
}
return;
}

// Check again in 60 seconds
vTaskDelay(pdMS_TO_TICKS(60000));
}
}

Expand Down Expand Up @@ -191,24 +202,18 @@ void Application::Start() {
/* Wait for the network to be ready */
board.StartNetwork();

const size_t main_loop_stack_size = 4096 * 8;
main_loop_task_stack_ = (StackType_t*)heap_caps_malloc(main_loop_stack_size, MALLOC_CAP_SPIRAM);
xTaskCreateStatic([](void* arg) {
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->MainLoop();
vTaskDelete(NULL);
}, "main_loop", main_loop_stack_size, this, 1, main_loop_task_stack_, &main_loop_task_buffer_);
}, "main_loop", 4096 * 2, this, 1, nullptr);

// Check for new firmware version or get the MQTT broker address
while (true) {
CheckNewVersion();

if (ota_.HasMqttConfig()) {
break;
}
Alert("Error", "Missing MQTT config");
vTaskDelay(pdMS_TO_TICKS(10000));
}
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->CheckNewVersion();
vTaskDelete(NULL);
}, "check_new_version", 4096 * 2, this, 1, nullptr);

#ifdef CONFIG_USE_AFE_SR
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
Expand Down Expand Up @@ -264,12 +269,19 @@ void Application::Start() {

// Initialize the protocol
display->SetText("Starting\nProtocol...");
protocol_ = new MqttProtocol(ota_.GetMqttConfig());
protocol_ = new MqttProtocol();
protocol_->OnIncomingAudio([this](const std::string& data) {
std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.emplace_back(std::move(data));
cv_.notify_all();
});
protocol_->OnAudioChannelOpened([this, codec]() {
if (protocol_->GetServerSampleRate() != codec->output_sample_rate()) {
ESP_LOGW(TAG, "服务器的音频采样率 %d 与设备输出的采样率 %d 不一致,重采样后可能会失真",
protocol_->GetServerSampleRate(), codec->output_sample_rate());
}
SetDecodeSampleRate(protocol_->GetServerSampleRate());
});
protocol_->OnAudioChannelClosed([this]() {
Schedule([this]() {
SetChatState(kChatStateIdle);
Expand All @@ -289,7 +301,9 @@ void Application::Start() {
Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone();
SetChatState(kChatStateListening);
if (chat_state_ == kChatStateSpeaking) {
SetChatState(kChatStateListening);
}
});
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
Expand All @@ -307,15 +321,6 @@ void Application::Start() {
if (emotion != NULL) {
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
}
} else if (strcmp(type->valuestring, "hello") == 0) {
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
SetDecodeSampleRate(sample_rate->valueint);
}
}
}
});

Expand Down Expand Up @@ -351,8 +356,7 @@ void Application::MainLoop() {

void Application::AbortSpeaking() {
ESP_LOGI(TAG, "Abort speaking");
std::string json = "{\"type\":\"abort\"}";
protocol_->SendText(json);
protocol_->SendAbort();

skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec();
Expand Down Expand Up @@ -420,10 +424,7 @@ void Application::SetChatState(ChatState state) {
break;
}

std::string json = "{\"type\":\"state\",\"state\":\"";
json += state_str[chat_state_];
json += "\"}";
protocol_->SendText(json);
protocol_->SendState(state_str[chat_state_]);
}

void Application::AudioEncodeTask() {
Expand Down Expand Up @@ -455,7 +456,7 @@ void Application::AudioEncodeTask() {
continue;
}

int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000;
int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000;
std::vector<int16_t> pcm(frame_size);

int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);
Expand Down
7 changes: 2 additions & 5 deletions main/application.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ enum ChatState {
kChatStateUpgrading
};

#define OPUS_FRAME_DURATION_MS 60

class Application {
public:
static Application& GetInstance() {
Expand Down Expand Up @@ -84,16 +86,11 @@ class Application {
OpusEncoder opus_encoder_;
OpusDecoder* opus_decoder_ = nullptr;

int opus_duration_ms_ = 60;
int opus_decode_sample_rate_ = -1;
OpusResampler input_resampler_;
OpusResampler reference_resampler_;
OpusResampler output_resampler_;

TaskHandle_t main_loop_task_ = nullptr;
StaticTask_t main_loop_task_buffer_;
StackType_t* main_loop_task_stack_ = nullptr;

void MainLoop();
void SetDecodeSampleRate(int sample_rate);
void CheckNewVersion();
Expand Down
7 changes: 7 additions & 0 deletions main/audio_codec.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "audio_codec.h"
#include "board.h"
#include "settings.h"

#include <esp_log.h>
#include <cstring>
Expand Down Expand Up @@ -40,6 +41,9 @@ IRAM_ATTR bool AudioCodec::on_sent(i2s_chan_handle_t handle, i2s_event_data_t *e
}

void AudioCodec::Start() {
Settings settings("audio", false);
output_volume_ = settings.GetInt("output_volume", output_volume_);

// 注册音频输出回调
i2s_event_callbacks_t callbacks = {};
callbacks.on_sent = on_sent;
Expand Down Expand Up @@ -124,6 +128,9 @@ void AudioCodec::ClearOutputQueue() {
void AudioCodec::SetOutputVolume(int volume) {
output_volume_ = volume;
ESP_LOGI(TAG, "Set output volume to %d", output_volume_);

Settings settings("audio", true);
settings.SetInt("output_volume", output_volume_);
}

void AudioCodec::EnableInput(bool enable) {
Expand Down
1 change: 1 addition & 0 deletions main/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern "C" void app_main(void)
// Initialize NVS flash for WiFi configuration
esp_err_t ret = nvs_flash_init();
if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
ESP_LOGW(TAG, "Erasing NVS flash to fix corruption");
ESP_ERROR_CHECK(nvs_flash_erase());
ret = nvs_flash_init();
}
Expand Down
20 changes: 12 additions & 8 deletions main/ota.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ota.h"
#include "system_info.h"
#include "board.h"
#include "settings.h"

#include <cJSON.h>
#include <esp_log.h>
Expand Down Expand Up @@ -34,13 +35,13 @@ void Ota::SetPostData(const std::string& post_data) {
post_data_ = post_data;
}

void Ota::CheckVersion() {
bool Ota::CheckVersion() {
std::string current_version = esp_app_get_description()->version;
ESP_LOGI(TAG, "Current version: %s", current_version.c_str());

if (check_version_url_.length() < 10) {
ESP_LOGE(TAG, "Check version URL is not properly set");
return;
return false;
}

auto http = Board::GetInstance().CreateHttp();
Expand All @@ -67,16 +68,18 @@ void Ota::CheckVersion() {
cJSON *root = cJSON_Parse(response.c_str());
if (root == NULL) {
ESP_LOGE(TAG, "Failed to parse JSON response");
return;
return false;
}

cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
if (mqtt != NULL) {
Settings settings("mqtt", true);
cJSON *item = NULL;
cJSON_ArrayForEach(item, mqtt) {
if (item->type == cJSON_String) {
mqtt_config_[item->string] = item->valuestring;
ESP_LOGI(TAG, "MQTT config: %s = %s", item->string, item->valuestring);
if (settings.GetString(item->string) != item->valuestring) {
settings.SetString(item->string, item->valuestring);
}
}
}
has_mqtt_config_ = true;
Expand All @@ -86,19 +89,19 @@ void Ota::CheckVersion() {
if (firmware == NULL) {
ESP_LOGE(TAG, "Failed to get firmware object");
cJSON_Delete(root);
return;
return false;
}
cJSON *version = cJSON_GetObjectItem(firmware, "version");
if (version == NULL) {
ESP_LOGE(TAG, "Failed to get version object");
cJSON_Delete(root);
return;
return false;
}
cJSON *url = cJSON_GetObjectItem(firmware, "url");
if (url == NULL) {
ESP_LOGE(TAG, "Failed to get url object");
cJSON_Delete(root);
return;
return false;
}

firmware_version_ = version->valuestring;
Expand All @@ -112,6 +115,7 @@ void Ota::CheckVersion() {
} else {
ESP_LOGI(TAG, "Current is the latest version");
}
return true;
}

void Ota::MarkCurrentVersionValid() {
Expand Down
5 changes: 1 addition & 4 deletions main/ota.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ class Ota {
void SetCheckVersionUrl(std::string check_version_url);
void SetHeader(const std::string& key, const std::string& value);
void SetPostData(const std::string& post_data);
void CheckVersion();
bool CheckVersion();
bool HasNewVersion() { return has_new_version_; }
bool HasMqttConfig() { return has_mqtt_config_; }
void StartUpgrade(std::function<void(int progress, size_t speed)> callback);
void MarkCurrentVersionValid();

std::map<std::string, std::string>& GetMqttConfig() { return mqtt_config_; }

private:
std::string check_version_url_;
bool has_new_version_ = false;
Expand All @@ -29,7 +27,6 @@ class Ota {
std::string firmware_url_;
std::string post_data_;
std::map<std::string, std::string> headers_;
std::map<std::string, std::string> mqtt_config_;

void Upgrade(const std::string& firmware_url);
std::function<void(int progress, size_t speed)> upgrade_callback_;
Expand Down
4 changes: 4 additions & 0 deletions main/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ class Protocol {
virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0;
virtual void SendAudio(const std::string& data) = 0;
virtual void SendText(const std::string& text) = 0;
virtual void SendState(const std::string& state) = 0;
virtual void SendAbort() = 0;
virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0;
virtual void OnAudioChannelOpened(std::function<void()> callback) = 0;
virtual void OnAudioChannelClosed(std::function<void()> callback) = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual int GetServerSampleRate() const = 0;
};

#endif // PROTOCOL_H
Expand Down
Loading

0 comments on commit 1deb477

Please sign in to comment.