Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ struct WeightAdapter {
bool force_prec_f32 = false;
float scale = 1.f;
} linear;
struct {
struct conv2d_params_t {
int s0 = 1;
int s1 = 1;
int p0 = 0;
Expand Down Expand Up @@ -2630,4 +2630,173 @@ class MultiheadAttention : public GGMLBlock {
}
};

__STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
struct ggml_context* ctx,
struct ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
struct ggml_tensor* w1, // Outer C (Full rank)
struct ggml_tensor* w1a, // Outer A (Low rank part 1)
struct ggml_tensor* w1b, // Outer B (Low rank part 2)
struct ggml_tensor* w2, // Inner BA (Full rank)
struct ggml_tensor* w2a, // Inner A (Low rank part 1)
struct ggml_tensor* w2b, // Inner B (Low rank part 2)
bool is_conv,
WeightAdapter::ForwardParams::conv2d_params_t conv_params,
float scale) {
GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL)));
GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL)));

int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0];
int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1];

int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0];
int vq = q_actual / uq;

int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1])
: (int)w2a->ne[1];
GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split");

struct ggml_tensor* hb;

if (!is_conv) {
int batch = (int)h->ne[1];
int merge_batch_uq = batch;
int merge_batch_vp = batch;

#if SD_USE_VULKAN
if (batch > 1) {
// no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
int max_batch = 65535;
int max_batch_uq = max_batch / uq;
merge_batch_uq = 1;
for (int i = max_batch_uq; i > 0; i--) {
if (batch % i == 0) {
merge_batch_uq = i;
break;
}
}

int max_batch_vp = max_batch / vp;
merge_batch_vp = 1;
for (int i = max_batch_vp; i > 0; i--) {
if (batch % i == 0) {
merge_batch_vp = i;
break;
}
}
}
#endif

struct ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq);
if (w2 != NULL) {
hb = ggml_mul_mat(ctx, w2, h_split);
} else {
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split));
}

if (batch > 1) {
hb = ggml_reshape_3d(ctx, hb, vp, uq, batch);
}
struct ggml_tensor* hb_t = ggml_cont(ctx, ggml_transpose(ctx, hb));
hb_t = ggml_reshape_3d(ctx, hb_t, uq, vp * merge_batch_vp, batch / merge_batch_vp);

struct ggml_tensor* hc_t;
if (w1 != NULL) {
hc_t = ggml_mul_mat(ctx, w1, hb_t);
} else {
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
}

if (batch > 1) {
hc_t = ggml_reshape_3d(ctx, hc_t, up, vp, batch);
}

struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
struct ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc), up * vp, batch);
return ggml_scale(ctx, out, scale);
} else {
int batch = (int)h->ne[3];
// 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
struct ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);

if (w2 != NULL) {
hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr,
conv_params.s0,
conv_params.s1,
conv_params.p0,
conv_params.p1,
conv_params.d0,
conv_params.d1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);
} else {
// swap a and b order for conv lora
struct ggml_tensor* a = w2b;
struct ggml_tensor* b = w2a;

// unpack conv2d weights if needed
if (ggml_n_dims(a) < 4) {
int k = (int)sqrt(a->ne[0] / h_split->ne[2]);
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[0]);
a = ggml_reshape_4d(ctx, a, k, k, a->ne[0] / (k * k), a->ne[1]);
} else if (a->ne[2] != h_split->ne[2]) {
int k = (int)sqrt(a->ne[2] / h_split->ne[2]);
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[2]);
a = ggml_reshape_4d(ctx, a, a->ne[0] * k, a->ne[1] * k, a->ne[2] / (k * k), a->ne[3]);
}
struct ggml_tensor* ha = ggml_ext_conv_2d(ctx, h_split, a, nullptr,
conv_params.s0,
conv_params.s1,
conv_params.p0,
conv_params.p1,
conv_params.d0,
conv_params.d1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);

// not supporting lora_mid here
hb = ggml_ext_conv_2d(ctx,
ha,
b,
nullptr,
1,
1,
0,
0,
1,
1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);
}

// Current hb shape: [W_out, H_out, vp, uq * batch]
int w_out = (int)hb->ne[0];
int h_out = (int)hb->ne[1];

// struct ggml_tensor* hb_cat = ggml_reshape_4d(ctx, hb, w_out , h_out , vp * uq, batch);
// [W_out, H_out, vp * uq, batch]
// Now left to compute (W1 kr Id) * hb_cat == (W1 kr W2) cv h

// merge the uq groups of size vp*w_out*h_out
struct ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch);
struct ggml_tensor* hc_t;
struct ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged));
if (w1 != NULL) {
// Would be great to be able to transpose w1 instead to avoid transposing both hb and hc
hc_t = ggml_mul_mat(ctx, w1, hb_merged_t);
} else {
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_merged_t));
}
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
// ungroup
struct ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc), w_out, h_out, up * vp, batch);
return ggml_scale(ctx, out, scale);
}
}

