Skip to content

Commit

Permalink
model tags classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed Dec 8, 2023
1 parent 51e1629 commit f590017
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 45 deletions.
10 changes: 9 additions & 1 deletion frontend/src/_locales/ja/main.json
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,13 @@
"Please select a MIDI device first": "まずMIDIデバイスを選択してください",
"Piano is the main instrument": "ピアノはメインの楽器です",
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。",
"This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。"
"This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。",
"Main": "メイン",
"Finetuned": "微調整",
"Global": "グローバル",
"Local": "ローカル",
"CN": "中国語",
"JP": "日本語",
"Music": "音楽",
"Other": "その他"
}
10 changes: 9 additions & 1 deletion frontend/src/_locales/zh-hans/main.json
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,13 @@
"Please select a MIDI device first": "请先选择一个MIDI设备",
"Piano is the main instrument": "钢琴为主",
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的",
"This version of RWKV is not supported yet.": "暂不支持此版本的RWKV"
"This version of RWKV is not supported yet.": "暂不支持此版本的RWKV",
"Main": "主干",
"Finetuned": "微调",
"Global": "全球",
"Local": "本地",
"CN": "中文",
"JP": "日文",
"Music": "音乐",
"Other": "其他"
}
36 changes: 34 additions & 2 deletions frontend/src/pages/Models.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import React, { FC } from 'react';
import React, { FC, useEffect, useState } from 'react';
import {
Button,
Checkbox,
createTableColumn,
DataGrid,
Expand Down Expand Up @@ -154,6 +155,22 @@ const columns: TableColumnDefinition<ModelSourceItem>[] = [

const Models: FC = observer(() => {
const { t } = useTranslation();
const [tags, setTags] = useState<Array<string>>([]);
const [modelSourceList, setModelSourceList] = useState<ModelSourceItem[]>(commonStore.modelSourceList);

useEffect(() => {
setTags(Array.from(new Set(
[...commonStore.modelSourceList.map(item => item.tags || []).flat()
.filter(i => !i.includes('Other') && !i.includes('Local'))
, 'Other', 'Local'])));
}, [commonStore.modelSourceList]);

useEffect(() => {
if (commonStore.activeModelListTags.length === 0)
setModelSourceList(commonStore.modelSourceList);
else
setModelSourceList(commonStore.modelSourceList.filter(item => commonStore.activeModelListTags.some(tag => item.tags?.includes(tag))));
}, [commonStore.modelSourceList, commonStore.activeModelListTags]);

return (
<Page title={t('Models')} content={
Expand Down Expand Up @@ -184,9 +201,24 @@ const Models: FC = observer(() => {
value={commonStore.modelSourceManifestList}
onChange={(e, data) => commonStore.setModelSourceManifestList(data.value)} />
</div>
<div className="flex gap-2 flex-wrap overflow-y-auto" style={{ minHeight: '88px' }}>
{tags.map(tag =>
<div key={tag} className="mt-auto">
<Button
appearance={commonStore.activeModelListTags.includes(tag) ? 'primary' : 'secondary'} onClick={
() => {
if (commonStore.activeModelListTags.includes(tag))
commonStore.setActiveModelListTags(commonStore.activeModelListTags.filter(t => t !== tag));
else
commonStore.setActiveModelListTags([...commonStore.activeModelListTags, tag]);
}
}>{t(tag)}</Button>
</div>)
}
</div>
<div className="flex grow overflow-hidden">
<DataGrid
items={commonStore.modelSourceList}
items={modelSourceList}
columns={columns}
sortable={true}
defaultSortState={{ sortColumn: 'actions', sortDirection: 'ascending' }}
Expand Down
5 changes: 5 additions & 0 deletions frontend/src/stores/commonStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class CommonStore {
modelConfigs: ModelConfig[] = [];
modelParamsCollapsed: boolean = true;
// models
activeModelListTags: string[] = [];
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
modelSourceList: ModelSourceItem[] = [];
// downloads
Expand Down Expand Up @@ -453,6 +454,10 @@ class CommonStore {
setPlayingTrackId(value: string) {
this.playingTrackId = value;
}

setActiveModelListTags(value: string[]) {
this.activeModelListTags = value;
}
}

export default new CommonStore();
1 change: 1 addition & 0 deletions frontend/src/types/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ export type ModelSourceItem = {
isLocal?: boolean;
localSize?: number;
lastUpdatedMs?: number;
tags?: string[];
hide?: boolean;
};
7 changes: 6 additions & 1 deletion frontend/src/utils/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ export async function refreshLocalModels(cache: {
size: d.size,
lastUpdated: d.modTime,
isComplete: true,
isLocal: true
isLocal: true,
tags: ['Local']
}] as ModelSourceItem[];
return [];
}));
Expand All @@ -89,12 +90,15 @@ export async function refreshLocalModels(cache: {
for (let i = 0; i < cache.models.length; i++) {
if (!cache.models[i].lastUpdatedMs)
cache.models[i].lastUpdatedMs = Date.parse(cache.models[i].lastUpdated);
if (!cache.models[i].tags)
cache.models[i].tags = ['Other'];

for (let j = i + 1; j < cache.models.length; j++) {
if (!cache.models[j].lastUpdatedMs)
cache.models[j].lastUpdatedMs = Date.parse(cache.models[j].lastUpdated);

if (cache.models[i].name === cache.models[j].name) {
const tags = Array.from(new Set([...cache.models[i].tags as string[], ...cache.models[j].tags as string[]]));
if (cache.models[i].size <= cache.models[j].size) { // j is local file
if (cache.models[i].lastUpdatedMs! < cache.models[j].lastUpdatedMs!) {
cache.models[i] = Object.assign({}, cache.models[i], cache.models[j]);
Expand All @@ -104,6 +108,7 @@ export async function refreshLocalModels(cache: {
} // else is not complete local file
cache.models[i].isLocal = true;
cache.models[i].localSize = cache.models[j].size;
cache.models[i].tags = tags;
cache.models.splice(j, 1);
j--;
}
Expand Down
Loading

0 comments on commit f590017

Please sign in to comment.