diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 50f35aed8..da23af597 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -934,7 +934,7 @@ struct SDContextParams { return oss.str(); } - sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { + sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview, bool is_server = false) { embedding_vec.clear(); embedding_vec.reserve(embedding_map.size()); for (const auto& kv : embedding_map) { @@ -944,6 +944,11 @@ struct SDContextParams { embedding_vec.emplace_back(item); } + if(is_server && lora_apply_mode == LORA_APPLY_AUTO) + { + lora_apply_mode = LORA_APPLY_AT_RUNTIME; + } + sd_ctx_params_t sd_ctx_params = { model_path.c_str(), clip_l_path.c_str(), @@ -1628,6 +1633,23 @@ struct SDGenerationParams { return true; } + static bool sanitize_lora_path(const std::string& lora_model_dir, + const std::string& raw_path_str, + fs::path& full_path) { + if (lora_model_dir.empty()) + return false; + + fs::path raw_p(raw_path_str); + + if (raw_p.is_absolute() || + !raw_p.root_name().empty() || + raw_path_str.find("..") != std::string::npos) { + return false; + } + full_path = fs::path(lora_model_dir) / raw_p; + return true; + } + void extract_and_remove_lora(const std::string& lora_model_dir) { if (lora_model_dir.empty()) { return; @@ -1659,10 +1681,10 @@ struct SDGenerationParams { } fs::path final_path; - if (is_absolute_path(raw_path)) { - final_path = raw_path; - } else { - final_path = fs::path(lora_model_dir) / raw_path; + if (!sanitize_lora_path(lora_model_dir, raw_path, final_path)) { + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; } if (!fs::exists(final_path)) { bool found = false; diff --git a/examples/server/main.cpp b/examples/server/main.cpp index def499755..0969505a6 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -283,7 +283,7 @@ int main(int argc, const char** argv) { LOG_DEBUG("%s", ctx_params.to_string().c_str()); LOG_DEBUG("%s", default_gen_params.to_string().c_str()); - sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); + sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false, true); sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); if (sd_ctx == nullptr) { @@ -408,7 +408,7 @@ int main(int argc, const char** argv) { if (gen_params.sample_params.sample_steps > 100) gen_params.sample_params.sample_steps = 100; - if (!gen_params.process_and_check(IMG_GEN, "")) { + if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json"); return; @@ -589,7 +589,7 @@ int main(int argc, const char** argv) { if (gen_params.sample_params.sample_steps > 100) gen_params.sample_params.sample_steps = 100; - if (!gen_params.process_and_check(IMG_GEN, "")) { + if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json"); return;