Source code for asynctradier.clients.account_clients

from typing import List, Optional

from asynctradier.common import EventType
from asynctradier.common.account_balance import AccountBalance
from asynctradier.common.event import Event
from asynctradier.common.gain_loss import ProfitLoss
from asynctradier.common.order import Order
from asynctradier.common.position import Position
from asynctradier.common.user_profile import UserAccount
from asynctradier.exceptions import APINotAvailable, InvalidDateFormat
from asynctradier.utils.common import is_valid_expiration_date
from asynctradier.utils.webutils import WebUtil


[docs] class AccountClient: """ A client for interacting with the Tradier Account API. Args: session (WebUtil): The session object used for making HTTP requests. account_id (str): The account ID. token (str): The API token. sandbox (bool, optional): Whether to use the sandbox environment. Defaults to False. """ def __init__( self, session: WebUtil, account_id: str, token: str, sandbox: bool = False ) -> None: self.session = session self.account_id = account_id self.token = token self.sandbox = sandbox
[docs] async def get_user_profile(self) -> List[UserAccount]: """ Retrieves the user profile information. Returns: A list of UserProfile objects representing the user's profile information. """ if self.sandbox: raise APINotAvailable( "please check the documentation for more details: https://documentation.tradier.com/brokerage-api/user/get-profile" ) url = "/v1/user/profile" response = await self.session.get(url) if response.get("profile") is None: return [] if not isinstance(response["profile"]["account"], list): accounts = [response["profile"]["account"]] else: accounts = response["profile"]["account"] res: List[UserAccount] = [] for account in accounts: res.append( UserAccount( **account, id=response["profile"]["id"], name=response["profile"]["name"], ) ) return res
[docs] async def get_balance(self) -> AccountBalance: """ Retrieves the account balance. Returns: AccountBalance: The account balance. """ url = f"/v1/accounts/{self.account_id}/balances" response = await self.session.get(url) return AccountBalance( **response["balances"], )
[docs] async def get_history( self, page: int = 1, limit: int = 25, event_type: Optional[EventType] = None, start: Optional[str] = None, end: Optional[str] = None, symbol: Optional[str] = None, exact_match: bool = False, ) -> List[Event]: """ Retrieves the account history. Args: page (int, optional): The page number of the history to retrieve. Defaults to 1. limit (int, optional): The number of events to retrieve per page. Defaults to 25. event_type (EventType, optional): The type of event to retrieve. Defaults to None. start (str, optional): The start date of the history to retrieve (YYYY-MM-DD). Defaults to None. end (str, optional): The end date of the history to retrieve (YYYY-MM-DD). Defaults to None. symbol (str, optional): The symbol of the event to retrieve. Defaults to None. exact_match (bool, optional): Whether to perform an exact match on the symbol. Defaults to False. Returns: List[Event]: A list of Event objects representing the account history. """ if self.sandbox: raise APINotAvailable( "please check the documentation for more details: https://documentation.tradier.com/brokerage-api/accounts/get-account-balance" ) if start is not None and not is_valid_expiration_date(start): raise InvalidDateFormat(start) if end is not None and not is_valid_expiration_date(end): raise InvalidDateFormat(end) if page is None or page < 1: page = 1 if limit is None or limit < 1: limit = 25 if exact_match is None: exact_match = False url = f"/v1/accounts/{self.account_id}/history" params = { "page": page, "limit": limit, "exactMatch": str(exact_match).lower(), } if event_type is not None: params["type"] = event_type.value if start is not None: params["start"] = start if end is not None: params["end"] = end if symbol is not None: params["symbol"] = symbol response = await self.session.get(url, params=params) if response.get("history") is None: return [] if response["history"].get("event") is None: return [] if not isinstance(response["history"]["event"], list): events = [response["history"]["event"]] else: events = response["history"]["event"] results: List[Event] = [] for event in events: results.append( Event( **event, ) ) return results
[docs] async def get_positions(self) -> List[Position]: """ Get the positions for the account. Returns: List[Position]: A list of Position objects. """ url = f"/v1/accounts/{self.account_id}/positions" response = await self.session.get(url) if response["positions"] == "null": positions = [] else: positions = response["positions"]["position"] if not isinstance(positions, list): positions = [positions] results: List[Position] = [] for position in positions: results.append( Position( **position, ) ) return results
[docs] async def get_gainloss( self, page: int = 1, limit: int = 25, start: Optional[str] = None, end: Optional[str] = None, symbol: Optional[str] = None, sort_by_close_date: bool = True, desc: bool = True, ) -> List[ProfitLoss]: """ Retrieves the gain/loss information for closed positions within a specified date range. Args: page (int): The page number of the results to retrieve (default is 1). limit (int): The maximum number of results per page (default is 25). start (str, optional): The start date of the date range (format: "YYYY-MM-DD"). end (str, optional): The end date of the date range (format: "YYYY-MM-DD"). symbol (str, optional): The symbol of the positions to filter by. sort_by_close_date (bool): Whether to sort the results by close date (default is False). desc (bool): Whether to sort the results in descending order (default is True). Returns: List[ProfitLoss]: A list of ProfitLoss objects representing the gain/loss information. Raises: InvalidDateFormat: If the start or end date is not in the correct format. """ if start is not None and not is_valid_expiration_date(start): raise InvalidDateFormat(start) if end is not None and not is_valid_expiration_date(end): raise InvalidDateFormat(end) if page is None or page < 1: page = 1 if limit is None or limit < 1: limit = 25 url = f"/v1/accounts/{self.account_id}/gainloss" params = { "page": page, "limit": limit, "sortBy": "closeDate" if sort_by_close_date else "openDate", "sort": "desc" if desc else "asc", } if start is not None: params["start"] = start if end is not None: params["end"] = end if symbol is not None: params["symbol"] = symbol response = await self.session.get(url, params=params) if response.get("gainloss") is None: return [] if response["gainloss"].get("closed_position") is None: return [] if not isinstance(response["gainloss"]["closed_position"], list): positions = [response["gainloss"]["closed_position"]] else: positions = response["gainloss"]["closed_position"] results: List[ProfitLoss] = [] for position in positions: results.append( ProfitLoss( **position, ) ) return results
[docs] async def get_orders(self, page: int = 1) -> List[Order]: """ Get a list of orders for the account. Parameters: page (int, optional): The page number of the orders to retrieve. Defaults to 1. Returns: List[Order]: A list of Order objects. """ res = [] page = 1 while True: orders = await self._get_orders(page) res += orders page += 1 if len(orders) <= 0: break return res
async def _get_orders(self, page: int) -> List[Order]: """ Get a list of orders for the account. Args: page (int): The page number of the orders to retrieve. Returns: List[Order]: A list of Order objects. """ url = f"/v1/accounts/{self.account_id}/orders" params = { "page": page, "includeTags": "true", } response = await self.session.get(url, params=params) if response["orders"] == "null": orders = [] else: orders = response["orders"]["order"] if not isinstance(orders, list): orders = [orders] results: List[Order] = [] for order in orders: results.append( Order( **order, ) ) return results
[docs] async def get_order(self, order_id: str) -> Order: """ Get an order by its ID. Args: order_id (str): The ID of the order. Returns: Order: The Order object. """ url = f"/v1/accounts/{self.account_id}/orders/{order_id}" params = {"includeTags": "true"} response = await self.session.get(url, params=params) order = response["order"] return Order( **order, )