From f80a759488cf69664fd60eadede753fa9026ee2a Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 17 Aug 2020 09:10:56 -0600 Subject: [PATCH 1/5] Have the adapter be responsible for producing the compiler The adapter's Relation is consulted for adding the ephemeral model prefix Also hide some things from Jinja Have the adapter be responsible for producing the compiler, move CTE generation into the Relation object --- core/dbt/adapters/base/impl.py | 13 ++++- core/dbt/adapters/base/relation.py | 17 +++++++ core/dbt/adapters/protocol.py | 33 +++++++++++-- core/dbt/compilation.py | 73 ++++++++++++++++++---------- core/dbt/context/providers.py | 77 ++++++++++++++---------------- core/dbt/rpc/node_runners.py | 5 +- core/dbt/task/compile.py | 4 +- core/dbt/task/rpc/sql_commands.py | 5 +- core/dbt/task/run.py | 5 +- core/dbt/task/runnable.py | 5 +- test/unit/test_compiler.py | 56 ++++++++++++++++++---- test/unit/test_context.py | 1 - test/unit/test_postgres_adapter.py | 3 +- test/unit/utils.py | 8 ++++ 14 files changed, 211 insertions(+), 94 deletions(-) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 7c1cc592530..5718db8c69d 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -25,7 +25,9 @@ ) from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows from dbt.clients.jinja import MacroGenerator -from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode +from dbt.contracts.graph.compiled import ( + CompileResultNode, CompiledSeedNode +) from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSeedNode from dbt.exceptions import warn_or_error @@ -289,7 +291,10 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]: return { self.Relation.create_from(self.config, node).without_identifier() for node in manifest.nodes.values() - if node.resource_type in NodeType.executable() + if ( + node.resource_type in NodeType.executable() and + not node.is_ephemeral_model + ) } def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap: @@ -1142,6 +1147,10 @@ def get_rows_different_sql( return sql + def get_compiler(self): + from dbt.compilation import Compiler + return Compiler(self.config) + COLUMNS_EQUAL_SQL = ''' with diff_count as ( diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 3f4026fe8e5..8cbe4b8f17b 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -201,6 +201,23 @@ def create_from_source( **kwargs ) + @staticmethod + def add_ephemeral_prefix(name: str): + return f'__dbt__CTE__{name}' + + @classmethod + def create_ephemeral_from_node( + cls: Type[Self], + config: HasQuoting, + node: Union[ParsedNode, CompiledNode], + ) -> Self: + # Note that ephemeral models are based on the name. + identifier = cls.add_ephemeral_prefix(node.name) + return cls.create( + type=cls.CTE, + identifier=identifier, + ).quote(identifier=False) + @classmethod def create_from_node( cls: Type[Self], diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 68809b77283..ab5cff64f41 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -1,19 +1,23 @@ from dataclasses import dataclass from typing import ( Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar, - Tuple, Union + Tuple, Union, Dict, Any ) from typing_extensions import Protocol import agate from dbt.contracts.connection import Connection, AdapterRequiredConfig -from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.compiled import ( + CompiledNode, NonSourceNode, NonSourceCompiledNode +) from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.relation import Policy, HasQuoting +from dbt.graph import Graph + @dataclass class AdapterConfig(BaseConfig): @@ -45,6 +49,19 @@ def create_from( ... +class CompilerProtocol(Protocol): + def compile(self, manifest: Manifest, write=True) -> Graph: + ... + + def compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, + ) -> NonSourceCompiledNode: + ... + + AdapterConfig_T = TypeVar( 'AdapterConfig_T', bound=AdapterConfig ) @@ -57,11 +74,18 @@ def create_from( Column_T = TypeVar( 'Column_T', bound=ColumnProtocol ) +Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol) class AdapterProtocol( Protocol, - Generic[AdapterConfig_T, ConnectionManager_T, Relation_T, Column_T] + Generic[ + AdapterConfig_T, + ConnectionManager_T, + Relation_T, + Column_T, + Compiler_T, + ] ): AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]] Column: ClassVar[Type[Column_T]] @@ -132,3 +156,6 @@ def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False ) -> Tuple[str, agate.Table]: ... + + def get_compiler(self) -> Compiler_T: + ... diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d9e449ff22c..8f68a818335 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -1,10 +1,11 @@ import os from collections import defaultdict -from typing import List, Dict, Any, Tuple, cast +from typing import List, Dict, Any, Tuple, cast, Optional import networkx as nx # type: ignore from dbt import flags +from dbt.adapters.factory import get_adapter from dbt.clients import jinja from dbt.clients.system import make_directory from dbt.context.providers import generate_runtime_model @@ -21,7 +22,7 @@ from dbt.graph import Graph from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType -from dbt.utils import add_ephemeral_model_prefix, pluralize +from dbt.utils import pluralize graph_file_name = 'graph.gpickle' @@ -156,6 +157,11 @@ def _create_node_context( return context + def add_ephemeral_prefix(self, name: str): + adapter = get_adapter(self.config) + relation_cls = adapter.Relation + return relation_cls.add_ephemeral_prefix(name) + def _get_compiled_model( self, manifest: Manifest, @@ -213,7 +219,8 @@ def _recursively_prepend_ctes( cte_model, manifest, extra_context ) _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) - new_cte_name = add_ephemeral_model_prefix(cte_model.name) + + new_cte_name = self.add_ephemeral_prefix(cte_model.name) sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)' _add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql)) @@ -223,8 +230,11 @@ def _recursively_prepend_ctes( return model, prepended_ctes - def compile_node( - self, node: NonSourceNode, manifest, extra_context=None + def _compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, ) -> NonSourceCompiledNode: if extra_context is None: extra_context = {} @@ -295,6 +305,7 @@ def link_graph(self, linker: Linker, manifest: Manifest): raise RuntimeError("Found a cycle: {}".format(cycle)) def compile(self, manifest: Manifest, write=True) -> Graph: + self.initialize() linker = Linker() self.link_graph(linker, manifest) @@ -307,11 +318,38 @@ def compile(self, manifest: Manifest, write=True) -> Graph: return Graph(linker.graph) + def _write_node(self, node: NonSourceNode) -> NonSourceNode: + if not _is_writable(node): + return node + logger.debug(f'Writing injected SQL for node "{node.unique_id}"') + + if node.injected_sql is None: + # this should not really happen, but it'd be a shame to crash + # over it + logger.error( + f'Compiled node "{node.unique_id}" had no injected_sql, ' + 'cannot write sql!' + ) + else: + node.build_path = node.write_node( + self.config.target_path, + 'compiled', + node.injected_sql + ) + return node + + def compile_node( + self, + node: NonSourceNode, + manifest: Manifest, + extra_context: Optional[Dict[str, Any]] = None, + write: bool = True, + ) -> NonSourceCompiledNode: + node = self._compile_node(node, manifest, extra_context) -def compile_manifest(config, manifest, write=True) -> Graph: - compiler = Compiler(config) - compiler.initialize() - return compiler.compile(manifest, write=write) + if write and _is_writable(node): + self._write_node(node) + return node def _is_writable(node): @@ -322,20 +360,3 @@ def _is_writable(node): return False return True - - -def compile_node(adapter, config, node, manifest, extra_context, write=True): - compiler = Compiler(config) - node = compiler.compile_node(node, manifest, extra_context) - - if write and _is_writable(node): - logger.debug('Writing injected SQL for node "{}"'.format( - node.unique_id)) - - node.build_path = node.write_node( - config.target_path, - 'compiled', - node.injected_sql - ) - - return node diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index ed7d7d56615..a9345504acf 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -46,7 +46,7 @@ from dbt.node_types import NodeType from dbt.utils import ( - add_ephemeral_model_prefix, merge, AttrDict, MultiDict + merge, AttrDict, MultiDict ) import agate @@ -58,23 +58,23 @@ # base classes class RelationProxy: def __init__(self, adapter): - self.quoting_config = adapter.config.quoting - self.relation_type = adapter.Relation + self._quoting_config = adapter.config.quoting + self._relation_type = adapter.Relation def __getattr__(self, key): - return getattr(self.relation_type, key) + return getattr(self._relation_type, key) def create_from_source(self, *args, **kwargs): # bypass our create when creating from source so as not to mess up # the source quoting - return self.relation_type.create_from_source(*args, **kwargs) + return self._relation_type.create_from_source(*args, **kwargs) def create(self, *args, **kwargs): kwargs['quote_policy'] = merge( - self.quoting_config, + self._quoting_config, kwargs.pop('quote_policy', {}) ) - return self.relation_type.create(*args, **kwargs) + return self._relation_type.create(*args, **kwargs) class BaseDatabaseWrapper: @@ -83,28 +83,28 @@ class BaseDatabaseWrapper: via a relation proxy. """ def __init__(self, adapter, namespace: MacroNamespace): - self.adapter = adapter + self._adapter = adapter self.Relation = RelationProxy(adapter) - self.namespace = namespace + self._namespace = namespace def __getattr__(self, name): raise NotImplementedError('subclasses need to implement this') @property def config(self): - return self.adapter.config + return self._adapter.config def type(self): - return self.adapter.type() + return self._adapter.type() def commit(self): - return self.adapter.commit_if_has_connection() + return self._adapter.commit_if_has_connection() def _get_adapter_macro_prefixes(self) -> List[str]: # a future version of this could have plugins automatically call fall # back to their dependencies' dependencies by using # `get_adapter_type_names` instead of `[self.config.credentials.type]` - search_prefixes = [self.adapter.type(), 'default'] + search_prefixes = [self._adapter.type(), 'default'] return search_prefixes def dispatch( @@ -138,7 +138,7 @@ def dispatch( for prefix in self._get_adapter_macro_prefixes(): search_name = f'{prefix}__{macro_name}' try: - macro = self.namespace.get_from_package( + macro = self._namespace.get_from_package( package_name, search_name ) except CompilationException as exc: @@ -379,13 +379,13 @@ class ParseDatabaseWrapper(BaseDatabaseWrapper): parse-time overrides. """ def __getattr__(self, name): - override = (name in self.adapter._available_ and - name in self.adapter._parse_replacements_) + override = (name in self._adapter._available_ and + name in self._adapter._parse_replacements_) if override: - return self.adapter._parse_replacements_[name] - elif name in self.adapter._available_: - return getattr(self.adapter, name) + return self._adapter._parse_replacements_[name] + elif name in self._adapter._available_: + return getattr(self._adapter, name) else: raise AttributeError( "'{}' object has no attribute '{}'".format( @@ -399,8 +399,8 @@ class RuntimeDatabaseWrapper(BaseDatabaseWrapper): available. """ def __getattr__(self, name): - if name in self.adapter._available_: - return getattr(self.adapter, name) + if name in self._adapter._available_: + return getattr(self._adapter, name) else: raise AttributeError( "'{}' object has no attribute '{}'".format( @@ -443,20 +443,14 @@ def resolve( self.validate(target_model, target_name, target_package) return self.create_relation(target_model, target_name) - def create_ephemeral_relation( - self, target_model: NonSourceNode, name: str - ) -> RelationProxy: - self.model.set_cte(target_model.unique_id, None) - return self.Relation.create( - type=self.Relation.CTE, - identifier=add_ephemeral_model_prefix(name) - ).quote(identifier=False) - def create_relation( self, target_model: NonSourceNode, name: str ) -> RelationProxy: - if target_model.get_materialization() == 'ephemeral': - return self.create_ephemeral_relation(target_model, name) + if target_model.is_ephemeral_model: + self.model.set_cte(target_model.unique_id, None) + return self.Relation.create_ephemeral_from_node( + self.config, target_model + ) else: return self.Relation.create_from(self.config, target_model) @@ -480,16 +474,19 @@ def validate( ) -> None: pass - def create_ephemeral_relation( + def create_relation( self, target_model: NonSourceNode, name: str ) -> RelationProxy: - # In operations, we can't ref() ephemeral nodes, because ParsedMacros - # do not support set_cte - raise_compiler_error( - 'Operations can not ref() ephemeral nodes, but {} is ephemeral' - .format(target_model.name), - self.model - ) + if target_model.is_ephemeral_model: + # In operations, we can't ref() ephemeral nodes, because + # ParsedMacros do not support set_cte + raise_compiler_error( + 'Operations can not ref() ephemeral nodes, but {} is ephemeral' + .format(target_model.name), + self.model + ) + else: + return super().create_relation(target_model, name) # `source` implementations diff --git a/core/dbt/rpc/node_runners.py b/core/dbt/rpc/node_runners.py index 58130b17415..805687767cc 100644 --- a/core/dbt/rpc/node_runners.py +++ b/core/dbt/rpc/node_runners.py @@ -2,7 +2,6 @@ from typing import Generic, TypeVar import dbt.exceptions -from dbt.compilation import compile_node from dbt.contracts.rpc import ( RemoteCompileResult, RemoteRunResult, ResultTable, ) @@ -38,8 +37,8 @@ def after_execute(self, result): pass def compile(self, manifest): - return compile_node(self.adapter, self.config, self.node, manifest, {}, - write=False) + compiler = self.adapter.get_compiler() + return compiler.compile_node(self.node, manifest, {}, write=False) @abstractmethod def execute(self, compiled_node, manifest) -> RPCSQLResult: diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index e7476a5ebbd..e4469f39728 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,7 +1,6 @@ from .runnable import GraphRunnableTask from .base import BaseRunner -from dbt.compilation import compile_node from dbt.contracts.results import RunModelResult from dbt.exceptions import InternalException from dbt.graph import ResourceTypeSelector, SelectionSpec, parse_difference @@ -20,7 +19,8 @@ def execute(self, compiled_node, manifest): return RunModelResult(compiled_node) def compile(self, manifest): - return compile_node(self.adapter, self.config, self.node, manifest, {}) + compiler = self.adapter.get_compiler() + return compiler.compile_node(self.node, manifest, {}) class CompileTask(GraphRunnableTask): diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py index a871a34c446..deea0c92580 100644 --- a/core/dbt/task/rpc/sql_commands.py +++ b/core/dbt/task/rpc/sql_commands.py @@ -7,7 +7,6 @@ from dbt import flags from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks -from dbt.compilation import compile_manifest from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedRPCNode @@ -129,7 +128,9 @@ def _get_exec_node(self): ) # don't write our new, weird manifest! - self.graph = compile_manifest(self.config, self.manifest, write=False) + adapter = get_adapter(self.config) + compiler = adapter.get_compiler() + self.graph = compiler.compile(self.manifest, write=False) # previously, this compiled the ancestors, but they are compiled at # runtime now. return rpc_node diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 917d38dc90b..449900649a2 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -18,7 +18,6 @@ from dbt import utils from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator -from dbt.compilation import compile_node from dbt.context.providers import generate_runtime_model from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import WritableManifest @@ -254,8 +253,8 @@ def raise_on_first_error(self): return False def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context): - compiled = compile_node(adapter, self.config, hook, self.manifest, - extra_context) + compiler = adapter.get_compiler() + compiled = compiler.compile_node(hook, self.manifest, extra_context) statement = compiled.injected_sql hook_index = hook.index or num_hooks hook_obj = get_hook(statement, index=hook_index) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 90f1e14bd98..dbedfcd6d05 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -26,7 +26,6 @@ NodeCount, print_timestamped_line, ) -from dbt.compilation import compile_manifest from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest @@ -71,7 +70,9 @@ def compile_manifest(self): raise InternalException( 'compile_manifest called before manifest was loaded' ) - self.graph = compile_manifest(self.config, self.manifest) + adapter = get_adapter(self.config) + compiler = adapter.get_compiler() + self.graph = compiler.compile(self.manifest) def _runtime_initialize(self): self.load_manifest() diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index f7cec475158..367462a08a3 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -3,6 +3,7 @@ import dbt.flags import dbt.compilation +from dbt.adapters.postgres import Plugin from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import NodeConfig, DependsOn, ParsedModelNode @@ -11,6 +12,8 @@ from datetime import datetime +from .utils import inject_adapter, clear_plugin, config_from_parts_or_dicts + class CompilerTest(unittest.TestCase): def assertEqualIgnoreWhitespace(self, a, b): @@ -34,10 +37,46 @@ def setUp(self): 'column_types': {}, 'tags': [], }) - self.mock_config = MagicMock(credentials=MagicMock(type='postgres')) + + project_cfg = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + } + profile_cfg = { + 'outputs': { + 'test': { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'thishostshouldnotexist', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + }, + 'target': 'test' + } + + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + self._generate_runtime_model_patch = patch.object(dbt.compilation, 'generate_runtime_model') self.mock_generate_runtime_model = self._generate_runtime_model_patch.start() + inject_adapter(Plugin.adapter(self.config), Plugin) + + # self.mock_adapter = PostgresAdapter MagicMock(type=MagicMock(return_value='postgres')) + # self.mock_adapter.Relation = + # self.mock_adapter.get_compiler.return_value = dbt.compilation.Compiler + # self.mock_plugin = MagicMock( + # adapter=MagicMock( + # credentials=MagicMock(return_value='postgres') + # ) + # ) + # inject_adapter(self.mock_adapter, self.mock_plugin) + # so we can make an adapter + def mock_generate_runtime_model_context(model, config, manifest): def ref(name): result = f'__dbt__CTE__{name}' @@ -50,6 +89,7 @@ def ref(name): def tearDown(self): self._generate_runtime_model_patch.stop() + clear_plugin(Plugin) def test__prepend_ctes__already_has_cte(self): ephemeral_config = self.model_config.replace(materialized='ephemeral') @@ -118,7 +158,7 @@ def test__prepend_ctes__already_has_cte(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -202,7 +242,7 @@ def test__prepend_ctes__no_ctes(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -217,7 +257,7 @@ def test__prepend_ctes__no_ctes(self): result.injected_sql, manifest.nodes.get('model.root.view').compiled_sql) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes.get('model.root.view_no_cte'), manifest, @@ -295,7 +335,7 @@ def test__prepend_ctes(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, @@ -399,7 +439,7 @@ def test__prepend_ctes__cte_not_compiled(self): files={}, ) - compiler = dbt.compilation.Compiler(self.mock_config) + compiler = dbt.compilation.Compiler(self.config) with patch.object(compiler, 'compile_node') as compile_node: compile_node.return_value = compiled_ephemeral @@ -504,9 +544,7 @@ def test__prepend_ctes__multiple_levels(self): files={}, ) - compiler = dbt.compilation.Compiler( - MagicMock(credentials=MagicMock(type='postgres')) - ) + compiler = dbt.compilation.Compiler(self.config) result, _ = compiler._recursively_prepend_ctes( manifest.nodes['model.root.view'], manifest, diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 0e545cdf3da..f62733f7b29 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -367,7 +367,6 @@ def get_include_paths(): def config(): return config_from_parts_or_dicts(PROJECT_DATA, PROFILE_DATA) - @pytest.fixture def manifest_fx(config): return mock_manifest(config) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index ebeb5b9b306..f3e773b4093 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -16,7 +16,7 @@ from psycopg2 import extensions as psycopg2_extensions from psycopg2 import DatabaseError -from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection, TestAdapterConversions, load_internal_manifest_macros +from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection, TestAdapterConversions, load_internal_manifest_macros, clear_plugin class TestPostgresAdapter(unittest.TestCase): @@ -297,6 +297,7 @@ def tearDown(self): self.qh_patch.stop() self.patcher.stop() self.load_patch.stop() + clear_plugin(PostgresPlugin) def test_quoting_on_drop_schema(self): relation = self.adapter.Relation.create( diff --git a/test/unit/utils.py b/test/unit/utils.py index 03701695e52..4701c4bcd5c 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -112,6 +112,14 @@ def inject_plugin(plugin): FACTORY.plugins[key] = plugin +def inject_plugin_for(config): + # from dbt.adapters.postgres import Plugin, PostgresAdapter + from dbt.adapters.factory import FACTORY + FACTORY.load_plugin(config.credentials.type) + adapter = FACTORY.get_adapter(config) + return adapter + + def inject_adapter(value, plugin): """Inject the given adapter into the adapter factory, so your hand-crafted artisanal adapter will be available from get_adapter() as if dbt loaded it. From 123771163a5967a02beeee1bab851dba2a45133d Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 17 Aug 2020 15:08:20 -0600 Subject: [PATCH 2/5] hide more things from the context --- core/dbt/compilation.py | 2 +- core/dbt/context/base.py | 24 ++++----- core/dbt/context/configured.py | 20 +++---- core/dbt/context/context_config.py | 84 ++++++++++++++++-------------- core/dbt/context/providers.py | 24 ++++----- test/unit/test_source_config.py | 2 +- 6 files changed, 82 insertions(+), 74 deletions(-) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 8f68a818335..3c50e3bd20f 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -318,7 +318,7 @@ def compile(self, manifest: Manifest, write=True) -> Graph: return Graph(linker.graph) - def _write_node(self, node: NonSourceNode) -> NonSourceNode: + def _write_node(self, node: NonSourceCompiledNode) -> NonSourceNode: if not _is_writable(node): return node logger.debug(f'Writing injected SQL for node "{node.unique_id}"') diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 33c9681c15d..51d0ec44e73 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -105,39 +105,39 @@ def __init__( cli_vars: Mapping[str, Any], node: Optional[CompiledResource] = None ) -> None: - self.context: Mapping[str, Any] = context - self.cli_vars: Mapping[str, Any] = cli_vars - self.node: Optional[CompiledResource] = node - self.merged: Mapping[str, Any] = self._generate_merged() + self._context: Mapping[str, Any] = context + self._cli_vars: Mapping[str, Any] = cli_vars + self._node: Optional[CompiledResource] = node + self._merged: Mapping[str, Any] = self._generate_merged() def _generate_merged(self) -> Mapping[str, Any]: - return self.cli_vars + return self._cli_vars @property def node_name(self): - if self.node is not None: - return self.node.name + if self._node is not None: + return self._node.name else: return '' def get_missing_var(self, var_name): - dct = {k: self.merged[k] for k in self.merged} + dct = {k: self._merged[k] for k in self._merged} pretty_vars = json.dumps(dct, sort_keys=True, indent=4) msg = self.UndefinedVarError.format( var_name, self.node_name, pretty_vars ) - raise_compiler_error(msg, self.node) + raise_compiler_error(msg, self._node) def has_var(self, var_name: str): - return var_name in self.merged + return var_name in self._merged def get_rendered_var(self, var_name): - raw = self.merged[var_name] + raw = self._merged[var_name] # if bool/int/float/etc are passed in, don't compile anything if not isinstance(raw, str): return raw - return get_rendered(raw, self.context) + return get_rendered(raw, self._context) def __call__(self, var_name, default=_VAR_NOTSET): if self.has_var(var_name): diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index 4e820c9496f..baf2f60b34c 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -36,23 +36,23 @@ def __init__( project_name: str, ): super().__init__(context, config.cli_vars) - self.config = config - self.project_name = project_name + self._config = config + self._project_name = project_name def __call__(self, var_name, default=Var._VAR_NOTSET): - my_config = self.config.load_dependencies()[self.project_name] + my_config = self._config.load_dependencies()[self._project_name] # cli vars > active project > local project - if var_name in self.config.cli_vars: - return self.config.cli_vars[var_name] + if var_name in self._config.cli_vars: + return self._config.cli_vars[var_name] - if self.config.config_version == 2 and my_config.config_version == 2: - adapter_type = self.config.credentials.type - lookup = FQNLookup(self.project_name) - active_vars = self.config.vars.vars_for(lookup, adapter_type) + if self._config.config_version == 2 and my_config.config_version == 2: + adapter_type = self._config.credentials.type + lookup = FQNLookup(self._project_name) + active_vars = self._config.vars.vars_for(lookup, adapter_type) all_vars = MultiDict([active_vars]) - if self.config.project_name != my_config.project_name: + if self._config.project_name != my_config.project_name: all_vars.add(my_config.vars.vars_for(lookup, adapter_type)) if var_name in all_vars: diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index 3cdf9bf92e1..847417477a5 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -26,16 +26,16 @@ def __init__( node_type: NodeType, ): self._config = None - self.active_project: RuntimeConfig = active_project - self.own_project: Project = own_project + self._active_project: RuntimeConfig = active_project + self._own_project: Project = own_project - self.model = ModelParts( + self._model = ModelParts( fqn=fqn, resource_type=node_type, - package_name=self.own_project.project_name, + package_name=self._own_project.project_name, ) - self.updater = ConfigUpdater(active_project.credentials.type) + self._updater = ConfigUpdater(active_project.credentials.type) # the config options defined within the model self.in_model_config: Dict[str, Any] = {} @@ -43,12 +43,12 @@ def __init__( def get_default(self) -> Dict[str, Any]: defaults = {"enabled": True, "materialized": "view"} - if self.model.resource_type == NodeType.Seed: + if self._model.resource_type == NodeType.Seed: defaults['materialized'] = 'seed' - elif self.model.resource_type == NodeType.Snapshot: + elif self._model.resource_type == NodeType.Snapshot: defaults['materialized'] = 'snapshot' - if self.model.resource_type == NodeType.Test: + if self._model.resource_type == NodeType.Test: defaults['severity'] = 'ERROR' return defaults @@ -57,31 +57,34 @@ def build_config_dict(self, base: bool = False) -> Dict[str, Any]: defaults = self.get_default() active_config = self.load_config_from_active_project() - if self.active_project.project_name == self.own_project.project_name: - cfg = self.updater.merge( + if self._active_project.project_name == self._own_project.project_name: + cfg = self._updater.merge( defaults, active_config, self.in_model_config ) else: own_config = self.load_config_from_own_project() - cfg = self.updater.merge( + cfg = self._updater.merge( defaults, own_config, self.in_model_config, active_config ) return cfg def _translate_adapter_aliases(self, config: Dict[str, Any]): - return self.active_project.credentials.translate_aliases(config) + return self._active_project.credentials.translate_aliases(config) def update_in_model_config(self, config: Dict[str, Any]) -> None: config = self._translate_adapter_aliases(config) - self.updater.update_into(self.in_model_config, config) + self._updater.update_into(self.in_model_config, config) def load_config_from_own_project(self) -> Dict[str, Any]: - return self.updater.get_project_config(self.model, self.own_project) + return self._updater.get_project_config(self._model, self._own_project) def load_config_from_active_project(self) -> Dict[str, Any]: - return self.updater.get_project_config(self.model, self.active_project) + return self._updater.get_project_config( + self._model, + self._active_project, + ) T = TypeVar('T', bound=BaseConfig) @@ -89,12 +92,12 @@ def load_config_from_active_project(self) -> Dict[str, Any]: class ContextConfigGenerator: def __init__(self, active_project: RuntimeConfig): - self.active_project = active_project + self._active_project = active_project def get_node_project(self, project_name: str): - if project_name == self.active_project.project_name: - return self.active_project - dependencies = self.active_project.load_dependencies() + if project_name == self._active_project.project_name: + return self._active_project + dependencies = self._active_project.load_dependencies() if project_name not in dependencies: raise InternalException( f'Project name {project_name} not found in dependencies ' @@ -102,7 +105,7 @@ def get_node_project(self, project_name: str): ) return dependencies[project_name] - def project_configs( + def _project_configs( self, project: Project, fqn: List[str], resource_type: NodeType ) -> Iterator[Dict[str, Any]]: if resource_type == NodeType.Seed: @@ -123,18 +126,20 @@ def project_configs( yield result - def active_project_configs( + def _active_project_configs( self, fqn: List[str], resource_type: NodeType ) -> Iterator[Dict[str, Any]]: - return self.project_configs(self.active_project, fqn, resource_type) + return self._project_configs(self._active_project, fqn, resource_type) def _update_from_config( self, result: T, partial: Dict[str, Any], validate: bool = False ) -> T: - translated = self.active_project.credentials.translate_aliases(partial) + translated = self._active_project.credentials.translate_aliases( + partial + ) return result.update_from( translated, - self.active_project.credentials.type, + self._active_project.credentials.type, validate=validate ) @@ -153,13 +158,16 @@ def calculate_node_config( # because it might be invalid in the case of required config members # (such as on snapshots!) result = config_cls.from_dict({}, validate=False) - for fqn_config in self.project_configs(own_config, fqn, resource_type): + + project_configs = self._project_configs(own_config, fqn, resource_type) + for fqn_config in project_configs: result = self._update_from_config(result, fqn_config) + for config_call in config_calls: result = self._update_from_config(result, config_call) - if own_config.project_name != self.active_project.project_name: - for fqn_config in self.active_project_configs(fqn, resource_type): + if own_config.project_name != self._active_project.project_name: + for fqn_config in self._active_project_configs(fqn, resource_type): result = self._update_from_config(result, fqn_config) # this is mostly impactful in the snapshot config case @@ -174,21 +182,21 @@ def __init__( resource_type: NodeType, project_name: str, ) -> None: - self.config_calls: List[Dict[str, Any]] = [] - self.cfg_source = ContextConfigGenerator(active_project) - self.fqn = fqn - self.resource_type = resource_type - self.project_name = project_name + self._config_calls: List[Dict[str, Any]] = [] + self._cfg_source = ContextConfigGenerator(active_project) + self._fqn = fqn + self._resource_type = resource_type + self._project_name = project_name def update_in_model_config(self, opts: Dict[str, Any]) -> None: - self.config_calls.append(opts) + self._config_calls.append(opts) def build_config_dict(self, base: bool = False) -> Dict[str, Any]: - return self.cfg_source.calculate_node_config( - config_calls=self.config_calls, - fqn=self.fqn, - resource_type=self.resource_type, - project_name=self.project_name, + return self._cfg_source.calculate_node_config( + config_calls=self._config_calls, + fqn=self._fqn, + resource_type=self._resource_type, + project_name=self._project_name, base=base, ).to_dict() diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index a9345504acf..16d79213f30 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -523,37 +523,37 @@ def __init__( config: RuntimeConfig, node: CompiledResource, ) -> None: - self.node: CompiledResource - self.config: RuntimeConfig = config + self._node: CompiledResource + self._config: RuntimeConfig = config super().__init__(context, config.cli_vars, node=node) def packages_for_node(self) -> Iterable[Project]: - dependencies = self.config.load_dependencies() - package_name = self.node.package_name + dependencies = self._config.load_dependencies() + package_name = self._node.package_name - if package_name != self.config.project_name: + if package_name != self._config.project_name: if package_name not in dependencies: # I don't think this is actually reachable raise_compiler_error( f'Node package named {package_name} not found!', - self.node + self._node ) yield dependencies[package_name] - yield self.config + yield self._config def _generate_merged(self) -> Mapping[str, Any]: search_node: IsFQNResource - if isinstance(self.node, IsFQNResource): - search_node = self.node + if isinstance(self._node, IsFQNResource): + search_node = self._node else: - search_node = FQNLookup(self.node.package_name) + search_node = FQNLookup(self._node.package_name) - adapter_type = self.config.credentials.type + adapter_type = self._config.credentials.type merged = MultiDict() for project in self.packages_for_node(): merged.add(project.vars.vars_for(search_node, adapter_type)) - merged.add(self.cli_vars) + merged.add(self._cli_vars) return merged diff --git a/test/unit/test_source_config.py b/test/unit/test_source_config.py index 69377551e94..fd42e89d579 100644 --- a/test/unit/test_source_config.py +++ b/test/unit/test_source_config.py @@ -172,6 +172,6 @@ def test__context_config_wrong_type(self): model = mock.MagicMock(resource_type=NodeType.Model, fqn=['root', 'x'], project_name='root') with self.assertRaises(dbt.exceptions.CompilationException) as exc: - cfg.updater.get_project_config(model, self.root_project_config) + cfg._updater.get_project_config(model, self.root_project_config) self.assertIn('must be a dict', str(exc.exception)) From 8ad1551b153e871e4c5c62a35e1fa3ead23816ea Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Aug 2020 10:36:07 -0600 Subject: [PATCH 3/5] when you think about it, data tests are really just ctes --- core/dbt/compilation.py | 157 +++++++++++++++++++++--- core/dbt/contracts/graph/compiled.py | 75 ----------- core/dbt/task/test.py | 8 +- third-party-stubs/sqlparse/__init__.pyi | 7 ++ third-party-stubs/sqlparse/sql.pyi | 32 +++++ third-party-stubs/sqlparse/tokens.pyi | 6 + 6 files changed, 192 insertions(+), 93 deletions(-) create mode 100644 third-party-stubs/sqlparse/__init__.pyi create mode 100644 third-party-stubs/sqlparse/sql.pyi create mode 100644 third-party-stubs/sqlparse/tokens.pyi diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 3c50e3bd20f..e8713d686a4 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -3,6 +3,7 @@ from typing import List, Dict, Any, Tuple, cast, Optional import networkx as nx # type: ignore +import sqlparse from dbt import flags from dbt.adapters.factory import get_adapter @@ -15,10 +16,15 @@ COMPILED_TYPES, NonSourceNode, NonSourceCompiledNode, + CompiledDataTestNode, CompiledSchemaTestNode, ) from dbt.contracts.graph.parsed import ParsedNode -from dbt.exceptions import dependency_not_found, InternalException +from dbt.exceptions import ( + dependency_not_found, + InternalException, + RuntimeException, +) from dbt.graph import Graph from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType @@ -192,6 +198,90 @@ def _get_compiled_model( f'was not an ephemeral model: {cte_id}' ) + def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str: + """ + `ctes` is a list of InjectedCTEs like: + + [ + InjectedCTE( + id="cte_id_1", + sql="__dbt__CTE__ephemeral as (select * from table)", + ), + InjectedCTE( + id="cte_id_2", + sql="__dbt__CTE__events as (select id, type from events)", + ), + ] + + Given `sql` like: + + "with internal_cte as (select * from sessions) + select * from internal_cte" + + This will spit out: + + "with __dbt__CTE__ephemeral as (select * from table), + __dbt__CTE__events as (select id, type from events), + with internal_cte as (select * from sessions) + select * from internal_cte" + + (Whitespace enhanced for readability.) + """ + if len(ctes) == 0: + return sql + + parsed_stmts = sqlparse.parse(sql) + parsed = parsed_stmts[0] + + with_stmt = None + for token in parsed.tokens: + if token.is_keyword and token.normalized == 'WITH': + with_stmt = token + break + + if with_stmt is None: + # no with stmt, add one, and inject CTEs right at the beginning + first_token = parsed.token_first() + with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') + parsed.insert_before(first_token, with_stmt) + else: + # stmt exists, add a comma (which will come after injected CTEs) + trailing_comma = sqlparse.sql.Token( + sqlparse.tokens.Punctuation, ',' + ) + parsed.insert_after(with_stmt, trailing_comma) + + token = sqlparse.sql.Token( + sqlparse.tokens.Keyword, + ", ".join(c.sql for c in ctes) + ) + parsed.insert_after(with_stmt, token) + + return str(parsed) + + def _model_prepend_ctes( + self, + model: NonSourceCompiledNode, + prepended_ctes: List[InjectedCTE] + ) -> NonSourceCompiledNode: + if model.compiled_sql is None: + raise RuntimeException( + 'Cannot prepend ctes to an unparsed node', model + ) + injected_sql = self._inject_ctes_into_sql( + model.compiled_sql, + prepended_ctes, + ) + + model.extra_ctes_injected = True + model.extra_ctes = prepended_ctes + model.injected_sql = injected_sql + model.validate(model.to_dict()) + return model + + def _get_dbt_test_name(self) -> str: + return 'dbt__CTE__INTERNAL_test' + def _recursively_prepend_ctes( self, model: NonSourceCompiledNode, @@ -209,27 +299,63 @@ def _recursively_prepend_ctes( prepended_ctes: List[InjectedCTE] = [] + dbt_test_name = self._get_dbt_test_name() + for cte in model.extra_ctes: - cte_model = self._get_compiled_model( - manifest, - cte.id, - extra_context, - ) - cte_model, new_prepended_ctes = self._recursively_prepend_ctes( - cte_model, manifest, extra_context - ) - _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) + if cte.id == dbt_test_name: + sql = cte.sql + else: + cte_model = self._get_compiled_model( + manifest, + cte.id, + extra_context, + ) + cte_model, new_prepended_ctes = self._recursively_prepend_ctes( + cte_model, manifest, extra_context + ) + _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) - new_cte_name = self.add_ephemeral_prefix(cte_model.name) - sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)' + new_cte_name = self.add_ephemeral_prefix(cte_model.name) + sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)' _add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql)) - model.prepend_ctes(prepended_ctes) + model = self._model_prepend_ctes(model, prepended_ctes) manifest.update_node(model) return model, prepended_ctes + def _insert_ctes( + self, + compiled_node: NonSourceCompiledNode, + manifest: Manifest, + extra_context: Dict[str, Any], + ) -> NonSourceCompiledNode: + """Insert the CTEs for the model.""" + + # for data tests, we need to insert a special CTE at the end of the + # list containing the test query, and then have the "real" query be a + # select count(*) from that model. + # the benefit of doing it this way is that _insert_ctes() can be + # rewritten for different adapters to handle databses that don't + # support CTEs, or at least don't have full support. + if isinstance(compiled_node, CompiledDataTestNode): + # the last prepend (so last in order) should be the data test body. + # then we can add our select count(*) from _that_ cte as the "real" + # compiled_sql, and do the regular prepend logic from CTEs. + name = self._get_dbt_test_name() + cte = InjectedCTE( + id=name, + sql=f' {name} as (\n{compiled_node.compiled_sql}\n)' + ) + compiled_node.extra_ctes.append(cte) + compiled_node.compiled_sql = f'\nselect count(*) from {name}' + + injected_node, _ = self._recursively_prepend_ctes( + compiled_node, manifest, extra_context + ) + return injected_node + def _compile_node( self, node: NonSourceNode, @@ -258,11 +384,12 @@ def _compile_node( compiled_node.compiled_sql = jinja.get_rendered( node.raw_sql, context, - node) + node, + ) compiled_node.compiled = True - injected_node, _ = self._recursively_prepend_ctes( + injected_node = self._insert_ctes( compiled_node, manifest, extra_context ) diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 6fc6f1fe5d7..1406475f969 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -17,11 +17,9 @@ ) from dbt.node_types import NodeType from dbt.contracts.util import Replaceable -from dbt.exceptions import RuntimeException from hologram import JsonSchemaMixin from dataclasses import dataclass, field -import sqlparse # type: ignore from typing import Optional, List, Union, Dict, Type @@ -45,19 +43,6 @@ class CompiledNode(ParsedNode, CompiledNodeMixin): extra_ctes: List[InjectedCTE] = field(default_factory=list) injected_sql: Optional[str] = None - def prepend_ctes(self, prepended_ctes: List[InjectedCTE]): - self.extra_ctes_injected = True - self.extra_ctes = prepended_ctes - if self.compiled_sql is None: - raise RuntimeException( - 'Cannot prepend ctes to an unparsed node', self - ) - self.injected_sql = _inject_ctes_into_sql( - self.compiled_sql, - prepended_ctes, - ) - self.validate(self.to_dict()) - def set_cte(self, cte_id: str, sql: str): """This is the equivalent of what self.extra_ctes[cte_id] = sql would do if extra_ctes were an OrderedDict @@ -146,66 +131,6 @@ def same_contents(self, other) -> bool: CompiledTestNode = Union[CompiledDataTestNode, CompiledSchemaTestNode] -def _inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str: - """ - `ctes` is a list of InjectedCTEs like: - - [ - InjectedCTE( - id="cte_id_1", - sql="__dbt__CTE__ephemeral as (select * from table)", - ), - InjectedCTE( - id="cte_id_2", - sql="__dbt__CTE__events as (select id, type from events)", - ), - ] - - Given `sql` like: - - "with internal_cte as (select * from sessions) - select * from internal_cte" - - This will spit out: - - "with __dbt__CTE__ephemeral as (select * from table), - __dbt__CTE__events as (select id, type from events), - with internal_cte as (select * from sessions) - select * from internal_cte" - - (Whitespace enhanced for readability.) - """ - if len(ctes) == 0: - return sql - - parsed_stmts = sqlparse.parse(sql) - parsed = parsed_stmts[0] - - with_stmt = None - for token in parsed.tokens: - if token.is_keyword and token.normalized == 'WITH': - with_stmt = token - break - - if with_stmt is None: - # no with stmt, add one, and inject CTEs right at the beginning - first_token = parsed.token_first() - with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') - parsed.insert_before(first_token, with_stmt) - else: - # stmt exists, add a comma (which will come after injected CTEs) - trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ',') - parsed.insert_after(with_stmt, trailing_comma) - - token = sqlparse.sql.Token( - sqlparse.tokens.Keyword, - ", ".join(c.sql for c in ctes) - ) - parsed.insert_after(with_stmt, token) - - return str(parsed) - - PARSED_TYPES: Dict[Type[CompiledNode], Type[ParsedResource]] = { CompiledAnalysisNode: ParsedAnalysisNode, CompiledModelNode: ParsedModelNode, diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index 74bd101abc9..200669ea839 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -41,10 +41,12 @@ def print_start_line(self): print_start_line(description, self.node_index, self.num_nodes) def execute_data_test(self, test: CompiledDataTestNode): - sql = ( - f'select count(*) as errors from (\n{test.injected_sql}\n) sbq' + # sql = ( + # f'select count(*) as errors from (\n{test.injected_sql}\n) sbq' + # ) + res, table = self.adapter.execute( + test.injected_sql, auto_begin=True, fetch=True ) - res, table = self.adapter.execute(sql, auto_begin=True, fetch=True) num_rows = len(table.rows) if num_rows != 1: diff --git a/third-party-stubs/sqlparse/__init__.pyi b/third-party-stubs/sqlparse/__init__.pyi new file mode 100644 index 00000000000..4e2a9a2b00d --- /dev/null +++ b/third-party-stubs/sqlparse/__init__.pyi @@ -0,0 +1,7 @@ +from typing import Tuple +from . import sql +from . import tokens + + +def parse(sql: str) -> Tuple[sql.Statement]: + ... diff --git a/third-party-stubs/sqlparse/sql.pyi b/third-party-stubs/sqlparse/sql.pyi new file mode 100644 index 00000000000..103a576d55a --- /dev/null +++ b/third-party-stubs/sqlparse/sql.pyi @@ -0,0 +1,32 @@ +from typing import Tuple, Iterable + + +class Token: + def __init__(self, ttype, value): + ... + + is_keyword: bool + normalized: str + + +class TokenList(Token): + tokens: Tuple[Token] + + def __getitem__(self, key) -> Token: + ... + + def __iter__(self) -> Iterable[Token]: + ... + + def insert_before(self, where, token, skip_ws=True): + ... + + def insert_after(self, where, token, skip_ws=True): + ... + + def token_first(self) -> Token: + ... + + +class Statement(TokenList): + ... diff --git a/third-party-stubs/sqlparse/tokens.pyi b/third-party-stubs/sqlparse/tokens.pyi new file mode 100644 index 00000000000..68e0ed8b097 --- /dev/null +++ b/third-party-stubs/sqlparse/tokens.pyi @@ -0,0 +1,6 @@ + +class _TokenType(tuple): + ... + +Keyword: _TokenType = _TokenType() +Punctuation: _TokenType = _TokenType() From 58a3cb4fbd1f6028095e273d9a8456586382fe89 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Aug 2020 12:09:23 -0600 Subject: [PATCH 4/5] changelog update --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10a8c5d778b..262503a10ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,11 +3,14 @@ ### Breaking changes - `adapter_macro` is no longer a macro, instead it is a builtin context method. Any custom macros that intercepted it by going through `context['dbt']` will need to instead access it via `context['builtins']` ([#2302](https://github.com/fishtown-analytics/dbt/issues/2302), [#2673](https://github.com/fishtown-analytics/dbt/pull/2673)) - `adapter_macro` is now deprecated. Use `adapter.dispatch` instead. +- Data tests are now written as CTEs instead of subqueries. Adapter plugins for adapters that don't support CTEs may require modification. ([#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) ### Under the hood - Upgraded snowflake-connector-python dependency to 2.2.10 and enabled the SSO token cache ([#2613](https://github.com/fishtown-analytics/dbt/issues/2613), [#2689](https://github.com/fishtown-analytics/dbt/issues/2689), [#2698](https://github.com/fishtown-analytics/dbt/pull/2698)) - Add deprecation warnings to anonymous usage tracking ([#2688](https://github.com/fishtown-analytics/dbt/issues/2688), [#2710](https://github.com/fishtown-analytics/dbt/issues/2710)) +- Data tests now behave like dbt CTEs ([#2609](https://github.com/fishtown-analytics/dbt/issues/2609), [#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) +- Adapter plugins can now override the CTE prefix by overriding their `Relation` attribute with a class that has a custom `add_ephemeral_prefix` implementation. ([#2660](https://github.com/fishtown-analytics/dbt/issues/2660), [#2712](https://github.com/fishtown-analytics/dbt/pull/2712)) ### Features - Add better retry support when using the BigQuery adapter ([#2694](https://github.com/fishtown-analytics/dbt/pull/2694), follow-up to [#1963](https://github.com/fishtown-analytics/dbt/pull/1963)) @@ -17,6 +20,7 @@ - Macros in the current project can override internal dbt macros that are called through `execute_macros`. ([#2301](https://github.com/fishtown-analytics/dbt/issues/2301), [#2686](https://github.com/fishtown-analytics/dbt/pull/2686)) - Add state:modified and state:new selectors ([#2641](https://github.com/fishtown-analytics/dbt/issues/2641), [#2695](https://github.com/fishtown-analytics/dbt/pull/2695)) + ### Fixes - Fix Redshift table size estimation; e.g. 44 GB tables are no longer reported as 44 KB. [#2702](https://github.com/fishtown-analytics/dbt/issues/2702) From d3e4d3fbcbeeb887190913f2ef3c3711ed056ddf Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Aug 2020 14:33:46 -0600 Subject: [PATCH 5/5] pr feedback: remove commented out code --- core/dbt/task/test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index 200669ea839..f37c70c2fa9 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -41,9 +41,6 @@ def print_start_line(self): print_start_line(description, self.node_index, self.num_nodes) def execute_data_test(self, test: CompiledDataTestNode): - # sql = ( - # f'select count(*) as errors from (\n{test.injected_sql}\n) sbq' - # ) res, table = self.adapter.execute( test.injected_sql, auto_begin=True, fetch=True )