diff --git a/packages/protocol-utils-evm/src/endpoint/factory.ts b/packages/protocol-utils-evm/src/endpoint/factory.ts index 2e1d14815..195400243 100644 --- a/packages/protocol-utils-evm/src/endpoint/factory.ts +++ b/packages/protocol-utils-evm/src/endpoint/factory.ts @@ -1,7 +1,8 @@ import pMemoize from 'p-memoize' -import { OmniContractFactory } from '@layerzerolabs/utils-evm' +import type { EndpointFactory, Uln302Factory } from '@layerzerolabs/protocol-utils' +import type { OmniContractFactory } from '@layerzerolabs/utils-evm' import { Endpoint } from './sdk' -import { EndpointFactory } from '@layerzerolabs/protocol-utils' +import { createUln302Factory } from '@/uln302/factory' /** * Syntactic sugar that creates an instance of EVM `Endpoint` SDK @@ -10,5 +11,7 @@ import { EndpointFactory } from '@layerzerolabs/protocol-utils' * @param {OmniContractFactory} contractFactory * @returns {EndpointFactory} */ -export const createEndpointFactory = (contractFactory: OmniContractFactory): EndpointFactory => - pMemoize(async (point) => new Endpoint(await contractFactory(point))) +export const createEndpointFactory = ( + contractFactory: OmniContractFactory, + uln302Factory: Uln302Factory = createUln302Factory(contractFactory) +): EndpointFactory => pMemoize(async (point) => new Endpoint(await contractFactory(point), uln302Factory)) diff --git a/packages/protocol-utils-evm/src/endpoint/sdk.ts b/packages/protocol-utils-evm/src/endpoint/sdk.ts index 3f64cc71f..6f9b58f35 100644 --- a/packages/protocol-utils-evm/src/endpoint/sdk.ts +++ b/packages/protocol-utils-evm/src/endpoint/sdk.ts @@ -1,9 +1,28 @@ -import type { IEndpoint } from '@layerzerolabs/protocol-utils' -import { formatEid, type Address, type OmniTransaction } from '@layerzerolabs/utils' +import assert from 'assert' +import type { IEndpoint, IUln302, Uln302Factory } from '@layerzerolabs/protocol-utils' +import { formatEid, type Address, type OmniTransaction, formatOmniPoint } from '@layerzerolabs/utils' import type { EndpointId } from '@layerzerolabs/lz-definitions' -import { ignoreZero, makeZeroAddress, OmniSDK } from '@layerzerolabs/utils-evm' +import { ignoreZero, isZero, makeZeroAddress, type OmniContract, OmniSDK } from '@layerzerolabs/utils-evm' export class Endpoint extends OmniSDK implements IEndpoint { + constructor( + contract: OmniContract, + private readonly uln302Factory: Uln302Factory + ) { + super(contract) + } + + async getUln302SDK(address: Address): Promise { + assert( + !isZero(address), + `Uln302 cannot be instantiated: Uln302 address cannot be a zero value for Endpoint ${formatOmniPoint( + this.point + )}` + ) + + return await this.uln302Factory({ eid: this.point.eid, address }) + } + async getDefaultReceiveLibrary(eid: EndpointId): Promise
{ return ignoreZero(await this.contract.contract.defaultReceiveLibrary(eid)) } diff --git a/packages/protocol-utils-evm/src/uln302/factory.ts b/packages/protocol-utils-evm/src/uln302/factory.ts new file mode 100644 index 000000000..61c1e1474 --- /dev/null +++ b/packages/protocol-utils-evm/src/uln302/factory.ts @@ -0,0 +1,14 @@ +import pMemoize from 'p-memoize' +import type { OmniContractFactory } from '@layerzerolabs/utils-evm' +import type { Uln302Factory } from '@layerzerolabs/protocol-utils' +import { Uln302 } from './sdk' + +/** + * Syntactic sugar that creates an instance of EVM `Uln302` SDK + * based on an `OmniPoint` with help of an `OmniContractFactory` + * + * @param {OmniContractFactory} contractFactory + * @returns {Uln302Factory} + */ +export const createUln302Factory = (contractFactory: OmniContractFactory): Uln302Factory => + pMemoize(async (point) => new Uln302(await contractFactory(point))) diff --git a/packages/protocol-utils-evm/src/uln302/index.ts b/packages/protocol-utils-evm/src/uln302/index.ts index e28e54c22..008b25bf8 100644 --- a/packages/protocol-utils-evm/src/uln302/index.ts +++ b/packages/protocol-utils-evm/src/uln302/index.ts @@ -1,2 +1,3 @@ +export * from './factory' export * from './sdk' export * from './types' diff --git a/packages/protocol-utils-evm/test/endpoint/sdk.test.ts b/packages/protocol-utils-evm/test/endpoint/sdk.test.ts new file mode 100644 index 000000000..10ee01968 --- /dev/null +++ b/packages/protocol-utils-evm/test/endpoint/sdk.test.ts @@ -0,0 +1,65 @@ +import fc from 'fast-check' +import { endpointArbitrary, evmAddressArbitrary } from '@layerzerolabs/test-utils' +import type { Contract } from '@ethersproject/contracts' +import { isZero, makeBytes32, makeZeroAddress, type OmniContract } from '@layerzerolabs/utils-evm' +import { Endpoint } from '@/endpoint' + +describe('endpoint/sdk', () => { + describe('getUln302SDK', () => { + const zeroishAddressArbitrary = fc.constantFrom(makeZeroAddress(), makeBytes32()) + const jestFunctionArbitrary = fc.anything().map(() => jest.fn()) + const oappOmniContractArbitrary = fc.record({ + address: evmAddressArbitrary, + peers: jestFunctionArbitrary, + endpoint: jestFunctionArbitrary, + interface: fc.record({ + encodeFunctionData: jestFunctionArbitrary, + }), + }) as fc.Arbitrary as fc.Arbitrary + + const omniContractArbitrary: fc.Arbitrary = fc.record({ + eid: endpointArbitrary, + contract: oappOmniContractArbitrary, + }) + + const uln302Factory = jest.fn().mockRejectedValue('No endpoint') + + it('should reject if the address is a zeroish address', async () => { + await fc.assert( + fc.asyncProperty(omniContractArbitrary, zeroishAddressArbitrary, async (omniContract, address) => { + const sdk = new Endpoint(omniContract, uln302Factory) + + await expect(sdk.getUln302SDK(address)).rejects.toThrow( + /Uln302 cannot be instantiated: Uln302 address cannot be a zero value for Endpoint/ + ) + }) + ) + }) + + it('should call the uln302Factory if the address is a non-zeroish address', async () => { + await fc.assert( + fc.asyncProperty( + omniContractArbitrary, + evmAddressArbitrary, + fc.anything(), + async (omniContract, address, uln302Sdk) => { + fc.pre(!isZero(address)) + + uln302Factory.mockReset().mockResolvedValue(uln302Sdk) + + const sdk = new Endpoint(omniContract, uln302Factory) + const uln302 = await sdk.getUln302SDK(address) + + expect(uln302).toBe(uln302Sdk) + + expect(uln302Factory).toHaveBeenCalledTimes(1) + expect(uln302Factory).toHaveBeenCalledWith({ + eid: omniContract.eid, + address, + }) + } + ) + ) + }) + }) +}) diff --git a/packages/protocol-utils/src/endpoint/types.ts b/packages/protocol-utils/src/endpoint/types.ts index 4789f8dea..38eb6a8e9 100644 --- a/packages/protocol-utils/src/endpoint/types.ts +++ b/packages/protocol-utils/src/endpoint/types.ts @@ -7,8 +7,11 @@ import type { Bytes32, } from '@layerzerolabs/utils' import type { EndpointId } from '@layerzerolabs/lz-definitions' +import type { IUln302 } from '@/uln302/types' export interface IEndpoint extends IOmniSDK { + getUln302SDK(address: Address): Promise + getDefaultReceiveLibrary(eid: EndpointId): Promise
setDefaultReceiveLibrary( eid: EndpointId, diff --git a/packages/ua-utils-evm-hardhat-test/test/__utils__/endpoint.ts b/packages/ua-utils-evm-hardhat-test/test/__utils__/endpoint.ts index 7917328c5..2014385d7 100644 --- a/packages/ua-utils-evm-hardhat-test/test/__utils__/endpoint.ts +++ b/packages/ua-utils-evm-hardhat-test/test/__utils__/endpoint.ts @@ -12,14 +12,12 @@ import { omniContractToPoint } from '@layerzerolabs/utils-evm' import { configureEndpoint, EndpointEdgeConfig, - EndpointFactory, Uln302NodeConfig, Uln302ExecutorConfig, configureUln302, - Uln302Factory, Uln302UlnConfig, } from '@layerzerolabs/protocol-utils' -import { Endpoint, Uln302 } from '@layerzerolabs/protocol-utils-evm' +import { createEndpointFactory, createUln302Factory } from '@layerzerolabs/protocol-utils-evm' import { formatOmniPoint } from '@layerzerolabs/utils' export const ethEndpoint = { eid: EndpointId.ETHEREUM_MAINNET, contractName: 'EndpointV2' } @@ -76,8 +74,8 @@ export const setupDefaultEndpoint = async (): Promise => { const environmentFactory = createNetworkEnvironmentFactory() const contractFactory = createConnectedContractFactory() const signerFactory = createSignerFactory() - const endpointSdkFactory: EndpointFactory = async (point) => new Endpoint(await contractFactory(point)) - const ulnSdkFactory: Uln302Factory = async (point) => new Uln302(await contractFactory(point)) + const ulnSdkFactory = createUln302Factory(contractFactory) + const endpointSdkFactory = createEndpointFactory(contractFactory, ulnSdkFactory) // First we deploy the endpoint await deploy(await environmentFactory(EndpointId.ETHEREUM_MAINNET)) diff --git a/packages/ua-utils-evm-hardhat-test/test/endpoint/config.test.ts b/packages/ua-utils-evm-hardhat-test/test/endpoint/config.test.ts index c0e549de3..517fc2c12 100644 --- a/packages/ua-utils-evm-hardhat-test/test/endpoint/config.test.ts +++ b/packages/ua-utils-evm-hardhat-test/test/endpoint/config.test.ts @@ -1,10 +1,9 @@ import 'hardhat' import { createConnectedContractFactory } from '@layerzerolabs/utils-evm-hardhat' -import type { OmniPoint } from '@layerzerolabs/utils' import { omniContractToPoint } from '@layerzerolabs/utils-evm' import { EndpointId } from '@layerzerolabs/lz-definitions' import { getDefaultUlnConfig, setupDefaultEndpoint } from '../__utils__/endpoint' -import { Endpoint, Uln302 } from '@layerzerolabs/protocol-utils-evm' +import { createEndpointFactory, createUln302Factory } from '@layerzerolabs/protocol-utils-evm' describe('endpoint/config', () => { const ethEndpoint = { eid: EndpointId.ETHEREUM_MAINNET, contractName: 'EndpointV2' } @@ -24,7 +23,7 @@ describe('endpoint/config', () => { it('should have default libraries configured', async () => { // This is the required tooling we need to set up const connectedContractFactory = createConnectedContractFactory() - const sdkFactory = async (point: OmniPoint) => new Endpoint(await connectedContractFactory(point)) + const sdkFactory = createEndpointFactory(connectedContractFactory) // Now for the purposes of the test, we need to get coordinates of our contracts const ethEndpointPoint = omniContractToPoint(await connectedContractFactory(ethEndpoint)) @@ -59,7 +58,7 @@ describe('endpoint/config', () => { it('should have default executors configured', async () => { // This is the required tooling we need to set up const connectedContractFactory = createConnectedContractFactory() - const sdkFactory = async (point: OmniPoint) => new Uln302(await connectedContractFactory(point)) + const sdkFactory = createUln302Factory(connectedContractFactory) const ethSendUlnPoint = omniContractToPoint(await connectedContractFactory(ethSendUln)) const avaxSendUlnPoint = omniContractToPoint(await connectedContractFactory(avaxSendUln)) diff --git a/packages/ua-utils-evm-hardhat/src/utils/taskHelpers.ts b/packages/ua-utils-evm-hardhat/src/utils/taskHelpers.ts index 5a8e934da..78e072fd9 100644 --- a/packages/ua-utils-evm-hardhat/src/utils/taskHelpers.ts +++ b/packages/ua-utils-evm-hardhat/src/utils/taskHelpers.ts @@ -1,7 +1,7 @@ import { Address } from '@layerzerolabs/utils' import { Uln302ExecutorConfig, Uln302UlnConfig } from '@layerzerolabs/protocol-utils' import { createConnectedContractFactory, getEidForNetworkName } from '@layerzerolabs/utils-evm-hardhat' -import { Endpoint, Uln302 } from '@layerzerolabs/protocol-utils-evm' +import { Endpoint, Uln302, createUln302Factory } from '@layerzerolabs/protocol-utils-evm' export async function getSendConfig( localNetworkName: string, @@ -11,7 +11,11 @@ export async function getSendConfig( const localEid = getEidForNetworkName(localNetworkName) const remoteEid = getEidForNetworkName(remoteNetworkName) const contractFactory = createConnectedContractFactory() - const localEndpointSDK = new Endpoint(await contractFactory({ eid: localEid, contractName: 'EndpointV2' })) + const uln302Factory = createUln302Factory(contractFactory) + const localEndpointSDK = new Endpoint( + await contractFactory({ eid: localEid, contractName: 'EndpointV2' }), + uln302Factory + ) // First we get the SDK for the local send library let sendLibrary: Address @@ -38,7 +42,11 @@ export async function getReceiveConfig( const localEid = getEidForNetworkName(localNetworkName) const remoteEid = getEidForNetworkName(remoteNetworkName) const contractFactory = createConnectedContractFactory() - const localEndpointSDK = new Endpoint(await contractFactory({ eid: localEid, contractName: 'EndpointV2' })) + const uln302Factory = createUln302Factory(contractFactory) + const localEndpointSDK = new Endpoint( + await contractFactory({ eid: localEid, contractName: 'EndpointV2' }), + uln302Factory + ) // First we get the SDK for the local send library let receiveLibrary: Address diff --git a/packages/ua-utils-evm/src/oapp/sdk.ts b/packages/ua-utils-evm/src/oapp/sdk.ts index b7ebeea82..1edd043a7 100644 --- a/packages/ua-utils-evm/src/oapp/sdk.ts +++ b/packages/ua-utils-evm/src/oapp/sdk.ts @@ -20,7 +20,7 @@ export class OApp extends OmniSDK implements IOApp { super(contract) } - async getEndpoint(): Promise { + async getEndpointSDK(): Promise { let address: string // First we'll need the endpoint address from the contract diff --git a/packages/ua-utils-evm/test/oapp/sdk.test.ts b/packages/ua-utils-evm/test/oapp/sdk.test.ts index 6e5548305..3320b58d1 100644 --- a/packages/ua-utils-evm/test/oapp/sdk.test.ts +++ b/packages/ua-utils-evm/test/oapp/sdk.test.ts @@ -228,7 +228,7 @@ describe('oapp/sdk', () => { }) }) - describe('getEndpoint', () => { + describe('getEndpointSDK', () => { it('should reject if the call to endpoint() rejects', async () => { await fc.assert( fc.asyncProperty(omniContractArbitrary, async (omniContract) => { @@ -236,7 +236,7 @@ describe('oapp/sdk', () => { const sdk = new OApp(omniContract, endpointFactory) - await expect(sdk.getEndpoint()).rejects.toThrow(/Failed to get endpoint address for OApp/) + await expect(sdk.getEndpointSDK()).rejects.toThrow(/Failed to get endpoint address for OApp/) }) ) }) @@ -251,7 +251,7 @@ describe('oapp/sdk', () => { const sdk = new OApp(omniContract, endpointFactory) - await expect(sdk.getEndpoint()).rejects.toThrow( + await expect(sdk.getEndpointSDK()).rejects.toThrow( /Endpoint cannot be instantiated: Endpoint address has been set to a zero value for OApp/ ) } @@ -273,7 +273,7 @@ describe('oapp/sdk', () => { endpointFactory.mockReset().mockResolvedValue(endpointSdk) const sdk = new OApp(omniContract, endpointFactory) - const endpoint = await sdk.getEndpoint() + const endpoint = await sdk.getEndpointSDK() expect(endpoint).toBe(endpointSdk) diff --git a/packages/ua-utils/src/oapp/types.ts b/packages/ua-utils/src/oapp/types.ts index dc3c177a7..0ad67a692 100644 --- a/packages/ua-utils/src/oapp/types.ts +++ b/packages/ua-utils/src/oapp/types.ts @@ -5,7 +5,7 @@ import type { Bytes32 } from '@layerzerolabs/utils' import type { OmniPointBasedFactory } from '@layerzerolabs/utils' export interface IOApp extends IOmniSDK { - getEndpoint(): Promise + getEndpointSDK(): Promise getPeer(eid: EndpointId): Promise hasPeer(eid: EndpointId, address: Bytes32 | Address | null | undefined): Promise setPeer(eid: EndpointId, peer: Bytes32 | Address | null | undefined): Promise