diff --git a/CHANGELOG.md b/CHANGELOG.md index f7ea88d24ca..d5be8964c7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,11 +3,14 @@ ### Breaking changes - `adapter_macro` is no longer a macro, instead it is a builtin context method. Any custom macros that intercepted it by going through `context['dbt']` will need to instead access it via `context['builtins']` ([#2302](https://github.com/fishtown-analytics/dbt/issues/2302), [#2673](https://github.com/fishtown-analytics/dbt/pull/2673)) - `adapter_macro` is now deprecated. Use `adapter.dispatch` instead. +- Data tests are now written as CTEs instead of subqueries. Adapter plugins for adapters that don't support CTEs may require modification. ([#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) ### Under the hood - Upgraded snowflake-connector-python dependency to 2.2.10 and enabled the SSO token cache ([#2613](https://github.com/fishtown-analytics/dbt/issues/2613), [#2689](https://github.com/fishtown-analytics/dbt/issues/2689), [#2698](https://github.com/fishtown-analytics/dbt/pull/2698)) - Add deprecation warnings to anonymous usage tracking ([#2688](https://github.com/fishtown-analytics/dbt/issues/2688), [#2710](https://github.com/fishtown-analytics/dbt/issues/2710)) +- Data tests now behave like dbt CTEs ([#2609](https://github.com/fishtown-analytics/dbt/issues/2609), [#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) +- Adapter plugins can now override the CTE prefix by overriding their `Relation` attribute with a class that has a custom `add_ephemeral_prefix` implementation. ([#2660](https://github.com/fishtown-analytics/dbt/issues/2660), [#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) ### Features - Add a BigQuery adapter macro to enable usage of CopyJobs ([#2709](https://github.com/fishtown-analytics/dbt/pull/2709)) @@ -20,6 +23,7 @@ - Add state:modified and state:new selectors ([#2641](https://github.com/fishtown-analytics/dbt/issues/2641), [#2695](https://github.com/fishtown-analytics/dbt/pull/2695)) - Add two new flags `--use-colors` and `--no-use-colors` to `dbt run` command to enable or disable log colorization from the command line ([#2708](https://github.com/fishtown-analytics/dbt/pull/2708)) + ### Fixes - Fix Redshift table size estimation; e.g. 44 GB tables are no longer reported as 44 KB. [#2702](https://github.com/fishtown-analytics/dbt/issues/2702) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 7c1cc592530..5718db8c69d 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -25,7 +25,9 @@ ) from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows from dbt.clients.jinja import MacroGenerator -from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode +from dbt.contracts.graph.compiled import ( + CompileResultNode, CompiledSeedNode +) from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSeedNode from dbt.exceptions import warn_or_error @@ -289,7 +291,10 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]: return { self.Relation.create_from(self.config, node).without_identifier() for node in manifest.nodes.values() - if node.resource_type in NodeType.executable() + if ( + node.resource_type in NodeType.executable() and + not node.is_ephemeral_model + ) } def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap: @@ -1142,6 +1147,10 @@ def get_rows_different_sql( return sql + def get_compiler(self): + from dbt.compilation import Compiler + return Compiler(self.config) + COLUMNS_EQUAL_SQL = ''' with diff_count as ( diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 3f4026fe8e5..8cbe4b8f17b 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -201,6 +201,23 @@ def create_from_source( **kwargs ) + @staticmethod + def add_ephemeral_prefix(name: str): + return f'__dbt__CTE__{name}' + + @classmethod + def create_ephemeral_from_node( + cls: Type[Self], + config: HasQuoting, + node: Union[ParsedNode, CompiledNode], + ) -> Self: + # Note that ephemeral models are based on the name. + identifier = cls.add_ephemeral_prefix(node.name) + return cls.create( + type=cls.CTE, + identifier=identifier, + ).quote(identifier=False) + @classmethod def create_from_node( cls: Type[Self], diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 68809b77283..ab5cff64f41 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -1,19 +1,23 @@ from dataclasses import dataclass from typing import ( Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar, - Tuple, Union + Tuple, Union, Dict, Any ) from typing_extensions import Protocol import agate from dbt.contracts.connection import Connection, AdapterRequiredConfig -from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.compiled import ( + CompiledNode, NonSourceNode, NonSourceCompiledNode +) from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.relation import Policy, HasQuoting +from dbt.graph import Graph + @dataclass class AdapterConfig(BaseConfig): @@ -45,6 +49,19 @@ def create_from( ... +class CompilerProtocol(Protocol): + def compile(self, manifest: Manifest, write=True) -> Graph: + ... + + def compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, + ) -> NonSourceCompiledNode: + ... + + AdapterConfig_T = TypeVar( 'AdapterConfig_T', bound=AdapterConfig ) @@ -57,11 +74,18 @@ def create_from( Column_T = TypeVar( 'Column_T', bound=ColumnProtocol ) +Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol) class AdapterProtocol( Protocol, - Generic[AdapterConfig_T, ConnectionManager_T, Relation_T, Column_T] + Generic[ + AdapterConfig_T, + ConnectionManager_T, + Relation_T, + Column_T, + Compiler_T, + ] ): AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]] Column: ClassVar[Type[Column_T]] @@ -132,3 +156,6 @@ def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False ) -> Tuple[str, agate.Table]: ... + + def get_compiler(self) -> Compiler_T: + ... diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d9e449ff22c..e8713d686a4 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -1,10 +1,12 @@ import os from collections import defaultdict -from typing import List, Dict, Any, Tuple, cast +from typing import List, Dict, Any, Tuple, cast, Optional import networkx as nx # type: ignore +import sqlparse from dbt import flags +from dbt.adapters.factory import get_adapter from dbt.clients import jinja from dbt.clients.system import make_directory from dbt.context.providers import generate_runtime_model @@ -14,14 +16,19 @@ COMPILED_TYPES, NonSourceNode, NonSourceCompiledNode, + CompiledDataTestNode, CompiledSchemaTestNode, ) from dbt.contracts.graph.parsed import ParsedNode -from dbt.exceptions import dependency_not_found, InternalException +from dbt.exceptions import ( + dependency_not_found, + InternalException, + RuntimeException, +) from dbt.graph import Graph from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType -from dbt.utils import add_ephemeral_model_prefix, pluralize +from dbt.utils import pluralize graph_file_name = 'graph.gpickle' @@ -156,6 +163,11 @@ def _create_node_context( return context + def add_ephemeral_prefix(self, name: str): + adapter = get_adapter(self.config) + relation_cls = adapter.Relation + return relation_cls.add_ephemeral_prefix(name) + def _get_compiled_model( self, manifest: Manifest, @@ -186,6 +198,90 @@ def _get_compiled_model( f'was not an ephemeral model: {cte_id}' ) + def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str: + """ + `ctes` is a list of InjectedCTEs like: + + [ + InjectedCTE( + id="cte_id_1", + sql="__dbt__CTE__ephemeral as (select * from table)", + ), + InjectedCTE( + id="cte_id_2", + sql="__dbt__CTE__events as (select id, type from events)", + ), + ] + + Given `sql` like: + + "with internal_cte as (select * from sessions) + select * from internal_cte" + + This will spit out: + + "with __dbt__CTE__ephemeral as (select * from table), + __dbt__CTE__events as (select id, type from events), + with internal_cte as (select * from sessions) + select * from internal_cte" + + (Whitespace enhanced for readability.) + """ + if len(ctes) == 0: + return sql + + parsed_stmts = sqlparse.parse(sql) + parsed = parsed_stmts[0] + + with_stmt = None + for token in parsed.tokens: + if token.is_keyword and token.normalized == 'WITH': + with_stmt = token + break + + if with_stmt is None: + # no with stmt, add one, and inject CTEs right at the beginning + first_token = parsed.token_first() + with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') + parsed.insert_before(first_token, with_stmt) + else: + # stmt exists, add a comma (which will come after injected CTEs) + trailing_comma = sqlparse.sql.Token( + sqlparse.tokens.Punctuation, ',' + ) + parsed.insert_after(with_stmt, trailing_comma) + + token = sqlparse.sql.Token( + sqlparse.tokens.Keyword, + ", ".join(c.sql for c in ctes) + ) + parsed.insert_after(with_stmt, token) + + return str(parsed) + + def _model_prepend_ctes( + self, + model: NonSourceCompiledNode, + prepended_ctes: List[InjectedCTE] + ) -> NonSourceCompiledNode: + if model.compiled_sql is None: + raise RuntimeException( + 'Cannot prepend ctes to an unparsed node', model + ) + injected_sql = self._inject_ctes_into_sql( + model.compiled_sql, + prepended_ctes, + ) + + model.extra_ctes_injected = True + model.extra_ctes = prepended_ctes + model.injected_sql = injected_sql + model.validate(model.to_dict()) + return model + + def _get_dbt_test_name(self) -> str: + return 'dbt__CTE__INTERNAL_test' + def _recursively_prepend_ctes( self, model: NonSourceCompiledNode, @@ -203,28 +299,68 @@ def _recursively_prepend_ctes( prepended_ctes: List[InjectedCTE] = [] + dbt_test_name = self._get_dbt_test_name() + for cte in model.extra_ctes: - cte_model = self._get_compiled_model( - manifest, - cte.id, - extra_context, - ) - cte_model, new_prepended_ctes = self._recursively_prepend_ctes( - cte_model, manifest, extra_context - ) - _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) - new_cte_name = add_ephemeral_model_prefix(cte_model.name) - sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)' + if cte.id == dbt_test_name: + sql = cte.sql + else: + cte_model = self._get_compiled_model( + manifest, + cte.id, + extra_context, + ) + cte_model, new_prepended_ctes = self._recursively_prepend_ctes( + cte_model, manifest, extra_context + ) + _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) + + new_cte_name = self.add_ephemeral_prefix(cte_model.name) + sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)' _add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql)) - model.prepend_ctes(prepended_ctes) + model = self._model_prepend_ctes(model, prepended_ctes) manifest.update_node(model) return model, prepended_ctes - def compile_node( - self, node: NonSourceNode, manifest, extra_context=None + def _insert_ctes( + self, + compiled_node: NonSourceCompiledNode, + manifest: Manifest, + extra_context: Dict[str, Any], + ) -> NonSourceCompiledNode: + """Insert the CTEs for the model.""" + + # for data tests, we need to insert a special CTE at the end of the + # list containing the test query, and then have the "real" query be a + # select count(*) from that model. + # the benefit of doing it this way is that _insert_ctes() can be + # rewritten for different adapters to handle databses that don't + # support CTEs, or at least don't have full support. + if isinstance(compiled_node, CompiledDataTestNode): + # the last prepend (so last in order) should be the data test body. + # then we can add our select count(*) from _that_ cte as the "real" + # compiled_sql, and do the regular prepend logic from CTEs. + name = self._get_dbt_test_name() + cte = InjectedCTE( + id=name, + sql=f' {name} as (\n{compiled_node.compiled_sql}\n)' + ) + compiled_node.extra_ctes.append(cte) + compiled_node.compiled_sql = f'\nselect count(*) from {name}' + + injected_node, _ = self._recursively_prepend_ctes( + compiled_node, manifest, extra_context + ) + return injected_node + + def _compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, ) -> NonSourceCompiledNode: if extra_context is None: extra_context = {} @@ -248,11 +384,12 @@ def compile_node( compiled_node.compiled_sql = jinja.get_rendered( node.raw_sql, context, - node) + node, + ) compiled_node.compiled = True - injected_node, _ = self._recursively_prepend_ctes( + injected_node = self._insert_ctes( compiled_node, manifest, extra_context ) @@ -295,6 +432,7 @@ def link_graph(self, linker: Linker, manifest: Manifest): raise RuntimeError("Found a cycle: {}".format(cycle)) def compile(self, manifest: Manifest, write=True) -> Graph: + self.initialize() linker = Linker() self.link_graph(linker, manifest) @@ -307,11 +445,38 @@ def compile(self, manifest: Manifest, write=True) -> Graph: return Graph(linker.graph) + def _write_node(self, node: NonSourceCompiledNode) -> NonSourceNode: + if not _is_writable(node): + return node + logger.debug(f'Writing injected SQL for node "{node.unique_id}"') + + if node.injected_sql is None: + # this should not really happen, but it'd be a shame to crash + # over it + logger.error( + f'Compiled node "{node.unique_id}" had no injected_sql, ' + 'cannot write sql!' + ) + else: + node.build_path = node.write_node( + self.config.target_path, + 'compiled', + node.injected_sql + ) + return node -def compile_manifest(config, manifest, write=True) -> Graph: - compiler = Compiler(config) - compiler.initialize() - return compiler.compile(manifest, write=write) + def compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, + write: bool = True, + ) -> NonSourceCompiledNode: + node = self._compile_node(node, manifest, extra_context) + + if write and _is_writable(node): + self._write_node(node) + return node def _is_writable(node): @@ -322,20 +487,3 @@ def _is_writable(node): return False return True - - -def compile_node(adapter, config, node, manifest, extra_context, write=True): - compiler = Compiler(config) - node = compiler.compile_node(node, manifest, extra_context) - - if write and _is_writable(node): - logger.debug('Writing injected SQL for node "{}"'.format( - node.unique_id)) - - node.build_path = node.write_node( - config.target_path, - 'compiled', - node.injected_sql - ) - - return node diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 33c9681c15d..51d0ec44e73 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -105,39 +105,39 @@ def __init__( cli_vars: Mapping[str, Any], node: Optional[CompiledResource] = None ) -> None: - self.context: Mapping[str, Any] = context - self.cli_vars: Mapping[str, Any] = cli_vars - self.node: Optional[CompiledResource] = node - self.merged: Mapping[str, Any] = self._generate_merged() + self._context: Mapping[str, Any] = context + self._cli_vars: Mapping[str, Any] = cli_vars + self._node: Optional[CompiledResource] = node + self._merged: Mapping[str, Any] = self._generate_merged() def _generate_merged(self) -> Mapping[str, Any]: - return self.cli_vars + return self._cli_vars @property def node_name(self): - if self.node is not None: - return self.node.name + if self._node is not None: + return self._node.name else: return '' def get_missing_var(self, var_name): - dct = {k: self.merged[k] for k in self.merged} + dct = {k: self._merged[k] for k in self._merged} pretty_vars = json.dumps(dct, sort_keys=True, indent=4) msg = self.UndefinedVarError.format( var_name, self.node_name, pretty_vars ) - raise_compiler_error(msg, self.node) + raise_compiler_error(msg, self._node) def has_var(self, var_name: str): - return var_name in self.merged + return var_name in self._merged def get_rendered_var(self, var_name): - raw = self.merged[var_name] + raw = self._merged[var_name] # if bool/int/float/etc are passed in, don't compile anything if not isinstance(raw, str): return raw - return get_rendered(raw, self.context) + return get_rendered(raw, self._context) def __call__(self, var_name, default=_VAR_NOTSET): if self.has_var(var_name): diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index 4e820c9496f..baf2f60b34c 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -36,23 +36,23 @@ def __init__( project_name: str, ): super().__init__(context, config.cli_vars) - self.config = config - self.project_name = project_name + self._config = config + self._project_name = project_name def __call__(self, var_name, default=Var._VAR_NOTSET): - my_config = self.config.load_dependencies()[self.project_name] + my_config = self._config.load_dependencies()[self._project_name] # cli vars > active project > local project - if var_name in self.config.cli_vars: - return self.config.cli_vars[var_name] + if var_name in self._config.cli_vars: + return self._config.cli_vars[var_name] - if self.config.config_version == 2 and my_config.config_version == 2: - adapter_type = self.config.credentials.type - lookup = FQNLookup(self.project_name) - active_vars = self.config.vars.vars_for(lookup, adapter_type) + if self._config.config_version == 2 and my_config.config_version == 2: + adapter_type = self._config.credentials.type + lookup = FQNLookup(self._project_name) + active_vars = self._config.vars.vars_for(lookup, adapter_type) all_vars = MultiDict([active_vars]) - if self.config.project_name != my_config.project_name: + if self._config.project_name != my_config.project_name: all_vars.add(my_config.vars.vars_for(lookup, adapter_type)) if var_name in all_vars: diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index 3cdf9bf92e1..847417477a5 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -26,16 +26,16 @@ def __init__( node_type: NodeType, ): self._config = None - self.active_project: RuntimeConfig = active_project - self.own_project: Project = own_project + self._active_project: RuntimeConfig = active_project + self._own_project: Project = own_project - self.model = ModelParts( + self._model = ModelParts( fqn=fqn, resource_type=node_type, - package_name=self.own_project.project_name, + package_name=self._own_project.project_name, ) - self.updater = ConfigUpdater(active_project.credentials.type) + self._updater = ConfigUpdater(active_project.credentials.type) # the config options defined within the model self.in_model_config: Dict[str, Any] = {} @@ -43,12 +43,12 @@ def __init__( def get_default(self) -> Dict[str, Any]: defaults = {"enabled": True, "materialized": "view"} - if self.model.resource_type == NodeType.Seed: + if self._model.resource_type == NodeType.Seed: defaults['materialized'] = 'seed' - elif self.model.resource_type == NodeType.Snapshot: + elif self._model.resource_type == NodeType.Snapshot: defaults['materialized'] = 'snapshot' - if self.model.resource_type == NodeType.Test: + if self._model.resource_type == NodeType.Test: defaults['severity'] = 'ERROR' return defaults @@ -57,31 +57,34 @@ def build_config_dict(self, base: bool = False) -> Dict[str, Any]: defaults = self.get_default() active_config = self.load_config_from_active_project() - if self.active_project.project_name == self.own_project.project_name: - cfg = self.updater.merge( + if self._active_project.project_name == self._own_project.project_name: + cfg = self._updater.merge( defaults, active_config, self.in_model_config ) else: own_config = self.load_config_from_own_project() - cfg = self.updater.merge( + cfg = self._updater.merge( defaults, own_config, self.in_model_config, active_config ) return cfg def _translate_adapter_aliases(self, config: Dict[str, Any]): - return self.active_project.credentials.translate_aliases(config) + return self._active_project.credentials.translate_aliases(config) def update_in_model_config(self, config: Dict[str, Any]) -> None: config = self._translate_adapter_aliases(config) - self.updater.update_into(self.in_model_config, config) + self._updater.update_into(self.in_model_config, config) def load_config_from_own_project(self) -> Dict[str, Any]: - return self.updater.get_project_config(self.model, self.own_project) + return self._updater.get_project_config(self._model, self._own_project) def load_config_from_active_project(self) -> Dict[str, Any]: - return self.updater.get_project_config(self.model, self.active_project) + return self._updater.get_project_config( + self._model, + self._active_project, + ) T = TypeVar('T', bound=BaseConfig) @@ -89,12 +92,12 @@ def load_config_from_active_project(self) -> Dict[str, Any]: class ContextConfigGenerator: def __init__(self, active_project: RuntimeConfig): - self.active_project = active_project + self._active_project = active_project def get_node_project(self, project_name: str): - if project_name == self.active_project.project_name: - return self.active_project - dependencies = self.active_project.load_dependencies() + if project_name == self._active_project.project_name: + return self._active_project + dependencies = self._active_project.load_dependencies() if project_name not in dependencies: raise InternalException( f'Project name {project_name} not found in dependencies ' @@ -102,7 +105,7 @@ def get_node_project(self, project_name: str): ) return dependencies[project_name] - def project_configs( + def _project_configs( self, project: Project, fqn: List[str], resource_type: NodeType ) -> Iterator[Dict[str, Any]]: if resource_type == NodeType.Seed: @@ -123,18 +126,20 @@ def project_configs( yield result - def active_project_configs( + def _active_project_configs( self, fqn: List[str], resource_type: NodeType ) -> Iterator[Dict[str, Any]]: - return self.project_configs(self.active_project, fqn, resource_type) + return self._project_configs(self._active_project, fqn, resource_type) def _update_from_config( self, result: T, partial: Dict[str, Any], validate: bool = False ) -> T: - translated = self.active_project.credentials.translate_aliases(partial) + translated = self._active_project.credentials.translate_aliases( + partial + ) return result.update_from( translated, - self.active_project.credentials.type, + self._active_project.credentials.type, validate=validate ) @@ -153,13 +158,16 @@ def calculate_node_config( # because it might be invalid in the case of required config members # (such as on snapshots!) result = config_cls.from_dict({}, validate=False) - for fqn_config in self.project_configs(own_config, fqn, resource_type): + + project_configs = self._project_configs(own_config, fqn, resource_type) + for fqn_config in project_configs: result = self._update_from_config(result, fqn_config) + for config_call in config_calls: result = self._update_from_config(result, config_call) - if own_config.project_name != self.active_project.project_name: - for fqn_config in self.active_project_configs(fqn, resource_type): + if own_config.project_name != self._active_project.project_name: + for fqn_config in self._active_project_configs(fqn, resource_type): result = self._update_from_config(result, fqn_config) # this is mostly impactful in the snapshot config case @@ -174,21 +182,21 @@ def __init__( resource_type: NodeType, project_name: str, ) -> None: - self.config_calls: List[Dict[str, Any]] = [] - self.cfg_source = ContextConfigGenerator(active_project) - self.fqn = fqn - self.resource_type = resource_type - self.project_name = project_name + self._config_calls: List[Dict[str, Any]] = [] + self._cfg_source = ContextConfigGenerator(active_project) + self._fqn = fqn + self._resource_type = resource_type + self._project_name = project_name def update_in_model_config(self, opts: Dict[str, Any]) -> None: - self.config_calls.append(opts) + self._config_calls.append(opts) def build_config_dict(self, base: bool = False) -> Dict[str, Any]: - return self.cfg_source.calculate_node_config( - config_calls=self.config_calls, - fqn=self.fqn, - resource_type=self.resource_type, - project_name=self.project_name, + return self._cfg_source.calculate_node_config( + config_calls=self._config_calls, + fqn=self._fqn, + resource_type=self._resource_type, + project_name=self._project_name, base=base, ).to_dict() diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index ed7d7d56615..16d79213f30 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -46,7 +46,7 @@ from dbt.node_types import NodeType from dbt.utils import ( - add_ephemeral_model_prefix, merge, AttrDict, MultiDict + merge, AttrDict, MultiDict ) import agate @@ -58,23 +58,23 @@ # base classes class RelationProxy: def __init__(self, adapter): - self.quoting_config = adapter.config.quoting - self.relation_type = adapter.Relation + self._quoting_config = adapter.config.quoting + self._relation_type = adapter.Relation def __getattr__(self, key): - return getattr(self.relation_type, key) + return getattr(self._relation_type, key) def create_from_source(self, *args, **kwargs): # bypass our create when creating from source so as not to mess up # the source quoting - return self.relation_type.create_from_source(*args, **kwargs) + return self._relation_type.create_from_source(*args, **kwargs) def create(self, *args, **kwargs): kwargs['quote_policy'] = merge( - self.quoting_config, + self._quoting_config, kwargs.pop('quote_policy', {}) ) - return self.relation_type.create(*args, **kwargs) + return self._relation_type.create(*args, **kwargs) class BaseDatabaseWrapper: @@ -83,28 +83,28 @@ class BaseDatabaseWrapper: via a relation proxy. """ def __init__(self, adapter, namespace: MacroNamespace): - self.adapter = adapter + self._adapter = adapter self.Relation = RelationProxy(adapter) - self.namespace = namespace + self._namespace = namespace def __getattr__(self, name): raise NotImplementedError('subclasses need to implement this') @property def config(self): - return self.adapter.config + return self._adapter.config def type(self): - return self.adapter.type() + return self._adapter.type() def commit(self): - return self.adapter.commit_if_has_connection() + return self._adapter.commit_if_has_connection() def _get_adapter_macro_prefixes(self) -> List[str]: # a future version of this could have plugins automatically call fall # back to their dependencies' dependencies by using # `get_adapter_type_names` instead of `[self.config.credentials.type]` - search_prefixes = [self.adapter.type(), 'default'] + search_prefixes = [self._adapter.type(), 'default'] return search_prefixes def dispatch( @@ -138,7 +138,7 @@ def dispatch( for prefix in self._get_adapter_macro_prefixes(): search_name = f'{prefix}__{macro_name}' try: - macro = self.namespace.get_from_package( + macro = self._namespace.get_from_package( package_name, search_name ) except CompilationException as exc: @@ -379,13 +379,13 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper): parse-time overrides. """ def __getattr__(self, name): - override = (name in self.adapter._available_ and - name in self.adapter._parse_replacements_) + override = (name in self._adapter._available_ and + name in self._adapter._parse_replacements_) if override: - return self.adapter._parse_replacements_[name] - elif name in self.adapter._available_: - return getattr(self.adapter, name) + return self._adapter._parse_replacements_[name] + elif name in self._adapter._available_: + return getattr(self._adapter, name) else: raise AttributeError( "'{}' object has no attribute '{}'".format( @@ -399,8 +399,8 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper): available. """ def __getattr__(self, name): - if name in self.adapter._available_: - return getattr(self.adapter, name) + if name in self._adapter._available_: + return getattr(self._adapter, name) else: raise AttributeError( "'{}' object has no attribute '{}'".format( @@ -443,20 +443,14 @@ def resolve( self.validate(target_model, target_name, target_package) return self.create_relation(target_model, target_name) - def create_ephemeral_relation( - self, target_model: NonSourceNode, name: str - ) -> RelationProxy: - self.model.set_cte(target_model.unique_id, None) - return self.Relation.create( - type=self.Relation.CTE, - identifier=add_ephemeral_model_prefix(name) - ).quote(identifier=False) - def create_relation( self, target_model: NonSourceNode, name: str ) -> RelationProxy: - if target_model.get_materialization() == 'ephemeral': - return self.create_ephemeral_relation(target_model, name) + if target_model.is_ephemeral_model: + self.model.set_cte(target_model.unique_id, None) + return self.Relation.create_ephemeral_from_node( + self.config, target_model + ) else: return self.Relation.create_from(self.config, target_model) @@ -480,16 +474,19 @@ def validate( ) -> None: pass - def create_ephemeral_relation( + def create_relation( self, target_model: NonSourceNode, name: str ) -> RelationProxy: - # In operations, we can't ref() ephemeral nodes, because ParsedMacros - # do not support set_cte - raise_compiler_error( - 'Operations can not ref() ephemeral nodes, but {} is ephemeral' - .format(target_model.name), - self.model - ) + if target_model.is_ephemeral_model: + # In operations, we can't ref() ephemeral nodes, because + # ParsedMacros do not support set_cte + raise_compiler_error( + 'Operations can not ref() ephemeral nodes, but {} is ephemeral' + .format(target_model.name), + self.model + ) + else: + return super().create_relation(target_model, name) # `source` implementations @@ -526,37 +523,37 @@ def __init__( config: RuntimeConfig, node: CompiledResource, ) -> None: - self.node: CompiledResource - self.config: RuntimeConfig = config + self._node: CompiledResource + self._config: RuntimeConfig = config super().__init__(context, config.cli_vars, node=node) def packages_for_node(self) -> Iterable[Project]: - dependencies = self.config.load_dependencies() - package_name = self.node.package_name + dependencies = self._config.load_dependencies() + package_name = self._node.package_name - if package_name != self.config.project_name: + if package_name != self._config.project_name: if package_name not in dependencies: # I don't think this is actually reachable raise_compiler_error( f'Node package named {package_name} not found!', - self.node + self._node ) yield dependencies[package_name] - yield self.config + yield self._config def _generate_merged(self) -> Mapping[str, Any]: search_node: IsFQNResource - if isinstance(self.node, IsFQNResource): - search_node = self.node + if isinstance(self._node, IsFQNResource): + search_node = self._node else: - search_node = FQNLookup(self.node.package_name) + search_node = FQNLookup(self._node.package_name) - adapter_type = self.config.credentials.type + adapter_type = self._config.credentials.type merged = MultiDict() for project in self.packages_for_node(): merged.add(project.vars.vars_for(search_node, adapter_type)) - merged.add(self.cli_vars) + merged.add(self._cli_vars) return merged diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 6fc6f1fe5d7..1406475f969 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -17,11 +17,9 @@ ) from dbt.node_types import NodeType from dbt.contracts.util import Replaceable -from dbt.exceptions import RuntimeException from hologram import JsonSchemaMixin from dataclasses import dataclass, field -import sqlparse # type: ignore from typing import Optional, List, Union, Dict, Type @@ -45,19 +43,6 @@ class CompiledNode(ParsedNode, CompiledNodeMixin): extra_ctes: List[InjectedCTE] = field(default_factory=list) injected_sql: Optional[str] = None - def prepend_ctes(self, prepended_ctes: List[InjectedCTE]): - self.extra_ctes_injected = True - self.extra_ctes = prepended_ctes - if self.compiled_sql is None: - raise RuntimeException( - 'Cannot prepend ctes to an unparsed node', self - ) - self.injected_sql = _inject_ctes_into_sql( - self.compiled_sql, - prepended_ctes, - ) - self.validate(self.to_dict()) - def set_cte(self, cte_id: str, sql: str): """This is the equivalent of what self.extra_ctes[cte_id] = sql would do if extra_ctes were an OrderedDict @@ -146,66 +131,6 @@ def same_contents(self, other) -> bool: CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode] -def _inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str: - """ - `ctes` is a list of InjectedCTEs like: - - [ - InjectedCTE( - id="cte_id_1", - sql="__dbt__CTE__ephemeral as (select * from table)", - ), - InjectedCTE( - id="cte_id_2", - sql="__dbt__CTE__events as (select id, type from events)", - ), - ] - - Given `sql` like: - - "with internal_cte as (select * from sessions) - select * from internal_cte" - - This will spit out: - - "with __dbt__CTE__ephemeral as (select * from table), - __dbt__CTE__events as (select id, type from events), - with internal_cte as (select * from sessions) - select * from internal_cte" - - (Whitespace enhanced for readability.) - """ - if len(ctes) == 0: - return sql - - parsed_stmts = sqlparse.parse(sql) - parsed = parsed_stmts[0] - - with_stmt = None - for token in parsed.tokens: - if token.is_keyword and token.normalized == 'WITH': - with_stmt = token - break - - if with_stmt is None: - # no with stmt, add one, and inject CTEs right at the beginning - first_token = parsed.token_first() - with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') - parsed.insert_before(first_token, with_stmt) - else: - # stmt exists, add a comma (which will come after injected CTEs) - trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ',') - parsed.insert_after(with_stmt, trailing_comma) - - token = sqlparse.sql.Token( - sqlparse.tokens.Keyword, - ", ".join(c.sql for c in ctes) - ) - parsed.insert_after(with_stmt, token) - - return str(parsed) - - PARSED_TYPES: Dict[Type[CompiledNode], Type[ParsedResource]] = { CompiledAnalysisNode: ParsedAnalysisNode, CompiledModelNode: ParsedModelNode, diff --git a/core/dbt/rpc/node_runners.py b/core/dbt/rpc/node_runners.py index 58130b17415..805687767cc 100644 --- a/core/dbt/rpc/node_runners.py +++ b/core/dbt/rpc/node_runners.py @@ -2,7 +2,6 @@ from typing import Generic, TypeVar import dbt.exceptions -from dbt.compilation import compile_node from dbt.contracts.rpc import ( RemoteCompileResult, RemoteRunResult, ResultTable, ) @@ -38,8 +37,8 @@ def after_execute(self, result): pass def compile(self, manifest): - return compile_node(self.adapter, self.config, self.node, manifest, {}, - write=False) + compiler = self.adapter.get_compiler() + return compiler.compile_node(self.node, manifest, {}, write=False) @abstractmethod def execute(self, compiled_node, manifest) -> RPCSQLResult: diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index e7476a5ebbd..e4469f39728 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,7 +1,6 @@ from .runnable import GraphRunnableTask from .base import BaseRunner -from dbt.compilation import compile_node from dbt.contracts.results import RunModelResult from dbt.exceptions import InternalException from dbt.graph import ResourceTypeSelector, SelectionSpec, parse_difference @@ -20,7 +19,8 @@ def execute(self, compiled_node, manifest): return RunModelResult(compiled_node) def compile(self, manifest): - return compile_node(self.adapter, self.config, self.node, manifest, {}) + compiler = self.adapter.get_compiler() + return compiler.compile_node(self.node, manifest, {}) class CompileTask(GraphRunnableTask): diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py index a871a34c446..deea0c92580 100644 --- a/core/dbt/task/rpc/sql_commands.py +++ b/core/dbt/task/rpc/sql_commands.py @@ -7,7 +7,6 @@ from dbt import flags from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks -from dbt.compilation import compile_manifest from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedRPCNode @@ -129,7 +128,9 @@ def _get_exec_node(self): ) # don't write our new, weird manifest! - self.graph = compile_manifest(self.config, self.manifest, write=False) + adapter = get_adapter(self.config) + compiler = adapter.get_compiler() + self.graph = compiler.compile(self.manifest, write=False) # previously, this compiled the ancestors, but they are compiled at # runtime now. return rpc_node diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 917d38dc90b..449900649a2 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -18,7 +18,6 @@ from dbt import utils from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator -from dbt.compilation import compile_node from dbt.context.providers import generate_runtime_model from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import WritableManifest @@ -254,8 +253,8 @@ def raise_on_first_error(self): return False def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context): - compiled = compile_node(adapter, self.config, hook, self.manifest, - extra_context) + compiler = adapter.get_compiler() + compiled = compiler.compile_node(hook, self.manifest, extra_context) statement = compiled.injected_sql hook_index = hook.index or num_hooks hook_obj = get_hook(statement, index=hook_index) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 90f1e14bd98..dbedfcd6d05 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -26,7 +26,6 @@ NodeCount, print_timestamped_line, ) -from dbt.compilation import compile_manifest from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest @@ -71,7 +70,9 @@ def compile_manifest(self): raise InternalException( 'compile_manifest called before manifest was loaded' ) - self.graph = compile_manifest(self.config, self.manifest) + adapter = get_adapter(self.config) + compiler = adapter.get_compiler() + self.graph = compiler.compile(self.manifest) def _runtime_initialize(self): self.load_manifest() diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index 74bd101abc9..f37c70c2fa9 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -41,10 +41,9 @@ def print_start_line(self): print_start_line(description, self.node_index, self.num_nodes) def execute_data_test(self, test: CompiledDataTestNode): - sql = ( - f'select count(*) as errors from (\n{test.injected_sql}\n) sbq' + res, table = self.adapter.execute( + test.injected_sql, auto_begin=True, fetch=True ) - res, table = self.adapter.execute(sql, auto_begin=True, fetch=True) num_rows = len(table.rows) if num_rows != 1: diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index f7cec475158..367462a08a3 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -3,6 +3,7 @@ import dbt.flags import dbt.compilation +from dbt.adapters.postgres import Plugin from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import NodeConfig, DependsOn, ParsedModelNode @@ -11,6 +12,8 @@ from datetime import datetime +from .utils import inject_adapter, clear_plugin, config_from_parts_or_dicts + class CompilerTest(unittest.TestCase): def assertEqualIgnoreWhitespace(self, a, b): @@ -34,10 +37,46 @@ def setUp(self): 'column_types': {}, 'tags': [], }) - self.mock_config = MagicMock(credentials=MagicMock(type='postgres')) + + project_cfg = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + } + profile_cfg = { + 'outputs': { + 'test': { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'thishostshouldnotexist', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + }, + 'target': 'test' + } + + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + self._generate_runtime_model_patch = patch.object(dbt.compilation, 'generate_runtime_model') self.mock_generate_runtime_model = self._generate_runtime_model_patch.start() + inject_adapter(Plugin.adapter(self.config), Plugin) + + # self.mock_adapter = PostgresAdapter MagicMock(type=MagicMock(return_value='postgres')) + # self.mock_adapter.Relation = + # self.mock_adapter.get_compiler.return_value = dbt.compilation.Compiler + # self.mock_plugin = MagicMock( + # adapter=MagicMock( + # credentials=MagicMock(return_value='postgres') + # ) + # ) + # inject_adapter(self.mock_adapter, self.mock_plugin) + # so we can make an adapter + def mock_generate_runtime_model_context(model, config, manifest): def ref(name): result = f'__dbt__CTE__{name}' @@ -50,6 +89,7 @@ def ref(name): def tearDown(self): self._generate_runtime_model_patch.stop() + clear_plugin(Plugin) def test__prepend_ctes__already_has_cte(self): ephemeral_config = self.model_config.replace(materialized='ephemeral') @@ -118,7 +158,7 @@ def test__prepend_ctes__already_has_cte(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -202,7 +242,7 @@ def test__prepend_ctes__no_ctes(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -217,7 +257,7 @@ def test__prepend_ctes__no_ctes(self): result.injected_sql, manifest.nodes.get('model.root.view').compiled_sql) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes.get('model.root.view_no_cte'), manifest, @@ -295,7 +335,7 @@ def test__prepend_ctes(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -399,7 +439,7 @@ def test__prepend_ctes__cte_not_compiled(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) with patch.object(compiler, 'compile_node') as compile_node: compile_node.return_value = compiled_ephemeral @@ -504,9 +544,7 @@ def test__prepend_ctes__multiple_levels(self): files={}, ) - compiler = dbt.compilation.Compiler( - MagicMock(credentials=MagicMock(type='postgres')) - ) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 0e545cdf3da..f62733f7b29 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -367,7 +367,6 @@ def get_include_paths(): def config(): return config_from_parts_or_dicts(PROJECT_DATA, PROFILE_DATA) - @pytest.fixture def manifest_fx(config): return mock_manifest(config) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index ebeb5b9b306..f3e773b4093 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -16,7 +16,7 @@ from psycopg2 import extensions as psycopg2_extensions from psycopg2 import DatabaseError -from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection, TestAdapterConversions, load_internal_manifest_macros +from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection, TestAdapterConversions, load_internal_manifest_macros, clear_plugin class TestPostgresAdapter(unittest.TestCase): @@ -297,6 +297,7 @@ def tearDown(self): self.qh_patch.stop() self.patcher.stop() self.load_patch.stop() + clear_plugin(PostgresPlugin) def test_quoting_on_drop_schema(self): relation = self.adapter.Relation.create( diff --git a/test/unit/test_source_config.py b/test/unit/test_source_config.py index 69377551e94..fd42e89d579 100644 --- a/test/unit/test_source_config.py +++ b/test/unit/test_source_config.py @@ -172,6 +172,6 @@ def test__context_config_wrong_type(self): model = mock.MagicMock(resource_type=NodeType.Model, fqn=['root', 'x'], project_name='root') with self.assertRaises(dbt.exceptions.CompilationException) as exc: - cfg.updater.get_project_config(model, self.root_project_config) + cfg._updater.get_project_config(model, self.root_project_config) self.assertIn('must be a dict', str(exc.exception)) diff --git a/test/unit/utils.py b/test/unit/utils.py index 03701695e52..4701c4bcd5c 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -112,6 +112,14 @@ def inject_plugin(plugin): FACTORY.plugins[key] = plugin +def inject_plugin_for(config): + # from dbt.adapters.postgres import Plugin, PostgresAdapter + from dbt.adapters.factory import FACTORY + FACTORY.load_plugin(config.credentials.type) + adapter = FACTORY.get_adapter(config) + return adapter + + def inject_adapter(value, plugin): """Inject the given adapter into the adapter factory, so your hand-crafted artisanal adapter will be available from get_adapter() as if dbt loaded it. diff --git a/third-party-stubs/sqlparse/__init__.pyi b/third-party-stubs/sqlparse/__init__.pyi new file mode 100644 index 00000000000..4e2a9a2b00d --- /dev/null +++ b/third-party-stubs/sqlparse/__init__.pyi @@ -0,0 +1,7 @@ +from typing import Tuple +from . import sql +from . import tokens + + +def parse(sql: str) -> Tuple[sql.Statement]: + ... diff --git a/third-party-stubs/sqlparse/sql.pyi b/third-party-stubs/sqlparse/sql.pyi new file mode 100644 index 00000000000..103a576d55a --- /dev/null +++ b/third-party-stubs/sqlparse/sql.pyi @@ -0,0 +1,32 @@ +from typing import Tuple, Iterable + + +class Token: + def __init__(self, ttype, value): + ... + + is_keyword: bool + normalized: str + + +class TokenList(Token): + tokens: Tuple[Token] + + def __getitem__(self, key) -> Token: + ... + + def __iter__(self) -> Iterable[Token]: + ... + + def insert_before(self, where, token, skip_ws=True): + ... + + def insert_after(self, where, token, skip_ws=True): + ... + + def token_first(self) -> Token: + ... + + +class Statement(TokenList): + ... diff --git a/third-party-stubs/sqlparse/tokens.pyi b/third-party-stubs/sqlparse/tokens.pyi new file mode 100644 index 00000000000..68e0ed8b097 --- /dev/null +++ b/third-party-stubs/sqlparse/tokens.pyi @@ -0,0 +1,6 @@ + +class _TokenType(tuple): + ... + +Keyword: _TokenType = _TokenType() +Punctuation: _TokenType = _TokenType()