Skip to content

Commit

Permalink
Merge pull request #319 from InjectiveLabs/feat/refactor_network_mixe…
Browse files Browse the repository at this point in the history
…d_secure_insecure_channels

feat/refactor_network_mixed_secure_insecure_channels
  • Loading branch information
aarmoa authored Apr 19, 2024
2 parents db9b6c0 + c47c94e commit b43fc0a
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 33 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

All notable changes to this project will be documented in this file.

## [1.5.0] - 2024-04-19
### Changed
- Refactoring in Network class to support mixed secure and insecure endpoints.
- Marked the Network parameter `use_secure_connection` as deprecated.

## [1.4.2] - 2024-03-19
### Changed
- Updated `aiohttp` dependency version to ">=3.9.2" to solve a security vulnerability detected by Dependabot
Expand Down
33 changes: 12 additions & 21 deletions pyinjective/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self,
network: Network,
insecure: Optional[bool] = None,
credentials=grpc.ssl_channel_credentials(),
credentials=None,
):
# the `insecure` parameter is ignored and will be deprecated soon. The value is taken directly from `network`
if insecure is not None:
Expand All @@ -97,6 +97,13 @@ def __init__(
DeprecationWarning,
stacklevel=2,
)
# the `credentials` parameter is ignored and will be deprecated soon. The value is taken directly from `network`
if credentials is not None:
warn(
"credentials parameter in AsyncClient is no longer used and will be deprecated",
DeprecationWarning,
stacklevel=2,
)

