Skip to content

Commit

Permalink
Add Azure OpenAI service
Browse files Browse the repository at this point in the history
  • Loading branch information
YiiGuxing committed Jan 28, 2024
1 parent 302098b commit 7c095af
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 122 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@file:Suppress("unused")

package cn.yiiguxing.plugin.translate.trans.openai

enum class AzureServiceVersion(val value: String) {
V2023_05_15("2023-05-15"),
V2023_12_01_PREVIEW("2023-12-01-preview");
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package cn.yiiguxing.plugin.translate.trans.openai

import cn.yiiguxing.plugin.translate.TranslationPlugin
import cn.yiiguxing.plugin.translate.util.credential.SimpleStringCredentialManager
import cn.yiiguxing.plugin.translate.util.credential.StringCredentialManager
import com.intellij.credentialStore.generateServiceName
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service

@Service
internal class OpenAICredentials private constructor() {

val openAi = SimpleStringCredentialManager(OPEN_AI_SERVICE_NAME)
val azure = SimpleStringCredentialManager(AZURE_SERVICE_NAME)

companion object {
private val OPEN_AI_SERVICE_NAME =
generateServiceName("OpenAI Credentials", "${TranslationPlugin.PLUGIN_ID}.OPENAI_API_KEY")
private val AZURE_SERVICE_NAME =
generateServiceName("OpenAI Credentials", "${TranslationPlugin.PLUGIN_ID}.AZURE_OPENAI_API_KEY")

private val service: OpenAICredentials get() = service()

fun manager(provider: ServiceProvider): StringCredentialManager = when (provider) {
ServiceProvider.OpenAI -> service.openAi
ServiceProvider.Azure -> service.azure
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
package cn.yiiguxing.plugin.translate.trans.openai

/**
* See: [OpenAI Models](https://platform.openai.com/docs/models)
* See: [OpenAIService Models](https://platform.openai.com/docs/models)
*/
enum class OpenAIModel(val value: String, val modelName: String) {
GPT_3_5_TURBO("gpt-3.5-turbo", "GPT-3.5-Turbo"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package cn.yiiguxing.plugin.translate.trans.openai

import cn.yiiguxing.plugin.translate.trans.openai.chat.ChatCompletion
import cn.yiiguxing.plugin.translate.trans.openai.chat.ChatMessage
import cn.yiiguxing.plugin.translate.trans.openai.chat.chatCompletionRequest
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
import com.intellij.util.io.RequestBuilder

const val DEFAULT_OPEN_AI_API_ENDPOINT = "https://api.openai.com"
const val OPEN_AI_API_PATH = "/v1/chat/completions"

private const val AZURE_OPEN_AI_API_PATH = "/openai/deployments/%s/chat/completions"

interface OpenAIService {

@RequiresBackgroundThread
fun chatCompletion(messages: List<ChatMessage>): ChatCompletion

interface Options {
val model: OpenAIModel
val endpoint: String?
}

interface AzureOptions : Options {
val apiVersion: AzureServiceVersion
}

companion object {
fun get(settings: OpenAISettings): OpenAIService {
return when (settings.provider) {
ServiceProvider.OpenAI -> OpenAI(settings.openAi)
ServiceProvider.Azure -> Azure(settings.azure)
}
}
}
}


class OpenAI(private val options: OpenAIService.Options) : OpenAIService {
private val apiUrl: String
get() = (options.endpoint ?: DEFAULT_OPEN_AI_API_ENDPOINT).trimEnd('/') + OPEN_AI_API_PATH

private fun RequestBuilder.auth() {
val apiKey = OpenAICredentials.manager(ServiceProvider.OpenAI).credential
tuner { it.setRequestProperty("Authorization", "Bearer $apiKey") }
}

override fun chatCompletion(messages: List<ChatMessage>): ChatCompletion {
val request = chatCompletionRequest {
model = options.model.value
this.messages = messages
}
return OpenAIHttp.post<ChatCompletion>(apiUrl, request) { auth() }
}
}

class Azure(options: OpenAIService.AzureOptions) : OpenAIService {

private val apiUrl: String = requireNotNull(options.endpoint) { "Azure OpenAI API endpoint is required" } +
AZURE_OPEN_AI_API_PATH.format(options.model.value) +
"?api-version=${options.apiVersion.value}"

private fun RequestBuilder.auth() {
val apiKey = OpenAICredentials.manager(ServiceProvider.Azure).credential
tuner { it.setRequestProperty("api-key", apiKey) }
}

override fun chatCompletion(messages: List<ChatMessage>): ChatCompletion {
val request = chatCompletionRequest(false) {
this.messages = messages
}
return OpenAIHttp.post<ChatCompletion>(apiUrl, request) { auth() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cn.yiiguxing.plugin.translate.trans.openai
import cn.yiiguxing.plugin.translate.TranslationStorages
import com.intellij.openapi.components.*
import com.intellij.util.xmlb.annotations.OptionTag
import com.intellij.util.xmlb.annotations.Tag
import com.intellij.util.xmlb.annotations.Transient


/**
Expand All @@ -12,15 +14,40 @@ import com.intellij.util.xmlb.annotations.OptionTag
@State(name = "Translation.OpenAISettings", storages = [Storage(TranslationStorages.PREFERENCES_STORAGE_NAME)])
class OpenAISettings : BaseState(), PersistentStateComponent<OpenAISettings> {

@get:OptionTag("MODEL")
var model: OpenAIModel by enum(OpenAIModel.GPT_3_5_TURBO)
@get:OptionTag("PROVIDER")
var provider: ServiceProvider by enum(ServiceProvider.OpenAI)

@get:OptionTag("API_ENDPOINT")
var apiEndpoint: String? by string()
@get:OptionTag("OPEN_AI")
var openAi: OpenAI by property(OpenAI())

@get:OptionTag("AZURE")
var azure: Azure by property(Azure())

@get:Transient
val model: OpenAIModel
get() = when (provider) {
ServiceProvider.OpenAI -> openAi.model
ServiceProvider.Azure -> azure.model
}

override fun getState(): OpenAISettings = this

override fun loadState(state: OpenAISettings) {
copyFrom(state)
}

@Tag("open-ai")
open class OpenAI : BaseState(), OpenAIService.Options {
@get:OptionTag("MODEL")
override var model: OpenAIModel by enum(OpenAIModel.GPT_3_5_TURBO)

@get:OptionTag("ENDPOINT")
override var endpoint: String? by string()
}

@Tag("azure")
class Azure : OpenAI(), OpenAIService.AzureOptions {
@get:OptionTag("API_VERSION")
override var apiVersion: AzureServiceVersion by enum(AzureServiceVersion.V2023_05_15)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OpenAISettingsDialog : DialogWrapper(false) {

private val apiKeyField: JBPasswordField = JBPasswordField()
private val apiEndpointField: ExtendableTextField = ExtendableTextField().apply {
emptyText.text = OpenAI.DEFAULT_API_ENDPOINT
emptyText.text = DEFAULT_OPEN_AI_API_ENDPOINT
val extension = Extension.create(AllIcons.General.Reset, message("set.as.default.action.name")) {
text = null
setErrorText(null)
Expand Down Expand Up @@ -87,8 +87,8 @@ class OpenAISettingsDialog : DialogWrapper(false) {
setResizable(false)
init()

apiEndpoint = settings.apiEndpoint
apiModelComboBox.selectedItem = settings.model
apiEndpoint = settings.openAi.endpoint
apiModelComboBox.selectedItem = settings.openAi.model
}


Expand All @@ -100,7 +100,7 @@ class OpenAISettingsDialog : DialogWrapper(false) {

private fun createConfigurationPanel(): JPanel {
val fieldWidth = 320
val apiPathLabel = JBLabel(OpenAI.API_PATH).apply {
val apiPathLabel = JBLabel(OPEN_AI_API_PATH).apply {
border = JBUI.Borders.emptyRight(apiEndpointField.insets.right)
isEnabled = false
}
Expand Down Expand Up @@ -139,14 +139,14 @@ class OpenAISettingsDialog : DialogWrapper(false) {
return
}

OpenAICredential.apiKey = apiKey
isOK = OpenAICredential.isApiKeySet
settings.apiEndpoint = apiEndpoint
OpenAICredentials.manager(ServiceProvider.OpenAI).credential = apiKey
isOK = OpenAICredentials.manager(ServiceProvider.OpenAI).isCredentialSet
settings.openAi.endpoint = apiEndpoint

val oldModel = settings.model
val oldModel = settings.openAi.model
val newModel = apiModelComboBox.selected ?: OpenAIModel.GPT_3_5_TURBO
if (oldModel != newModel) {
settings.model = newModel
settings.openAi.model = newModel
service<CacheService>().removeMemoryCache { key, _ ->
key.translator == TranslationEngine.OPEN_AI.id
}
Expand All @@ -159,7 +159,10 @@ class OpenAISettingsDialog : DialogWrapper(false) {
// This is a modal dialog, so it needs to be invoked later.
SwingUtilities.invokeLater {
val dialogRef = DisposableRef.create(disposable, this)
runAsync { OpenAICredential.apiKey to OpenAICredential.isApiKeySet }
runAsync {
OpenAICredentials.manager(ServiceProvider.OpenAI)
.let { it.credential to it.isCredentialSet }
}
.expireWith(disposable)
.successOnUiThread(dialogRef) { dialog, (apiKey, isApiKeySet) ->
dialog.apiKey = apiKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import cn.yiiguxing.plugin.translate.message
import cn.yiiguxing.plugin.translate.service.CacheService
import cn.yiiguxing.plugin.translate.trans.*
import cn.yiiguxing.plugin.translate.trans.openai.chat.ChatRole
import cn.yiiguxing.plugin.translate.trans.openai.chat.chatCompletionRequest
import cn.yiiguxing.plugin.translate.trans.openai.chat.chatMessages
import cn.yiiguxing.plugin.translate.trans.openai.exception.OpenAIStatusException
import cn.yiiguxing.plugin.translate.ui.settings.TranslationEngine.OPEN_AI
import cn.yiiguxing.plugin.translate.util.md5
Expand All @@ -25,11 +25,11 @@ object OpenAITranslator : AbstractTranslator(), DocumentationTranslator {
OpenAILanguages.languages.toMutableList().apply { add(0, Lang.AUTO) }
override val supportedTargetLanguages: List<Lang> = OpenAILanguages.languages

private val openAIModel: OpenAIModel get() = service<OpenAISettings>().model
private val settings: OpenAISettings get() = service<OpenAISettings>()


override fun checkConfiguration(force: Boolean): Boolean {
if (force || !OpenAICredential.isApiKeySet) {
if (force || !OpenAICredentials.manager(settings.provider).isCredentialSet) {
return OPEN_AI.showConfigurationDialog()
}

Expand Down Expand Up @@ -58,16 +58,15 @@ object OpenAITranslator : AbstractTranslator(), DocumentationTranslator {
targetLang: Lang,
isFofDocumentation: Boolean
): String {
val model = openAIModel
val cacheService = service<CacheService>()
val cacheKey = getCacheKey(model, text, srcLang, targetLang)
val cacheKey = getCacheKey(text, srcLang, targetLang)
val cache = cacheService.getDiskCache(cacheKey)
if (!cache.isNullOrEmpty()) {
return cache
}

val request = getChatCompletionRequest(model, text, srcLang, targetLang, isFofDocumentation)
val chatCompletion = OpenAI.chatCompletion(request)
val request = getChatCompletionRequest(text, srcLang, targetLang, isFofDocumentation)
val chatCompletion = OpenAIService.get(settings).chatCompletion(request)
var result = chatCompletion.choices.first().message!!.content
if (!isFofDocumentation && result.length > 1 && result.first() == '"' && result.last() == '"') {
result = result.substring(1, result.lastIndex)
Expand All @@ -78,36 +77,32 @@ object OpenAITranslator : AbstractTranslator(), DocumentationTranslator {
}

private fun getChatCompletionRequest(
openAIModel: OpenAIModel,
text: String,
srcLang: Lang,
targetLang: Lang,
isFofDocumentation: Boolean = false
) =
chatCompletionRequest {
model = openAIModel.value
messages {
message {
role = ChatRole.SYSTEM
content = "You are a translation engine that can " + if (isFofDocumentation) {
"translate HTML document."
} else {
"only translate text and cannot interpret it."
}
}
message {
role = ChatRole.USER
content =
"Translate ${if (srcLang == Lang.AUTO) "" else "from ${srcLang.openAILanguage} "}to ${targetLang.openAILanguage}."
}
message {
role = ChatRole.USER
content = if (isFofDocumentation) text else """"$text""""
}
) = chatMessages {
message {
role = ChatRole.SYSTEM
content = "You are a translation engine that can " + if (isFofDocumentation) {
"translate HTML document."
} else {
"only translate text and cannot interpret it."
}
}
message {
role = ChatRole.USER
content =
"Translate ${if (srcLang == Lang.AUTO) "" else "from ${srcLang.openAILanguage} "}to ${targetLang.openAILanguage}."
}
message {
role = ChatRole.USER
content = if (isFofDocumentation) text else """"$text""""
}
}

private fun getCacheKey(model: OpenAIModel, text: String, srcLang: Lang, targetLang: Lang): String {
private fun getCacheKey(text: String, srcLang: Lang, targetLang: Lang): String {
val model = settings.model
return "$id$model$text$srcLang$targetLang".md5()
}

Expand Down
Loading

0 comments on commit 7c095af

Please sign in to comment.