Source code for dorieh.platform.dbt.create_test

"""
A utility to generate a test based on a sample table.

The tool introspects a table, given as an input
and generates a set of queries each of which tests that the data in
a certain column has not changed.

The queries are output into a file that can be executed
as a single SQL query producing a table with the following columns:

1. Name of the column being tested
2. What value is being tested not to change:
   MD5 hash, number of distinct records, mean value or variance
3. Whether the value has changed (indicated by string `failed`) or remained
   the same (indicated by string `passed`)

Individual queries are separated by a comment strings:

* `-- Test case end`
* `-- Test case start`

so a test runner can execute them individually if desired

"""
#  Copyright (c) 2021. Harvard University
#
#  Developed by Research Software Engineering,
#  Faculty of Arts and Sciences, Research Computing (FAS RC)
#  Author: Michael A Bouzinier
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from enum import Enum
from typing import List, Dict, Optional, Callable

from dorieh.platform import init_logging
from dorieh.platform.db import Connection
from dorieh.platform.dbt.dbt_config import DBTConfig


[docs]class CType(Enum): categorical = "categorical" text = "text" integral = "integral" numeric = "numeric" date = 'date'
[docs]class Column: def __init__(self, name: str, ctype: CType, is_indexed: bool): self.name = name self.type = ctype self.is_indexed = is_indexed
[docs]class TableFingerprint: CATEGORICAL_THRESHOLD = 24 def __init__(self, context: DBTConfig = None): if not context: context = DBTConfig(None, __doc__).instantiate() self.context = context if not self.context.table: raise ValueError("'--table' is a required option") init_logging(name="generate-test-" + self.context.table) if '.' in self.context.table: t = self.context.table.split('.') self.table = t[1] self.schema = t[0] else: self.table = self.context.table self.schema = "public" self.fqtn = f"{self.schema}.{self.table}" self.count_distinct = "SELECT COUNT(DISTINCT {c}) FROM " + self.fqtn self.columns: List[Column] = [] self.test_cases: List[str] = [] cnxn = Connection(self.context.db, self.context.connection) self.catalog = cnxn.parameters["database"] self.get_columns()
[docs] def get_columns(self): i1 = """ coalesce(a.attname, (('{' || pg_get_expr( i.indexprs, i.indrelid ) || '}')::text[] )[k.i] ) AS index_column """ i2 = f""" CASE WHEN COLUMN_NAME IN ( SELECT {i1} FROM pg_index i CROSS JOIN LATERAL unnest(i.indkey) WITH ORDINALITY AS k(attnum, i) LEFT JOIN pg_attribute AS a ON i.indrelid = a.attrelid AND k.attnum = a.attnum WHERE i.indrelid = '{self.fqtn}'::regclass AND k.i = 1 ) THEN true ELSE false END AS indexed """ sql = f""" SELECT COLUMN_NAME, data_type, {i2} FROM information_schema.columns WHERE table_name = '{self.table}' AND table_schema = '{self.schema}' AND table_catalog = '{self.catalog}' ORDER BY 1 """ with Connection(self.context.db, self.context.connection) as cnxn: with cnxn.cursor() as cursor: cursor.execute(sql) columns = [row for row in cursor] for c in columns: if c[1] in ['integer']: t = CType.integral elif c[1] in ['numeric']: t = CType.numeric elif c[1] in ['date']: t = CType.date elif c[1] in ['USER-DEFINED']: continue else: t = CType.text dc = 1024*1024*1024 if t != 'numeric': with cnxn.cursor() as cursor: cursor.execute(self.count_distinct.format(c=c[0])) for row in cursor: dc = row[0] if dc < self.CATEGORICAL_THRESHOLD and t == CType.text: t = CType.categorical if t == CType.integral and dc > self.CATEGORICAL_THRESHOLD: t = CType.numeric self.columns.append(Column(c[0], t, c[2])) return
[docs] def get_categories(self) -> List[Column]: return [ c for c in self.columns if c.is_indexed and c.type in [CType.integral, CType.categorical] ]
[docs] def generate_tests(self): for c in self.columns: self.test_column(c)
[docs] def test_column(self, c: Column): if c.type in [CType.text, CType.categorical, CType.integral, CType.date]: s1 = self.count_distinct.format(c=c.name) test_case = self.test_exact(s1, c.name, "count distinct") self.test_cases.append(test_case) if c.type in [CType.text, CType.categorical]: s2 = f"SELECT MD5(string_agg({c.name}::varchar, '' order by {c.name})) FROM {self.fqtn}" test_case = self.test_exact(s2, c.name, "MD5 value") self.test_cases.append(test_case) if c.type in [CType.numeric, CType.integral]: s3 = f"SELECT AVG({c.name}) FROM {self.fqtn}" test_case = self.test_approximate(s3, c.name, "Mean value") self.test_cases.append(test_case) s4 = f"SELECT VARIANCE({c.name}) FROM {self.fqtn}" test_case = self.test_approximate(s4, c.name, "Variance") self.test_cases.append(test_case)
[docs] def test_case_sql(self, name: str, test: str, condition: str) -> str: return "SELECT \n" \ + f"\t'{self.fqtn}.{name}' As table_column,\n" \ + f"\t'{test}' As Testing,\n" \ + "\tCASE \n" \ + f"\t\tWHEN {condition} \n" \ + "\t\tTHEN true ELSE false END AS passed\n"
[docs] def test_exact(self, sql: str, name: str, test: str) -> str: v = None with Connection(self.context.db, self.context.connection) as cnxn: with cnxn.cursor() as cursor: cursor.execute(sql) for row in cursor: v = row[0] if v is None: condition = f"({sql}) IS NULL" else: condition = f"({sql}) = '{str(v)}'" test_case = self.test_case_sql(name, test, condition) return test_case
[docs] def test_approximate(self, sql: str, name: str, test: str) -> str: v = None with Connection(self.context.db, self.context.connection) as cnxn: with cnxn.cursor() as cursor: cursor.execute(sql) for row in cursor: v = row[0] if v is None: condition = f"({sql}) IS NULL" return self.test_case_sql(name, test, condition) if v >= 0: v1 = 0.99 * float(v) v2 = 1.01 * float(v) else: v2 = 0.99 * float(v) v1 = 1.01 * float(v) condition = f"({sql}) BETWEEN {str(v1)} AND {str(v2)}" test_case = self.test_case_sql(name, test, condition) return test_case
[docs] def union(self): return "UNION ALL\n".join(self.test_cases)
[docs] def write_test_script(self): first = True with open(self.context.script[0], "wt") as script: for test_case in self.test_cases: if not first: print("UNION ALL", file=script) first = False print("-- Test case start", file=script) print(test_case, file=script) print("-- Test case end", file=script)
if __name__ == '__main__': fingerprint = TableFingerprint() fingerprint.generate_tests() # print(fingerprint.union()) fingerprint.write_test_script()