Skip to content

Commit

Permalink
OpenAI Translator adds Azure OpenAI support
Browse files Browse the repository at this point in the history
Features: #3801, #3765
  • Loading branch information
YiiGuxing committed Feb 19, 2024
1 parent 8dd0fa5 commit 8d91c43
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ internal class OpenAICredentials private constructor() {
ServiceProvider.OpenAI -> service.openAi
ServiceProvider.Azure -> service.azure
}

fun isCredentialSet(provider: ServiceProvider): Boolean = manager(provider).isCredentialSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@ class OpenAISettings : BaseState(), PersistentStateComponent<OpenAISettings> {
@get:OptionTag("AZURE")
var azure: Azure by property(Azure())


@get:Transient
val model: OpenAIModel
get() = getOptions().model

@get:Transient
val isConfigured: Boolean
get() = when (provider) {
ServiceProvider.OpenAI -> openAi.model
ServiceProvider.Azure -> azure.model
ServiceProvider.Azure -> !azure.endpoint.isNullOrEmpty()
else -> true
}

fun getOptions(provider: ServiceProvider = this.provider): OpenAIService.Options = when (provider) {
ServiceProvider.OpenAI -> openAi
ServiceProvider.Azure -> azure
}

override fun getState(): OpenAISettings = this

override fun loadState(state: OpenAISettings) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,34 @@ import com.intellij.openapi.ui.DialogWrapper
import com.intellij.ui.CollectionComboBoxModel
import com.intellij.ui.DocumentAdapter
import com.intellij.ui.SimpleListCellRenderer
import com.intellij.ui.components.JBLabel
import com.intellij.ui.components.JBPasswordField
import com.intellij.ui.components.fields.ExtendableTextComponent.Extension
import com.intellij.ui.components.fields.ExtendableTextField
import com.intellij.util.Alarm
import com.intellij.util.ui.JBUI
import icons.TranslationIcons
import org.jetbrains.concurrency.runAsync
import javax.swing.JComponent
import javax.swing.JLabel
import javax.swing.JPanel
import javax.swing.SwingUtilities
import java.awt.event.ItemEvent
import javax.swing.*
import javax.swing.event.DocumentEvent
import javax.swing.text.JTextComponent

class OpenAISettingsDialog : DialogWrapper(false) {

private val settings = service<OpenAISettings>()
private val openAiState = OpenAISettings.OpenAI().apply { copyFrom(settings.openAi) }
private val azureState = OpenAISettings.Azure().apply { copyFrom(settings.azure) }
private val apiKeys: ApiKeys = ApiKeys()
private var isApiKeySet: Boolean = false

private val alarm: Alarm = Alarm(disposable)

private val apiKeyField: JBPasswordField = JBPasswordField()
private val apiKeyField: JBPasswordField = JBPasswordField().apply {
document.addDocumentListener(object : DocumentAdapter() {
override fun textChanged(e: DocumentEvent) {
apiKeys[provider] = String(password)
}
})
}
private val apiEndpointField: ExtendableTextField = ExtendableTextField().apply {
emptyText.text = DEFAULT_OPEN_AI_API_ENDPOINT
val extension = Extension.create(AllIcons.General.Reset, message("set.as.default.action.name")) {
Expand All @@ -48,33 +56,68 @@ class OpenAISettingsDialog : DialogWrapper(false) {
}
document.addDocumentListener(object : DocumentAdapter() {
override fun textChanged(e: DocumentEvent) {
alarm.cancelAllRequests()
alarm.addRequest(::verifyApiEndpoint, 300)
apiEndpointChanged()
if (apiEndpoint.isNullOrEmpty()) {
removeExtension(extension)
} else {
} else if (provider == ServiceProvider.OpenAI) {
addExtension(extension)
}
}
})
}

private val apiServiceProviderComboBox: ComboBox<ServiceProvider> =
ComboBox(CollectionComboBoxModel(ServiceProvider.values().toList())).apply {
renderer = SimpleListCellRenderer.create { label, model, _ ->
label.text = model.name
label.icon = getProviderIcon(model)
}
addItemListener { event ->
if (event.stateChange == ItemEvent.SELECTED) {
providerUpdated(event.item as ServiceProvider)
}
}
}
private val apiModelComboBox: ComboBox<OpenAIModel> =
ComboBox(CollectionComboBoxModel(OpenAIModel.values().toList())).apply {
renderer = SimpleListCellRenderer.create { label, model, _ ->
label.text = model.modelName
}
addItemListener {
if (it.stateChange == ItemEvent.SELECTED) {
currentState.model = it.item as OpenAIModel
}
}
}
private val azureServiceVersionComboBox: ComboBox<AzureServiceVersion> =
ComboBox(CollectionComboBoxModel(AzureServiceVersion.values().toList())).apply {
isVisible = false
renderer = SimpleListCellRenderer.create { label, model, _ ->
label.text = model.value
}
addItemListener {
if (it.stateChange == ItemEvent.SELECTED) {
azureState.apiVersion = it.item as AzureServiceVersion
}
}
}

private var apiKey: String?
get() = apiKeyField.password
?.takeIf { it.isNotEmpty() }
?.let { String(it) }
private val apiVersionLabel = JLabel(message("openai.settings.dialog.label.api.version")).apply {
isVisible = false
}
private lateinit var hintComponent: JTextComponent

private var provider: ServiceProvider
get() = apiServiceProviderComboBox.selected ?: ServiceProvider.OpenAI
set(value) {
apiKeyField.text = if (value.isNullOrEmpty()) null else value
apiServiceProviderComboBox.selectedItem = value
}

private var isOK: Boolean = false
private val currentState: OpenAISettings.OpenAI
get() = when (provider) {
ServiceProvider.OpenAI -> openAiState
ServiceProvider.Azure -> azureState
}

private var apiEndpoint: String?
get() = apiEndpointField.text?.trim()?.takeIf { it.isNotEmpty() }
Expand All @@ -87,8 +130,10 @@ class OpenAISettingsDialog : DialogWrapper(false) {
setResizable(false)
init()

apiEndpoint = settings.openAi.endpoint
apiModelComboBox.selectedItem = settings.openAi.model
provider = settings.provider
azureServiceVersionComboBox.selectedItem = settings.azure.apiVersion

providerUpdated(settings.provider)
}


Expand All @@ -100,32 +145,80 @@ class OpenAISettingsDialog : DialogWrapper(false) {

private fun createConfigurationPanel(): JPanel {
val fieldWidth = 320
val apiPathLabel = JBLabel(OPEN_AI_API_PATH).apply {
border = JBUI.Borders.emptyRight(apiEndpointField.insets.right)
isEnabled = false
}
return JPanel(UI.migLayout()).apply {
hintComponent = UI.createHint("", fieldWidth, apiKeyField)
return JPanel(UI.migLayout(lcBuilder = { hideMode(2) })).apply {
val gapCC = UI.cc().gapRight(migSize(8))
add(JLabel(message("openai.settings.dialog.label.model")), gapCC)
add(apiModelComboBox, UI.fillX().wrap())
val comboBoxCC = UI.cc().width(migSize((fieldWidth * 0.6).toInt())).wrap()
add(JLabel(message("openai.settings.dialog.label.api.provider")), gapCC)
add(apiServiceProviderComboBox, comboBoxCC)
add(JLabel(message("openai.settings.dialog.label.api.model")), gapCC)
add(apiModelComboBox, comboBoxCC)
add(apiVersionLabel, gapCC)
add(azureServiceVersionComboBox, comboBoxCC)
add(JLabel(message("openai.settings.dialog.label.api.endpoint")), gapCC)
add(apiEndpointField, UI.fillX())
add(apiPathLabel, UI.cc().gapLeft(migSize(2)).wrap())
add(apiEndpointField, UI.fillX().wrap())
add(JLabel(message("openai.settings.dialog.label.api.key")), gapCC)
add(apiKeyField, UI.fillX().spanX(2).minWidth(migSize(fieldWidth)).wrap())
add(
UI.createHint(message("openai.settings.dialog.hint"), fieldWidth, apiKeyField),
UI.cc().cell(1, 3).spanX(2).wrap()
)
add(apiKeyField, UI.fillX().minWidth(migSize(fieldWidth)).wrap())
add(hintComponent, UI.cc().cell(1, 5).wrap())
}
}

override fun getHelpId(): String = HelpTopic.OPEN_AI.id

override fun isOK(): Boolean = isOK
override fun isOK(): Boolean {
return isApiKeySet && currentState.endpoint.isValidEndpoint(provider == ServiceProvider.OpenAI)
}

private fun getHint(provider: ServiceProvider): String {
return when (provider) {
ServiceProvider.OpenAI -> message("openai.settings.dialog.hint")
ServiceProvider.Azure -> message("openai.settings.dialog.hint.azure")
}
}

private fun getProviderIcon(provider: ServiceProvider): Icon {
return when (provider) {
ServiceProvider.OpenAI -> TranslationIcons.Engines.OpenAI
ServiceProvider.Azure -> AllIcons.Providers.Azure
}
}

private fun providerUpdated(newProvider: ServiceProvider) {
val isAzure = newProvider == ServiceProvider.Azure
apiVersionLabel.isVisible = isAzure
azureServiceVersionComboBox.isVisible = isAzure
hintComponent.text = getHint(newProvider)

if (isAzure) {
apiEndpointField.setExtensions(emptyList())
apiEndpointField.emptyText.text = ""
} else {
apiEndpointField.emptyText.text = DEFAULT_OPEN_AI_API_ENDPOINT
}

currentState.let {
apiModelComboBox.selected = it.model
apiEndpoint = it.endpoint
}
apiKeyField.text = apiKeys[newProvider]

verifyApiEndpoint()
}

private fun apiEndpointChanged() {
alarm.cancelAllRequests()
alarm.addRequest(::verifyApiEndpoint, 300)

val endpoint = apiEndpoint
currentState.endpoint = if (endpoint.isValidEndpoint(provider == ServiceProvider.OpenAI)) {
endpoint
} else {
settings.getOptions(provider).endpoint
}
}

private fun verifyApiEndpoint(): Boolean {
if (apiEndpoint.let { it == null || URL_REGEX.matches(it) }) {
if (apiEndpoint.isValidEndpoint(provider == ServiceProvider.OpenAI)) {
setErrorText(null)
return true
}
Expand All @@ -139,19 +232,25 @@ class OpenAISettingsDialog : DialogWrapper(false) {
return
}

OpenAICredentials.manager(ServiceProvider.OpenAI).credential = apiKey
isOK = OpenAICredentials.manager(ServiceProvider.OpenAI).isCredentialSet
settings.openAi.endpoint = apiEndpoint
OpenAICredentials.manager(ServiceProvider.OpenAI).credential = apiKeys.openAi
OpenAICredentials.manager(ServiceProvider.Azure).credential = apiKeys.azure

val oldModel = settings.openAi.model
val newModel = apiModelComboBox.selected ?: OpenAIModel.GPT_3_5_TURBO
if (oldModel != newModel) {
settings.openAi.model = newModel
val oldProvider = settings.provider
val newProvider = provider
if (oldProvider != newProvider ||
openAiState.model != settings.openAi.model ||
azureState.model != settings.azure.model
) {
service<CacheService>().removeMemoryCache { key, _ ->
key.translator == TranslationEngine.OPEN_AI.id
}
}

settings.provider = newProvider
settings.openAi.copyFrom(openAiState)
settings.azure.copyFrom(azureState)
isApiKeySet = OpenAICredentials.manager(provider).isCredentialSet

super.doOKAction()
}

Expand All @@ -160,13 +259,16 @@ class OpenAISettingsDialog : DialogWrapper(false) {
SwingUtilities.invokeLater {
val dialogRef = DisposableRef.create(disposable, this)
runAsync {
OpenAICredentials.manager(ServiceProvider.OpenAI)
.let { it.credential to it.isCredentialSet }
ApiKeys(
OpenAICredentials.manager(ServiceProvider.OpenAI).credential,
OpenAICredentials.manager(ServiceProvider.Azure).credential
)
}
.expireWith(disposable)
.successOnUiThread(dialogRef) { dialog, (apiKey, isApiKeySet) ->
dialog.apiKey = apiKey
dialog.isOK = isApiKeySet
.successOnUiThread(dialogRef) { dialog, apiKeys ->
dialog.apiKeys.copyFrom(apiKeys)
dialog.apiKeyField.text = apiKeys[dialog.provider]
dialog.isApiKeySet = !apiKeys[dialog.provider].isNullOrEmpty()
}
.disposeAfterProcessing(dialogRef)
}
Expand All @@ -176,5 +278,31 @@ class OpenAISettingsDialog : DialogWrapper(false) {

private companion object {
val URL_REGEX = "^https?://([^/?#\\s]+)([^?#;\\s]*)$".toRegex()

fun String?.isValidEndpoint(canBeNull: Boolean = true): Boolean {
return this?.let { URL_REGEX.matches(it) } ?: canBeNull
}
}

private data class ApiKeys(var openAi: String? = null, var azure: String? = null) {

operator fun get(provider: ServiceProvider): String? {
return when (provider) {
ServiceProvider.OpenAI -> openAi
ServiceProvider.Azure -> azure
}
}

operator fun set(provider: ServiceProvider, value: String?) {
when (provider) {
ServiceProvider.OpenAI -> openAi = value
ServiceProvider.Azure -> azure = value
}
}

fun copyFrom(apiKeys: ApiKeys) {
openAi = apiKeys.openAi
azure = apiKeys.azure
}
}
}
7 changes: 4 additions & 3 deletions src/main/kotlin/cn/yiiguxing/plugin/translate/ui/UI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import javax.swing.Icon
import javax.swing.JComponent
import javax.swing.JEditorPane
import javax.swing.border.Border
import javax.swing.text.JTextComponent

/**
* UI
Expand Down Expand Up @@ -66,8 +67,8 @@ object UI {
setHoveringIcon(IconUtil.darker(baseIcon, 3))
}

fun migLayout(gapX: String = "0!", gapY: String = "0!", insets: String = "0") =
MigLayout(LC().fill().gridGap(gapX, gapY).insets(insets))
fun migLayout(gapX: String = "0!", gapY: String = "0!", insets: String = "0", lcBuilder: (LC.() -> Unit)? = null) =
MigLayout(LC().fill().gridGap(gapX, gapY).insets(insets).also { lcBuilder?.invoke(it) })

fun migLayoutVertical() =
MigLayout(LC().flowY().fill().gridGap("0!", "0!").insets("0"))
Expand Down Expand Up @@ -101,7 +102,7 @@ object UI {

operator fun Border.plus(external: Border): Border = JBUI.Borders.merge(this, external, true)

fun createHint(content: String, componentWidth: Int = 300, hintForComponent: JComponent? = null): JComponent =
fun createHint(content: String, componentWidth: Int = 300, hintForComponent: JComponent? = null): JTextComponent =
JEditorPane().apply {
isEditable = false
isFocusable = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ enum class TranslationEngine(
BAIDU -> isConfigured(Settings.baiduTranslateSettings)
ALI -> isConfigured(Settings.aliTranslateSettings)
DEEPL -> DeeplCredential.isAuthKeySet
OPEN_AI -> OpenAICredentials.manager(service<OpenAISettings>().provider).isCredentialSet
OPEN_AI -> service<OpenAISettings>().let { it.isConfigured && OpenAICredentials.isCredentialSet(it.provider) }
}
}

Expand Down
Loading

0 comments on commit 8d91c43

Please sign in to comment.