#endif // __GGML_EXTEND__HPP__
116 changes: 108 additions & 8 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,10 @@ struct LoraModel : public GGMLRunner {
return updown;
}

ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) {
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
// lora
ggml_tensor* diff = nullptr;
if (with_lora) {
if (with_lora_and_lokr) {
diff = get_lora_weight_diff(model_tensor_name, ctx);
}
// diff
Expand All @@ -483,7 +483,7 @@ struct LoraModel : public GGMLRunner {
diff = get_loha_weight_diff(model_tensor_name, ctx);
}
// lokr
if (diff == nullptr) {
if (diff == nullptr && with_lora_and_lokr) {
diff = get_lokr_weight_diff(model_tensor_name, ctx);
}
if (diff != nullptr) {
Expand Down Expand Up @@ -514,6 +514,108 @@ struct LoraModel : public GGMLRunner {
} else {
key = model_tensor_name + "." + std::to_string(index);
}
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;

std::string lokr_w1_name = "lora." + key + ".lokr_w1";
std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a";
// if either of these is found, then we have a lokr lora
auto iter = lora_tensors.find(lokr_w1_name);
auto iter_a = lora_tensors.find(lokr_w1_a_name);
if (iter != lora_tensors.end() || iter_a != lora_tensors.end()) {
std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b";
std::string lokr_w2_name = "lora." + key + ".lokr_w2";
std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a";
std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b";
std::string alpha_name = "lora." + key + ".alpha";

ggml_tensor* lokr_w1 = nullptr;
ggml_tensor* lokr_w1_a = nullptr;
ggml_tensor* lokr_w1_b = nullptr;
ggml_tensor* lokr_w2 = nullptr;
ggml_tensor* lokr_w2_a = nullptr;
ggml_tensor* lokr_w2_b = nullptr;

if (iter != lora_tensors.end()) {
lokr_w1 = iter->second;
}
iter = iter_a;
if (iter != lora_tensors.end()) {
lokr_w1_a = iter->second;
}
iter = lora_tensors.find(lokr_w1_b_name);
if (iter != lora_tensors.end()) {
lokr_w1_b = iter->second;
}

iter = lora_tensors.find(lokr_w2_name);
if (iter != lora_tensors.end()) {
lokr_w2 = iter->second;
if (is_conv2d && lokr_w2->type != GGML_TYPE_F16) {
lokr_w2 = ggml_cast(ctx, lokr_w2, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_a_name);
if (iter != lora_tensors.end()) {
lokr_w2_a = iter->second;
if (is_conv2d && lokr_w2_a->type != GGML_TYPE_F16) {
lokr_w2_a = ggml_cast(ctx, lokr_w2_a, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_b_name);
if (iter != lora_tensors.end()) {
lokr_w2_b = iter->second;
if (is_conv2d && lokr_w2_b->type != GGML_TYPE_F16) {
lokr_w2_b = ggml_cast(ctx, lokr_w2_b, GGML_TYPE_F16);
}
}

int rank = 1;
if (lokr_w1_b) {
rank = (int)lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1];
}
if (lokr_w2_b) {
rank = (int)lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1];
}

float scale_value = 1.0f;
iter = lora_tensors.find(alpha_name);
if (iter != lora_tensors.end()) {
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
scale_value = alpha / rank;
applied_lora_tensors.insert(alpha_name);
}

if (rank == 1) {
scale_value = 1.0f;
}
scale_value *= multiplier;

auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
if (out_diff == nullptr) {
out_diff = curr_out_diff;
} else {
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, 0);
}

if (lokr_w1)
applied_lora_tensors.insert(lokr_w1_name);
if (lokr_w1_a)
applied_lora_tensors.insert(lokr_w1_a_name);
if (lokr_w1_b)
applied_lora_tensors.insert(lokr_w1_b_name);
if (lokr_w2)
applied_lora_tensors.insert(lokr_w2_name);
if (lokr_w2_a)
applied_lora_tensors.insert(lokr_w2_name);
if (lokr_w2_b)
applied_lora_tensors.insert(lokr_w2_b_name);
applied_lora_tensors.insert(alpha_name);

index++;
continue;
}

// not a lokr, normal lora path

std::string lora_down_name = "lora." + key + ".lora_down";
std::string lora_up_name = "lora." + key + ".lora_up";
Expand All @@ -525,9 +627,7 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* lora_mid = nullptr;
ggml_tensor* lora_down = nullptr;

bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;

auto iter = lora_tensors.find(lora_up_name);
iter = lora_tensors.find(lora_up_name);
if (iter != lora_tensors.end()) {
lora_up = iter->second;
if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
Expand Down Expand Up @@ -741,9 +841,9 @@ struct MultiLoraAdapter : public WeightAdapter {
: lora_models(lora_models) {
}

ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) {
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
for (auto& lora_model : lora_models) {
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora);
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr);
if (diff == nullptr) {
continue;
}
Expand Down
Loading