Source code for tdw_catalog.utils

from dataclasses import dataclass
from enum import Enum, IntEnum
import math
import gzip
from io import BytesIO
import sys
import typing
import zlib

from aiohttp import ClientResponse, ClientSession
if sys.version_info >= (3, 11):
    from enum import StrEnum
else:
    from backports.strenum import StrEnum
from typing import BinaryIO, Dict, Optional, TYPE_CHECKING, Callable, List, Union
import collections
from sys import getsizeof
from datetime import datetime
import dateutil
from tdw_catalog.errors import CatalogException

if TYPE_CHECKING:
    from urllib.request import Request


class _ExportFormat(IntEnum):
    CSV = 0
    PARQUET = 1
    CSV_GZIP = 2


[docs]class FilterSortOrder(Enum): ASC = 1 DESC = 2
[docs]class ConnectionPortalType(StrEnum): GS = 'Gs' S3 = 'S3' UNITY = 'Unity' FTP = 'Ftp' SFTP = 'Sftp' EXTERNAL = 'External' NULL = 'Null' IMPORT_LITE = 'ImportLite' HTTP = 'Http' CATALOG = 'Namara'
# TODO remove and replace with constants from namara-go generated code
[docs]class MetadataFieldType(IntEnum): """ The different possible data types for values stored in MetadataFields and default values stored in MetadataTemplateFields """ FT_STRING = 0, FT_INTEGER = 1, FT_DECIMAL = 2, FT_DATE = 3, FT_DATETIME = 4, FT_DATASET = 5, FT_URL = 6, FT_USER = 7, FT_ATTACHMENT = 8, FT_LIST = 9, FT_CURRENCY = 10, FT_TEAM = 11 FT_ALIAS = 12
[docs]class ColumnType(StrEnum): """ The different possible data types for :class:`.Column`\ s within a :class:`.DataDictionary` """ BOOLEAN = 'boolean' DATE = 'date' DATETIME = 'datetime' INTEGER = 'integer' DECIMAL = 'decimal' PERCENT = 'percent' CURRENCY = 'currency' STRING = 'string' TEXT = 'text' GEOMETRY = 'geometry' GEOJSON = 'geojson'
[docs]class ImportState(StrEnum): """ The different possible states an imported dataset might occupy. Virtualized datasets will always show state ``IMPORTED``. """ IMPORTED = 'imported' IMPORTING = 'importing' QUEUED = 'queued' FAILED = 'failed'
[docs]@dataclass class CurrencyFieldValue(): """ :class:`.CurrencyFieldValue` models the value of a currency field Attributes ---------- value : float The currency value currency : str The specific currency to which the value belongs """ value: float currency: str
[docs]@dataclass class FilterSort: """ :class:`.FilterSort` describes a desired sort field and order for results. Attributes ---------- field : str The field to sort by order : FilterSortOrder, optional The order to sort in (`FilterSortOrder.ASC` by default) """ field: str order: FilterSortOrder = FilterSortOrder.ASC
[docs]@dataclass class LegacyFilter: """ :class:`.LegacyFilter` describes the ways in which results should be filtered and/or paginated Attributes ---------- limit : int, optional Limits the number of results. Useful for pagination. (`None` by default) offset : int, optional Offsets the result list by the given number of results. Useful for pagination. (`None` by default) """ limit: int = None offset: int = None def serialize(self): new_filter = {} if self.limit is not None: new_filter["limit"] = {"value": self.limit} if self.offset is not None: new_filter["offset"] = {"value": self.offset} return new_filter
[docs]class Filter(LegacyFilter): """ :class:`.Filter` describes the ways in which results should be filtered and/or paginated. It is serialized in a new way vs :class:`.LegacyFilter` Attributes ---------- limit : int, optional Limits the number of results. Useful for pagination. (`None` by default) offset : int, optional Offsets the result list by the given number of results. Useful for pagination. (`None` by default) """ def serialize(self): new_filter = super().serialize() new_filter["offset"] = {"offset": self.offset} return new_filter
[docs]@dataclass class SortableFilter(LegacyFilter): """ :class:`.SortableFilter` describes the ways in which results should be filtered, paginated and/or sorted. Attributes ---------- limit : int, optional Limits the number of results. Useful for pagination. (`None` by default) offset : int, optional Offsets the result list by the given number of results. Useful for pagination. (`None` by default) sort : FilterSort, optional Specifies a desired sort field and order for results (`None` by default). """ sort: FilterSort = None def serialize(self): new_filter = super().serialize() if self.sort is not None: new_filter["sort"] = { "value": self.sort.field, "order": "ASC" if self.sort.order == FilterSortOrder.ASC else "DESC" } return new_filter
[docs]@dataclass class ListOrganizationsFilter(LegacyFilter): """ :class:`.ListOrganizationsFilter` filters :class:`.Organization` results according to a set of provided ids Attributes ---------- organization_ids : str[], optional Filters results according to a set of provided ids """ organization_ids: Optional[List[str]] = None def serialize(self): new_filter = super().serialize() if self.organization_ids is not None: new_filter["organization_ids"] = self.organization_ids return new_filter
[docs]@dataclass class ListSourcesFilter(LegacyFilter): """ :class:`.ListSourcesFilter` filters results according to :class:`.Source` fields Attributes ---------- labels : Optional[str] Filters results by label. This will match label substrings. """ labels: Optional[str] = None def serialize(self): new_filter = super().serialize() if self.labels is not None: new_filter["labels"] = list(self.labels) return new_filter
[docs]@dataclass class ListConnectionsFilter(LegacyFilter): """ :class:`.ListConnectionsFilter` filters results according to Connection fields Attributes ---------- organization_id : Optional[str] Filters results by `organization_id` source_ids : Optional[List[str]] Filters results to the given `source_id(s)` portals : Optional[List[ConnectionPortalType]] Filters results to the given :class:`.ConnectionPortalType`\\ (s) """ organization_id: Optional[str] = None source_ids: Optional[List[str]] = None portals: Optional[List[ConnectionPortalType]] = None def serialize(self): new_filter = super().serialize() if self.organization_id is not None: new_filter["organization_id"] = self.organization_id if self.source_ids is not None: new_filter["source_ids"] = self.source_ids if self.portals is not None: new_filter["portals"] = self.portals return new_filter
[docs]@dataclass class ListGlossaryTermsFilter(Filter): """ :class:`.ListGlossaryTermsFilter` filters results according to :class:`.GlossaryTerm` ids Attributes ---------- glossary_term_ids : Optional[List[str]] Filters results to the given `glossary_term_id(s)` """ glossary_term_ids: Optional[List[str]] = None def serialize(self): new_filter = super().serialize() if self.glossary_term_ids is not None: new_filter["glossary_term_ids"] = self.glossary_term_ids return new_filter
[docs]@dataclass class QueryFilter(SortableFilter): """ :class:`.QueryFilter` filters results according to a NiQL query Attributes ---------- query : str, optional Filters results according to a NiQL query """ query: Optional[str] = None def serialize(self): new_filter = SortableFilter.serialize(self) if self.query is not None: new_filter["query"] = {"value": self.query} return new_filter
def _parse_timestamp(timestamp: Union[str, datetime, dict]) -> datetime: if isinstance(timestamp, datetime): return timestamp # handle NullableTimestamp elif isinstance(timestamp, dict) and ('timestamp' in timestamp or 'is_null' in timestamp): if 'is_null' in timestamp and timestamp['is_null'] == True: return None return dateutil.parser.parse(timestamp["timestamp"]) else: return dateutil.parser.parse(timestamp) def _convert_datetime_to_nullable_timestamp(date: datetime): timestamp = date.timestamp() seconds = int(timestamp) # getting the trailing numbers and converting to nanoseconds nanos = int(((timestamp % 1) * 1000) * 10000) return { "is_null": False, "timestamp": { "seconds": seconds, "nanos": nanos, } } # with help from https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional def _type_is_optional(field): return typing.get_origin(field) is Union and \ type(None) in typing.get_args(field) async def _download_export( download_url: str, format: Optional[_ExportFormat] = None, f_out: Optional[BinaryIO] = None) -> Optional[Union[str, BinaryIO]]: async with ClientSession() as session: # TODO support gunzip async with session.get(download_url, allow_redirects=True) as response: if f_out is not None: decompressor = zlib.decompressobj( 16 + zlib.MAX_WBITS ) if format == _ExportFormat.CSV_GZIP else None try: async for chunk in response.content.iter_chunked(4096): if decompressor is not None: dchunk = decompressor.decompress(chunk) decompressor.flush() f_out.write(dchunk) else: f_out.write(chunk) return None finally: f_out.flush() else: async with session.get(download_url, allow_redirects=True) as response: if format is _ExportFormat.PARQUET: return BytesIO(await response.read()) elif format is _ExportFormat.CSV_GZIP: return gzip.decompress(await response.read()).decode("utf-8") else: return await response.text("utf-8")