Skip to content

Commit

Permalink
add sml/svm
Browse files Browse the repository at this point in the history
  • Loading branch information
lwxxxxxxx committed Oct 27, 2023
1 parent 641471f commit d542440
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions sml/svm/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,20 @@ def __init__(self, kernel="rbf", C=1.0, gamma='scale', max_iter=300, tol=1e-3):
self.tol = tol
self.n_features = None

self.alpha = None
self.alpha_y = None
self.b = None

self.X = None

assert self.gamma in {'scale', 'auto'}, "Gamma only support 'scale' and 'auto'"
assert self.kernel == "rbf", "Kernel function only support 'rbf'"

def cal_kernel(self, x, x_):
"""Calculate kernel."""
assert self.gamma in {'scale', 'auto'}, "Gamma only support 'scale' and 'auto'"
gamma = {
'scale': 1 / (self.n_features * x.var()),
'auto': 1 / self.n_features,
}[self.gamma]

assert self.kernel == "rbf", "Kernel function only support 'rbf'"
kernel_res = jnp.exp(
-gamma
* (
Expand Down Expand Up @@ -111,9 +110,8 @@ def fit(self, X, y):
j = smo.working_set_select_j(i, alpha, y, neg_y_grad, Q)
neg_y_grad, alpha = smo.update(i, j, Q, y, alpha, neg_y_grad)

self.alpha = alpha
self.b = smo.cal_b(alpha, neg_y_grad, y)
self.alpha_y = self.alpha * y
self.alpha_y = alpha * y

self.X = X

Expand All @@ -124,12 +122,6 @@ def predict(self, x):
Parameters
----------
X : {array-like}, shape (n_samples, n_features)
Input data.
y : {array-like}, shape (n_samples)
Lable of the input data.
x : {array-like}, shape (n_samples, n_features)
Input data for prediction.
Expand Down

0 comments on commit d542440

Please sign in to comment.