From a3f8ba000973705d78dde99d6e68ae155060ade8 Mon Sep 17 00:00:00 2001 From: kevinheavey Date: Wed, 16 Oct 2024 00:14:38 +0400 Subject: [PATCH] Add optional `token_program_id` param to `get_associated_token_address` --- CHANGELOG.md | 1 + crates/token/src/associated.rs | 19 +++++++++++++++++-- python/solders/token/associated.pyi | 6 +++++- tests/test_message.py | 2 ++ tests/test_rpc_responses.py | 1 + tests/test_transaction.py | 2 ++ tests/token/test_ata.py | 14 ++++++++++++++ 7 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 tests/token/test_ata.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1064c7ef..bd7be81d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Changed +- Add optional `token_program_id` param to `get_associated_token_address` [(#117)](https://github.com/kevinheavey/solders/pull/117). - Upgrade Solana deps to 2.0 [(#116)](https://github.com/kevinheavey/solders/pull/116). - Remove GetStakeActivationResp (no longer exists) [(#116)](https://github.com/kevinheavey/solders/pull/116). diff --git a/crates/token/src/associated.rs b/crates/token/src/associated.rs index 5d0ad403..8ba2659b 100644 --- a/crates/token/src/associated.rs +++ b/crates/token/src/associated.rs @@ -1,14 +1,29 @@ use pyo3::prelude::*; use solders_pubkey::Pubkey; -use spl_associated_token_account_client::address::get_associated_token_address as get_ata; +use spl_associated_token_account_client::address::get_associated_token_address_with_program_id as get_ata; /// Derives the associated token account address for the given wallet address and token mint. +/// +/// Args: +/// wallet_address (Pubkey): The address of the wallet that owns the token account. +/// token_mint_address (Pubkey): The token mint. +/// token_program_id (Pubkey | None): The token program ID. Defaults to the SPL Token Program. +/// +/// Returns: +/// Pubkey: The associated token address +/// #[pyfunction] pub fn get_associated_token_address( wallet_address: &Pubkey, token_mint_address: &Pubkey, + token_program_id: Option<&Pubkey>, ) -> Pubkey { - get_ata(wallet_address.as_ref(), token_mint_address.as_ref()).into() + get_ata( + wallet_address.as_ref(), + token_mint_address.as_ref(), + token_program_id.map_or(&spl_token::ID, |x| x.as_ref()), + ) + .into() } pub fn create_associated_mod(py: Python<'_>) -> PyResult<&PyModule> { diff --git a/python/solders/token/associated.pyi b/python/solders/token/associated.pyi index 949ea1d9..f36350b3 100644 --- a/python/solders/token/associated.pyi +++ b/python/solders/token/associated.pyi @@ -1,5 +1,9 @@ +from typing import Optional + from solders.pubkey import Pubkey def get_associated_token_address( - wallet_address: Pubkey, token_mint_address: Pubkey + wallet_address: Pubkey, + token_mint_address: Pubkey, + token_program_id: Optional[Pubkey] = None, ) -> Pubkey: ... diff --git a/tests/test_message.py b/tests/test_message.py index 09e1eb07..f874432b 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -78,6 +78,7 @@ def test_program_position() -> None: assert message.program_position(1) == 0 assert message.program_position(2) == 1 + def test_program_ids() -> None: key0 = Pubkey.new_unique() key1 = Pubkey.new_unique() @@ -93,6 +94,7 @@ def test_program_ids() -> None: ) assert message.program_ids() == [loader2] + def test_message_header_len_constant() -> None: assert MessageHeader.LENGTH == 3 diff --git a/tests/test_rpc_responses.py b/tests/test_rpc_responses.py index bb0ea2e6..f326a790 100644 --- a/tests/test_rpc_responses.py +++ b/tests/test_rpc_responses.py @@ -1523,6 +1523,7 @@ def test_get_slot_leaders() -> None: ) ] + def test_get_supply() -> None: raw = """{ "jsonrpc": "2.0", diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 869e7cc4..b7c363fc 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -913,6 +913,7 @@ def test_tx_uses_nonce_first_prog_id_not_nonce_fail() -> None: tx = Transaction([from_keypair, nonce_keypair], message, Hash.default()) assert tx.uses_durable_nonce() is None + def test_tx_uses_nonce_wrong_first_nonce_ix_fail() -> None: from_keypair = Keypair() from_pubkey = from_keypair.pubkey() @@ -935,6 +936,7 @@ def test_tx_uses_nonce_wrong_first_nonce_ix_fail() -> None: tx = Transaction([from_keypair, nonce_keypair], message, Hash.default()) assert tx.uses_durable_nonce() is None + def test_tx_keypair_pubkey_mismatch() -> None: from_keypair = Keypair() from_pubkey = from_keypair.pubkey() diff --git a/tests/token/test_ata.py b/tests/token/test_ata.py new file mode 100644 index 00000000..f40e75a9 --- /dev/null +++ b/tests/token/test_ata.py @@ -0,0 +1,14 @@ +from solders.pubkey import Pubkey +from solders.token.associated import get_associated_token_address + + +def test_ata() -> None: + wallet_address = Pubkey.from_string("5d21Nx19eZBThbExCn1ESAk3RGmE8Rdp9PKMWZ2VedSK") + token_mint = Pubkey.from_string("3CqfBkrmRsK3uXZaxktvTeeBkJp4yeFKs4mUi2jhKExz") + assert get_associated_token_address( + wallet_address, token_mint + ) == Pubkey.from_string("Aumq2SPVzZccYL3UAhvXoDDkNYLZr2zpyxLuJiyx79te") + token22_id = Pubkey.from_string("TokenzQdBNbLqP5VEhdkAS6EPFLC1PHnBqCXEpPxuEb") + assert get_associated_token_address( + wallet_address, token_mint, token22_id + ) == Pubkey.from_string("4xoV4cxTM3GcaWP7bKbUdu2Gp9P9nEgpmCPV8ykFGo4U")