self.addr = ""
self.number = 0
Expand All @@ -105,11 +112,7 @@ def __init__(
self.network = network

# chain stubs
self.chain_channel = (
grpc.aio.secure_channel(network.grpc_endpoint, credentials)
if (network.use_secure_connection and credentials is not None)
else grpc.aio.insecure_channel(network.grpc_endpoint)
)
self.chain_channel = self.network.create_chain_grpc_channel()

self.stubCosmosTendermint = tendermint_query_grpc.ServiceStub(self.chain_channel)
self.stubAuth = auth_query_grpc.QueryStub(self.chain_channel)
Expand All @@ -121,11 +124,7 @@ def __init__(
self.timeout_height = 1

# exchange stubs
self.exchange_channel = (
grpc.aio.secure_channel(network.grpc_exchange_endpoint, credentials)
if (network.use_secure_connection and credentials is not None)
else grpc.aio.insecure_channel(network.grpc_exchange_endpoint)
)
self.exchange_channel = self.network.create_exchange_grpc_channel()
self.stubMeta = exchange_meta_rpc_grpc.InjectiveMetaRPCStub(self.exchange_channel)
self.stubExchangeAccount = exchange_accounts_rpc_grpc.InjectiveAccountsRPCStub(self.exchange_channel)
self.stubOracle = oracle_rpc_grpc.InjectiveOracleRPCStub(self.exchange_channel)
Expand All @@ -138,18 +137,10 @@ def __init__(
self.stubPortfolio = portfolio_rpc_grpc.InjectivePortfolioRPCStub(self.exchange_channel)

# explorer stubs
self.explorer_channel = (
grpc.aio.secure_channel(network.grpc_explorer_endpoint, credentials)
if (network.use_secure_connection and credentials is not None)
else grpc.aio.insecure_channel(network.grpc_explorer_endpoint)
)
self.explorer_channel = self.network.create_explorer_grpc_channel()
self.stubExplorer = explorer_rpc_grpc.InjectiveExplorerRPCStub(self.explorer_channel)

self.chain_stream_channel = (
grpc.aio.secure_channel(network.chain_stream_endpoint, credentials)
if (network.use_secure_connection and credentials is not None)
else grpc.aio.insecure_channel(network.chain_stream_endpoint)
)
self.chain_stream_channel = self.network.create_chain_stream_grpc_channel()
self.chain_stream_stub = stream_rpc_grpc.StreamStub(channel=self.chain_stream_channel)

self._timeout_height_sync_task = None
Expand Down
111 changes: 101 additions & 10 deletions pyinjective/core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from abc import ABC, abstractmethod
from http.cookies import SimpleCookie
from typing import Callable, Optional, Tuple
from warnings import warn

import grpc
from grpc import ChannelCredentials


class CookieAssistant(ABC):
Expand Down Expand Up @@ -181,8 +185,20 @@ def __init__(
fee_denom: str,
env: str,
cookie_assistant: CookieAssistant,
use_secure_connection: bool = False,
use_secure_connection: Optional[bool] = None,
grpc_channel_credentials: Optional[ChannelCredentials] = None,
grpc_exchange_channel_credentials: Optional[ChannelCredentials] = None,
grpc_explorer_channel_credentials: Optional[ChannelCredentials] = None,
chain_stream_channel_credentials: Optional[ChannelCredentials] = None,
):
# the `use_secure_connection` parameter is ignored and will be deprecated soon.
if use_secure_connection is not None:
warn(
"use_secure_connection parameter in Network is no longer used and will be deprecated",
DeprecationWarning,
stacklevel=2,
)

self.lcd_endpoint = lcd_endpoint
self.tm_websocket_endpoint = tm_websocket_endpoint
self.grpc_endpoint = grpc_endpoint
Expand All @@ -193,7 +209,10 @@ def __init__(
self.fee_denom = fee_denom
self.env = env
self.cookie_assistant = cookie_assistant
self.use_secure_connection = use_secure_connection
self.grpc_channel_credentials = grpc_channel_credentials
self.grpc_exchange_channel_credentials = grpc_exchange_channel_credentials
self.grpc_explorer_channel_credentials = grpc_explorer_channel_credentials
self.chain_stream_channel_credentials = chain_stream_channel_credentials

@classmethod
def devnet(cls):
Expand All @@ -219,6 +238,11 @@ def testnet(cls, node="lb"):
if node not in nodes:
raise ValueError("Must be one of {}".format(nodes))

grpc_channel_credentials = grpc.ssl_channel_credentials()
grpc_exchange_channel_credentials = grpc.ssl_channel_credentials()
grpc_explorer_channel_credentials = grpc.ssl_channel_credentials()
chain_stream_channel_credentials = grpc.ssl_channel_credentials()

if node == "lb":
lcd_endpoint = "https://testnet.sentry.lcd.injective.network:443"
tm_websocket_endpoint = "wss://testnet.sentry.tm.injective.network:443/websocket"
Expand All @@ -227,7 +251,6 @@ def testnet(cls, node="lb"):
grpc_explorer_endpoint = "testnet.sentry.explorer.grpc.injective.network:443"
chain_stream_endpoint = "testnet.sentry.chain.stream.injective.network:443"
cookie_assistant = BareMetalLoadBalancedCookieAssistant()
use_secure_connection = True
else:
lcd_endpoint = "https://testnet.lcd.injective.network:443"
tm_websocket_endpoint = "wss://testnet.tm.injective.network:443/websocket"
Expand All @@ -236,7 +259,6 @@ def testnet(cls, node="lb"):
grpc_explorer_endpoint = "testnet.explorer.grpc.injective.network:443"
chain_stream_endpoint = "testnet.chain.stream.injective.network:443"
cookie_assistant = DisabledCookieAssistant()
use_secure_connection = True

return cls(
lcd_endpoint=lcd_endpoint,
Expand All @@ -249,7 +271,10 @@ def testnet(cls, node="lb"):
fee_denom="inj",
env="testnet",
cookie_assistant=cookie_assistant,
use_secure_connection=use_secure_connection,
grpc_channel_credentials=grpc_channel_credentials,
grpc_exchange_channel_credentials=grpc_exchange_channel_credentials,
grpc_explorer_channel_credentials=grpc_explorer_channel_credentials,
chain_stream_channel_credentials=chain_stream_channel_credentials,
)

@classmethod
Expand All @@ -267,7 +292,10 @@ def mainnet(cls, node="lb"):
grpc_explorer_endpoint = "sentry.explorer.grpc.injective.network:443"
chain_stream_endpoint = "sentry.chain.stream.injective.network:443"
cookie_assistant = BareMetalLoadBalancedCookieAssistant()
use_secure_connection = True
grpc_channel_credentials = grpc.ssl_channel_credentials()
grpc_exchange_channel_credentials = grpc.ssl_channel_credentials()
grpc_explorer_channel_credentials = grpc.ssl_channel_credentials()
chain_stream_channel_credentials = grpc.ssl_channel_credentials()

return cls(
lcd_endpoint=lcd_endpoint,
Expand All @@ -280,7 +308,10 @@ def mainnet(cls, node="lb"):
fee_denom="inj",
env="mainnet",
cookie_assistant=cookie_assistant,
use_secure_connection=use_secure_connection,
grpc_channel_credentials=grpc_channel_credentials,
grpc_exchange_channel_credentials=grpc_exchange_channel_credentials,
grpc_explorer_channel_credentials=grpc_explorer_channel_credentials,
chain_stream_channel_credentials=chain_stream_channel_credentials,
)

@classmethod
Expand All @@ -296,7 +327,6 @@ def local(cls):
fee_denom="inj",
env="local",
cookie_assistant=DisabledCookieAssistant(),
use_secure_connection=False,
)

@classmethod
Expand All @@ -311,8 +341,20 @@ def custom(
chain_id,
env,
cookie_assistant: Optional[CookieAssistant] = None,
use_secure_connection: bool = False,
use_secure_connection: Optional[bool] = None,
grpc_channel_credentials: Optional[ChannelCredentials] = None,
grpc_exchange_channel_credentials: Optional[ChannelCredentials] = None,
grpc_explorer_channel_credentials: Optional[ChannelCredentials] = None,
chain_stream_channel_credentials: Optional[ChannelCredentials] = None,
):
# the `use_secure_connection` parameter is ignored and will be deprecated soon.
if use_secure_connection is not None:
warn(
"use_secure_connection parameter in Network is no longer used and will be deprecated",
DeprecationWarning,
stacklevel=2,
)

assistant = cookie_assistant or DisabledCookieAssistant()
return cls(
lcd_endpoint=lcd_endpoint,
Expand All @@ -325,7 +367,37 @@ def custom(
fee_denom="inj",
env=env,
cookie_assistant=assistant,
use_secure_connection=use_secure_connection,
grpc_channel_credentials=grpc_channel_credentials,
grpc_exchange_channel_credentials=grpc_exchange_channel_credentials,
grpc_explorer_channel_credentials=grpc_explorer_channel_credentials,
chain_stream_channel_credentials=chain_stream_channel_credentials,
)

@classmethod
def custom_chain_and_public_indexer_mainnet(
cls,
lcd_endpoint,
tm_websocket_endpoint,
grpc_endpoint,
chain_stream_endpoint,
cookie_assistant: Optional[CookieAssistant] = None,
):
mainnet_network = cls.mainnet()

return cls.custom(
lcd_endpoint=lcd_endpoint,
tm_websocket_endpoint=tm_websocket_endpoint,
grpc_endpoint=grpc_endpoint,
grpc_exchange_endpoint=mainnet_network.grpc_exchange_endpoint,
grpc_explorer_endpoint=mainnet_network.grpc_explorer_endpoint,
chain_stream_endpoint=chain_stream_endpoint,
chain_id="injective-1",
env="mainnet",
cookie_assistant=cookie_assistant,
grpc_channel_credentials=None,
grpc_exchange_channel_credentials=mainnet_network.grpc_exchange_channel_credentials,
grpc_explorer_channel_credentials=mainnet_network.grpc_explorer_channel_credentials,
chain_stream_channel_credentials=None,
)

def string(self):
Expand All @@ -336,3 +408,22 @@ async def chain_metadata(self, metadata_query_provider: Callable) -> Tuple[Tuple

async def exchange_metadata(self, metadata_query_provider: Callable) -> Tuple[Tuple[str, str]]:
return await self.cookie_assistant.exchange_metadata(metadata_query_provider=metadata_query_provider)

def create_chain_grpc_channel(self) -> grpc.Channel:
return self._create_grpc_channel(self.grpc_endpoint, self.grpc_channel_credentials)

def create_exchange_grpc_channel(self) -> grpc.Channel:
return self._create_grpc_channel(self.grpc_exchange_endpoint, self.grpc_exchange_channel_credentials)

def create_explorer_grpc_channel(self) -> grpc.Channel:
return self._create_grpc_channel(self.grpc_explorer_endpoint, self.grpc_explorer_channel_credentials)

def create_chain_stream_grpc_channel(self) -> grpc.Channel:
return self._create_grpc_channel(self.chain_stream_endpoint, self.chain_stream_channel_credentials)

def _create_grpc_channel(self, endpoint: str, credentials: Optional[ChannelCredentials]) -> grpc.Channel:
if credentials is None:
channel = grpc.aio.insecure_channel(endpoint)
else:
channel = grpc.aio.secure_channel(endpoint, credentials)
return channel
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "injective-py"
version = "1.4.2"
version = "1.5.0"
description = "Injective Python SDK, with Exchange API Client"
authors = ["Injective Labs <contact@injectivelabs.org>"]
license = "Apache-2.0"
Expand Down
28 changes: 28 additions & 0 deletions tests/core/test_network_deprecation_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from warnings import catch_warnings

from pyinjective.core.network import DisabledCookieAssistant, Network


class TestNetworkDeprecationWarnings:
def test_use_secure_connection_parameter_deprecation_warning(self):
with catch_warnings(record=True) as all_warnings:
Network(
lcd_endpoint="lcd_endpoint",
tm_websocket_endpoint="tm_websocket_endpoint",
grpc_endpoint="grpc_endpoint",
grpc_exchange_endpoint="grpc_exchange_endpoint",
grpc_explorer_endpoint="grpc_explorer_endpoint",
chain_stream_endpoint="chain_stream_endpoint",
chain_id="chain_id",
fee_denom="fee_denom",
env="env",
cookie_assistant=DisabledCookieAssistant(),
use_secure_connection=True,
)

deprecation_warnings = [warning for warning in all_warnings if issubclass(warning.category, DeprecationWarning)]
assert len(deprecation_warnings) == 1
assert (
str(deprecation_warnings[0].message)
== "use_secure_connection parameter in Network is no longer used and will be deprecated"
)
17 changes: 16 additions & 1 deletion tests/test_async_client_deprecation_warnings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from warnings import catch_warnings

import grpc
import pytest

from pyinjective.async_client import AsyncClient
Expand Down Expand Up @@ -579,7 +580,7 @@ async def test_get_oracle_prices_deprecation_warning(
assert str(deprecation_warnings[0].message) == "This method is deprecated. Use fetch_oracle_price instead"

@pytest.mark.asyncio
async def test_stream_keepalive_deprecation_warning(
async def test_stream_oracle_prices_deprecation_warning(
self,
oracle_servicer,
):
Expand Down Expand Up @@ -1682,3 +1683,17 @@ async def test_chain_stream_deprecation_warning(
assert (
str(deprecation_warnings[0].message) == "This method is deprecated. Use listen_chain_stream_updates instead"
)

def test_credentials_parameter_deprecation_warning(
self,
auth_servicer,
):
with catch_warnings(record=True) as all_warnings:
AsyncClient(network=Network.local(), credentials=grpc.ssl_channel_credentials())

deprecation_warnings = [warning for warning in all_warnings if issubclass(warning.category, DeprecationWarning)]
assert len(deprecation_warnings) == 1
assert (
str(deprecation_warnings[0].message)
== "credentials parameter in AsyncClient is no longer used and will be deprecated"
)

0 comments on commit b43fc0a

Please sign in to comment.