Skip to content

Commit

Permalink
Merge pull request #19 from RockChinQ/feat/random-ad
Browse files Browse the repository at this point in the history
Feat: add supports for optional ads
  • Loading branch information
RockChinQ authored Oct 6, 2023
2 parents 6213923 + e6f5949 commit a1124d0
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 2 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ database:
# SQLite 数据库文件路径
path: ./data/free_one_api.db
type: sqlite
logging:
debug: false # 是否开启调试日志
# 随机广告
# 会随机追加到每个响应的末尾
random_ad:
# 广告列表
ad_list:
- ' (This response is sponsored by Free One API. Consider star the project on GitHub:
https://github.com/RockChinQ/free-one-api )'
# 是否开启随机广告
enabled: false
# 广告出现概率 (0-1)
rate: 0.05
router:
# 后端监听端口
port: 3000
Expand Down
12 changes: 12 additions & 0 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ database:
# SQLite DB file path
path: ./data/free_one_api.db
type: sqlite
logging:
debug: false # Enable debug log
# Random advertisement, will be appended to the end of each response
random_ad:
# advertisement list
ad_list:
- ' (This response is sponsored by Free One API. Consider star the project on GitHub:
https://github.com/RockChinQ/free-one-api )'
# Enable random ad
enabled: false
# Random ad rate
rate: 0.05
router:
# Backend listen port
port: 3000
Expand Down
27 changes: 27 additions & 0 deletions free_one_api/common/randomad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import random
import typing

ads = []

rate = 0.01

enabled = False

def generate_ad() -> typing.Generator[str, None, None]:
"""Generate random ad."""
global ads
global rate
global enabled

if not enabled:
return

if len(ads) == 0:
return

if random.random() < rate:
ad_words = random.choice(ads).split(" ")

for word in ad_words:
yield word
yield " "
19 changes: 18 additions & 1 deletion free_one_api/impls/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ async def run(self):
},
"logging": {
"debug": False,
},
"random_ad": {
"enabled": False,
"rate": 0.05,
"ad_list": [
" (This response is sponsored by Free One API. Consider star the project on GitHub: https://github.com/RockChinQ/free-one-api )",
]
}
}

Expand Down Expand Up @@ -131,7 +138,17 @@ async def make_application(config_path: str) -> Application:
for handler in logging.getLogger().handlers:
logging.getLogger().removeHandler(handler)

logging.getLogger().addHandler(terminal_out)
logging.getLogger().addHandler(terminal_out)

# save ad to runtime
if 'random_ad' in config and config['random_ad']['enabled']:
from ..common import randomad

randomad.enabled = config['random_ad']['enabled']
randomad.rate = config['random_ad']['rate']
randomad.ads = config['random_ad']['ad_list']

from ..common import randomad

# make database manager
from .database import sqlite as sqlitedb
Expand Down
25 changes: 24 additions & 1 deletion free_one_api/impls/forward/mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import string
import random
import logging
import typing

import quart

from ...models.forward import mgr as forwardmgr
from ...models.channel import mgr as channelmgr
from ...models.key import mgr as apikeymgr
from ...entities import channel, apikey, request, response, exceptions
from ...common import randomad


class ForwardManager(forwardmgr.AbsForwardManager):

def __init__(self, chanmgr: channelmgr.AbsChannelManager, keymgr: apikeymgr.AbsAPIKeyManager):
self.chanmgr = chanmgr
self.keymgr = keymgr

async def __stream_query(
self,
chan: channel.Channel,
Expand Down Expand Up @@ -47,6 +49,23 @@ async def _gen():
"finish_reason": resp.finish_reason.value
}]
}))

if randomad.enabled:
for word in randomad.generate_ad():
yield "data: {}\n\n".format(json.dumps({
"id": "chatcmpl-"+id_suffix,
"object": "chat.completion.chunk",
"created": t,
"model": req.model,
"choices": [{
"index": 0,
"delta": {
"content": word,
},
"finish_reason": response.FinishReason.NULL.value
}]
}))

yield "data: [DONE]\n\n"
except exceptions.QueryHandlingError as e:
yield "data: {}\n\ndata: [DONE]\n\n".format(json.dumps({
Expand Down Expand Up @@ -95,6 +114,10 @@ async def __non_stream_query(
resp_tmp = resp
normal_message += resp.normal_message

if randomad.enabled:
for word in randomad.generate_ad():
normal_message += word

except exceptions.QueryHandlingError as e:
# check for custom error raised by adapter
return quart.jsonify({
Expand Down

0 comments on commit a1124d0

Please sign in to comment.