From 0e9e05e6d5d13b3878714071ca2895e2312d02d5 Mon Sep 17 00:00:00 2001 From: Damien Broka Date: Fri, 13 Aug 2021 12:57:06 +0100 Subject: [PATCH] feat: semantic detection PoC --- Cargo.lock | 398 +++++++++++++++- Cargo.toml | 1 + core/src/graph/string/faker.rs | 2 +- semdet/Cargo.toml | 41 ++ semdet/build.rs | 22 + semdet/src/decode.rs | 102 +++++ semdet/src/encode.rs | 219 +++++++++ semdet/src/error.rs | 53 +++ semdet/src/lib.rs | 74 +++ semdet/src/module.rs | 426 ++++++++++++++++++ semdet/train/dummy.tch | Bin 0 -> 4634 bytes synth/Cargo.toml | 3 + synth/src/cli/import_utils.rs | 57 ++- synth/src/cli/mongo.rs | 5 +- synth/src/datasource/relational_datasource.rs | 2 +- 15 files changed, 1393 insertions(+), 12 deletions(-) create mode 100644 semdet/Cargo.toml create mode 100644 semdet/build.rs create mode 100644 semdet/src/decode.rs create mode 100644 semdet/src/encode.rs create mode 100644 semdet/src/error.rs create mode 100644 semdet/src/lib.rs create mode 100644 semdet/src/module.rs create mode 100644 semdet/train/dummy.tch diff --git a/Cargo.lock b/Cargo.lock index 1a7e56e94..9e162d6dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,29 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "arrow" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf189dff0c7e0f40588fc25adbe5bb6837b82fc61bb7cadf5d76de030f710bb" +dependencies = [ + "bitflags", + "chrono", + "csv", + "flatbuffers", + "hex", + "indexmap", + "lazy_static", + "lexical-core", + "multiversion", + "num 0.4.0", + "rand 0.8.4", + "regex", + "serde", + "serde_derive", + "serde_json", +] + [[package]] name = "async-attributes" version = "1.1.2" @@ -557,6 +580,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "bstr" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" +dependencies = [ + "lazy_static", + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "build_const" version = "0.2.2" @@ -603,6 +638,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +[[package]] +name = "bzip2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6afcd980b5f3a45017c57e57a2fcccbb351cc43a356ce117ef760ef8052b89b0" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cache-padded" version = "1.1.1" @@ -772,6 +828,15 @@ dependencies = [ "build_const", ] +[[package]] +name = "crc32fast" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -856,6 +921,28 @@ dependencies = [ "subtle 2.4.1", ] +[[package]] +name = "csv" +version = "1.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +dependencies = [ + "bstr", + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "ctor" version = "0.1.20" @@ -885,6 +972,36 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "curl" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "003cb79c1c6d1c93344c7e1201bb51c2148f24ec2bd9c253709d6b2efb796515" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe", + "openssl-sys", + "schannel", + "socket2 0.4.1", + "winapi 0.3.9", +] + +[[package]] +name = "curl-sys" +version = "0.4.45+curl-7.78.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de9e5a72b1c744eb5dd20b2be4d7eb84625070bb5c4ab9b347b70464ab1e62eb" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "winapi 0.3.9", +] + [[package]] name = "darling" version = "0.13.0" @@ -1133,6 +1250,29 @@ dependencies = [ "web-sys", ] +[[package]] +name = "flatbuffers" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef4c5738bcd7fad10315029c50026f83c9da5e4a21f8ed66826f43e0e2bde5f6" +dependencies = [ + "bitflags", + "smallvec", + "thiserror", +] + +[[package]] +name = "flate2" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3aec53de10fe96d7d8c565eb17f2c687bb5518a2ec453b5b1252964526abe0" +dependencies = [ + "cfg-if 1.0.0", + "crc32fast", + "libc", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1409,6 +1549,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62aca2aba2d62b4a7f5b33f3712cb1b0692779a56fb510499d5c0aa594daeaf3" + [[package]] name = "hashbrown" version = "0.11.2" @@ -1678,6 +1824,29 @@ dependencies = [ "regex", ] +[[package]] +name = "indoc" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" +dependencies = [ + "indoc-impl", + "proc-macro-hack", +] + +[[package]] +name = "indoc-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" +dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", + "unindent", +] + [[package]] name = "infer" version = "0.2.3" @@ -1898,6 +2067,15 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +[[package]] +name = "matrixmultiply" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a8a15b776d9dfaecd44b03c5828c2199cddff5247215858aac14624f8d6b741" +dependencies = [ + "rawpointer", +] + [[package]] name = "md-5" version = "0.8.0" @@ -2048,6 +2226,26 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "multiversion" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "025c962a3dd3cc5e0e520aa9c612201d127dcdf28616974961a649dca64f5373" +dependencies = [ + "multiversion-macros", +] + +[[package]] +name = "multiversion-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a3e2bde382ebf960c1f3e79689fa5941625fe9bf694a1cb64af3e85faff3af" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "native-tls" version = "0.2.7" @@ -2066,6 +2264,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9" +dependencies = [ + "matrixmultiply", + "num-complex 0.4.0", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "net2" version = "0.2.37" @@ -2118,10 +2329,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" dependencies = [ "num-bigint 0.3.2", - "num-complex", + "num-complex 0.3.1", + "num-integer", + "num-iter", + "num-rational 0.3.2", + "num-traits", +] + +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint 0.4.0", + "num-complex 0.4.0", "num-integer", "num-iter", - "num-rational", + "num-rational 0.4.0", "num-traits", ] @@ -2174,6 +2399,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -2207,6 +2441,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg 1.0.1", + "num-bigint 0.4.0", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -2345,6 +2591,25 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "paste" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" +dependencies = [ + "paste-impl", + "proc-macro-hack", +] + +[[package]] +name = "paste-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" +dependencies = [ + "proc-macro-hack", +] + [[package]] name = "pbkdf2" version = "0.3.0" @@ -2513,6 +2778,54 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "pyo3" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af205762ba65eec9f27a2fa1a57a40644e8e3368784b8c8b2f2de48f6e8ddd96" +dependencies = [ + "cfg-if 1.0.0", + "indoc", + "libc", + "parking_lot", + "paste", + "pyo3-build-config", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "755944027ce803c7238e59c5a18e59c1d0a4553db50b23e9ba209a568353028d" +dependencies = [ + "once_cell", +] + +[[package]] +name = "pyo3-macros" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd31b36bccfd902c78804bd96c28ea93eac6fa0ca311f9d21ef2230b6665b29a" +dependencies = [ + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c21c59ba36db9c823e931c662766b0dd01a030b1d96585b67d8857a96a56b972" +dependencies = [ + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2635,6 +2948,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.5.1" @@ -2690,6 +3009,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "regex-syntax" version = "0.6.25" @@ -2890,6 +3215,18 @@ dependencies = [ "libc", ] +[[package]] +name = "semantic-detection" +version = "0.1.0" +dependencies = [ + "arrow", + "fake", + "ndarray", + "pyo3", + "rand 0.8.4", + "tch", +] + [[package]] name = "semver" version = "0.9.0" @@ -2927,9 +3264,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c5e91e4240b46c4c19219d6cc84784444326131a4210f496f948d5cc827a29" +checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" dependencies = [ "indexmap", "itoa", @@ -3393,6 +3730,7 @@ name = "synth" version = "0.5.3" dependencies = [ "anyhow", + "arrow", "async-std", "async-trait", "beau_collector", @@ -3414,6 +3752,7 @@ dependencies = [ "rand 0.8.4", "regex", "rust_decimal", + "semantic-detection", "serde", "serde_json", "sqlx", @@ -3442,7 +3781,7 @@ dependencies = [ "humantime-serde", "lazy_static", "log", - "num", + "num 0.3.1", "num_cpus", "rand 0.8.4", "rand_regex", @@ -3512,6 +3851,22 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tch" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26115a35496899782cc16c8e5a51c9375e7f3c0e7dcafa9f1883338e938a04f" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand 0.8.4", + "thiserror", + "torch-sys", + "zip", +] + [[package]] name = "tempfile" version = "3.2.0" @@ -3741,6 +4096,19 @@ dependencies = [ "tokio 1.9.0", ] +[[package]] +name = "torch-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ade099a9f752b2e57c7852fe022e3542ec7e196675acbc9bb0d3a380e18c62" +dependencies = [ + "anyhow", + "cc", + "curl", + "libc", + "zip", +] + [[package]] name = "tower-service" version = "0.3.1" @@ -3872,6 +4240,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7" + [[package]] name = "universal-hash" version = "0.4.1" @@ -4186,3 +4560,17 @@ dependencies = [ "syn", "synstructure", ] + +[[package]] +name = "zip" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" +dependencies = [ + "byteorder", + "bzip2", + "crc32fast", + "flate2", + "thiserror", + "time 0.1.44", +] diff --git a/Cargo.toml b/Cargo.toml index 60ca881e6..8fbe9724f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "gen", "core", + "semdet", "synth", "dist/playground" ] diff --git a/core/src/graph/string/faker.rs b/core/src/graph/string/faker.rs index f1f7b68b7..b602f1c1d 100644 --- a/core/src/graph/string/faker.rs +++ b/core/src/graph/string/faker.rs @@ -32,7 +32,7 @@ impl Default for Locale { #[derive(Clone, Default, Deserialize, Debug, Serialize, PartialEq, Eq)] pub struct FakerArgs { #[serde(default)] - locales: Vec, + pub locales: Vec, } type FakerFunction = for<'r> fn(&'r mut dyn RngCore, &FakerArgs) -> String; diff --git a/semdet/Cargo.toml b/semdet/Cargo.toml new file mode 100644 index 000000000..c0d047ee7 --- /dev/null +++ b/semdet/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "semantic-detection" +version = "0.1.0" +edition = "2018" +authors = [ + "Damien Broka ", +] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "semantic_detection" +crate-type=["lib", "dylib"] + +[features] +default = [ "dummy" ] +train = [ "pyo3" ] +dummy = [ ] +torch = [ "tch" ] + +[dependencies.arrow] +version = "5.1.0" + +[dependencies.fake] +version = "2.4.1" +features = ["http"] + +[dependencies.pyo3] +version = "0.14.2" +optional = true +features = [ "extension-module" ] + +[dependencies.tch] +version = "0.5.0" +optional = true + +[dependencies.ndarray] +version = "0.15.3" + +[dev-dependencies.rand] +version = "0.8.4" \ No newline at end of file diff --git a/semdet/build.rs b/semdet/build.rs new file mode 100644 index 000000000..6af6f37b6 --- /dev/null +++ b/semdet/build.rs @@ -0,0 +1,22 @@ +use std::env; +use std::fs; +use std::io::Result; +use std::path::{Path, PathBuf}; + +fn main() -> Result<()> { + let pretrained_path = env::var_os("PRETRAINED") + .map(PathBuf::from) + .unwrap_or_else(|| Path::new("train").join("dummy.tch")); + let target_path = PathBuf::from(env::var_os("OUT_DIR").unwrap()).join("pretrained.tch"); + eprintln!( + "attempting to copy pretrained weights:\n\t<- {}\n\t-> {}", + pretrained_path.to_str().unwrap(), + target_path.to_str().unwrap() + ); + fs::copy(&pretrained_path, &target_path)?; + println!( + "cargo:rustc-env=PRETRAINED={}", + target_path.to_str().unwrap() + ); + Ok(()) +} diff --git a/semdet/src/decode.rs b/semdet/src/decode.rs new file mode 100644 index 000000000..5060bfe44 --- /dev/null +++ b/semdet/src/decode.rs @@ -0,0 +1,102 @@ +use ndarray::{ArrayView, Ix1}; + +use std::convert::Infallible; + +/// Trait for functions that produce a value from an input [`Array`](ndarray::Array) of prescribed +/// shape. +/// +/// The type parameter `D` should probably be a [`Dimension`](ndarray::Dimension) for implementations +/// to be useful. +pub trait Decoder { + type Err: std::error::Error + 'static; + + /// The type of values returned. + type Value; + + /// Compute and return a [`Self::Value`](Decoder::Value) from the input `tensor`. + /// + /// Implementations are allowed to panic if `tensor.shape() != self.shape()`. + fn decode(&self, tensor: ArrayView) -> Result; + + /// The shape that is required of a valid input of this decoder. + fn shape(&self) -> D; +} + +impl<'d, D, Dm> Decoder for &'d D +where + D: Decoder, +{ + type Err = D::Err; + type Value = D::Value; + + fn decode(&self, tensor: ArrayView) -> Result { + >::decode(self, tensor) + } + + fn shape(&self) -> Dm { + >::shape(self) + } +} + +pub struct MaxIndexDecoder { + index: Vec, +} + +impl MaxIndexDecoder { + /// # Panics + /// + /// If `index` is empty. + pub fn from_vec(index: Vec) -> Self { + assert!( + !index.is_empty(), + "passed `index` to `from_values` must not be empty" + ); + Self { index } + } +} + +impl Decoder for MaxIndexDecoder +where + S: Clone, +{ + type Err = Infallible; + type Value = Option; + + fn decode(&self, tensor: ArrayView) -> Result { + let (idx, by) = tensor + .iter() + .enumerate() + .max_by(|(_, l), (_, r)| l.total_cmp(r)) + .unwrap(); + if *by > (1. / tensor.len() as f32) { + let value = self.index.get(idx).unwrap().clone(); + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn shape(&self) -> Ix1 { + Ix1(self.index.len()) + } +} + +#[cfg(test)] +pub mod tests { + use super::Decoder; + use super::MaxIndexDecoder; + + use ndarray::{Array, Ix1}; + + #[test] + fn decoder_max_index() { + let decoder = MaxIndexDecoder::from_vec((0..10).collect()); + + for idx in 0..10 { + let mut input = Array::zeros(Ix1(10)); + *input.get_mut(idx).unwrap() = 1.; + let output = decoder.decode(input.view()).unwrap(); + assert_eq!(output, Some(idx)); + } + } +} diff --git a/semdet/src/encode.rs b/semdet/src/encode.rs new file mode 100644 index 000000000..6cd923be4 --- /dev/null +++ b/semdet/src/encode.rs @@ -0,0 +1,219 @@ +use arrow::array::{GenericStringArray, StringOffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; + +use ndarray::{ArrayViewMut, Axis, Ix1, Slice}; + +use std::collections::HashSet; +use std::convert::Infallible; + +/// Trait for functions that compute an [`Array`](ndarray::Array) of prescribed shape from an input +/// [`RecordBatch`](arrow::record_batch::RecordBatch). +/// +/// The type parameter `D` should probably be a [`Dimension`](ndarray::Dimension) for +/// implementations to be useful. +pub trait Encoder { + type Err: std::error::Error + 'static; + + /// Compute the values from the `input` and store the result in the initialized mutable `buffer`. + /// + /// Implementations are allowed to panic if `buffer.shape() != self.shape()`. + fn encode<'f>( + &self, + input: &RecordBatch, + buffer: ArrayViewMut<'f, f32, D>, + ) -> Result<(), Self::Err>; + + /// The shape of the output of this encoder. + fn shape(&self) -> D; +} + +impl<'e, E, Dm> Encoder for &'e E +where + E: Encoder, +{ + type Err = E::Err; + + fn encode<'f>( + &self, + input: &RecordBatch, + buffer: ArrayViewMut<'f, f32, Dm>, + ) -> Result<(), Self::Err> { + >::encode(self, input, buffer) + } + + fn shape(&self) -> Dm { + >::shape(self) + } +} + +/// An [`Encoder`](Encoder) that simply counts how many rows in the input are present in a dictionary. +#[derive(Debug)] +pub struct Dictionary { + dict: HashSet, +} + +impl Dictionary { + /// Create a new dictionary from a collection of [`&str`](str). + pub fn new<'a, I>(values: I) -> Self + where + I: IntoIterator, + { + Self { + dict: values.into_iter().map(|s| s.to_string()).collect(), + } + } + + fn count(&self, data: &GenericStringArray) -> u64 { + data.into_iter() + .filter_map(|m_s| m_s.and_then(|s| self.dict.get(s))) + .count() as u64 + } +} + +impl Encoder for Dictionary { + type Err = Infallible; + + /// In the first column of `input`, count the number of rows which match an element of the + /// dictionary. + fn encode<'f>( + &self, + input: &RecordBatch, + mut buffer: ArrayViewMut<'f, f32, Ix1>, + ) -> Result<(), Self::Err> { + let column = input.column(0); + let count = match column.data_type() { + DataType::Utf8 => { + let sar: &GenericStringArray = column.as_any().downcast_ref().unwrap(); + Some(self.count(sar)) + } + DataType::LargeUtf8 => { + let sar: &GenericStringArray = column.as_any().downcast_ref().unwrap(); + Some(self.count(sar)) + } + _ => None, + } + .and_then(|matches| { + if column.len() != 0 { + Some(matches as f32) + } else { + None + } + }); + *buffer.get_mut(0).unwrap() = count.unwrap_or(f32::NAN); + Ok(()) + } + + fn shape(&self) -> Ix1 { + Ix1(1) + } +} + +/// An [`Encoder`](Encoder) that horizontally stacks the output of a collection of [`Encoder`](Encoder)s. +pub struct StackedEncoder { + stack: Vec, + shape: Ix1, +} + +impl StackedEncoder +where + D: Encoder, +{ + /// Construct a new [`StackedEncoder`](StackedEncoder) from a collection of [`Encoder`](Encoder)s. + pub fn from_vec(stack: Vec) -> Self { + let shape = stack.iter().map(|encoder| encoder.shape()[0]).sum(); + Self { + stack, + shape: Ix1(shape), + } + } +} + +impl Encoder for StackedEncoder +where + D: Encoder, +{ + type Err = D::Err; + + fn encode<'f>( + &self, + input: &RecordBatch, + mut buffer: ArrayViewMut<'f, f32, Ix1>, + ) -> Result<(), Self::Err> { + let mut idx = 0usize; + for encoder in self.stack.iter() { + let next = idx + encoder.shape()[0]; + let sliced = buffer.slice_axis_mut(Axis(0), Slice::from(idx..next)); + encoder.encode(input, sliced)?; + idx = next; + } + Ok(()) + } + + fn shape(&self) -> Ix1 { + self.shape + } +} + +#[cfg(test)] +pub mod tests { + use ndarray::{array, Array, Ix2}; + + use std::iter::once; + + use super::{Dictionary, Encoder, StackedEncoder}; + use crate::tests::*; + + macro_rules! encode { + ($encoder:ident, $input:ident, $output:expr) => {{ + $encoder.encode(&$input, $output).unwrap(); + }}; + ($encoder:ident, $input:ident) => {{ + let mut output = Array::zeros($encoder.shape()); + encode!($encoder, $input, output.view_mut()); + output + }}; + } + + #[test] + fn encoder_dictionary() { + let last_names = ::NAME_LAST_NAME; + let encoder = Dictionary::new(last_names.iter().copied()); + let names_array = string_array_of(fake::faker::name::en::LastName(), 1000); + let input = record_batch_of(once(names_array)); + let output = encode!(encoder, input); + assert_eq!(output, array![511.]); + } + + #[test] + fn encoder_stacked() { + let data = [ + ::NAME_LAST_NAME, + ::ADDRESS_COUNTRY_CODE, + ::JOB_FIELD, + ]; + let encoder = StackedEncoder::from_vec( + data.iter() + .map(|values| Dictionary::new(values.iter().copied())) + .collect(), + ); + + let mut output = Array::zeros(Ix2(encoder.stack.len(), encoder.shape()[0])); + let slices = output.outer_iter_mut(); + + let inputs = vec![ + string_array_of(fake::faker::name::en::LastName(), 1000), + string_array_of(fake::faker::address::en::CountryCode(), 1000), + string_array_of(fake::faker::job::en::Field(), 1000), + ]; + for (slice, array) in slices.zip(inputs) { + let input = record_batch_of(vec![array]); + encode!(encoder, input, slice); + } + + assert_eq!( + output, + array![[511.0, 0.0, 0.0], [0.0, 511.0, 1.0], [0.0, 21.0, 500.0]] + ); + } +} diff --git a/semdet/src/error.rs b/semdet/src/error.rs new file mode 100644 index 000000000..f880f7bb8 --- /dev/null +++ b/semdet/src/error.rs @@ -0,0 +1,53 @@ +use arrow::error::ArrowError; +use ndarray::ShapeError; + +#[derive(Debug)] +pub enum Error { + Arrow(ArrowError), + Shape(ShapeError), + Encoder(Box), + Model(Box), + Decoder(Box), + Implementation(String), +} + +impl Error { + pub fn encoder(err: E) -> Self { + Self::Encoder(Box::new(err)) + } + + pub fn model(err: E) -> Self { + Self::Model(Box::new(err)) + } + + pub fn decoder(err: E) -> Self { + Self::Decoder(Box::new(err)) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Arrow(arrow) => write!(f, "arrow error: {}", arrow), + Self::Shape(shape) => write!(f, "shape error: {}", shape), + Self::Encoder(encoder) => write!(f, "encoder error: {}", encoder), + Self::Model(model) => write!(f, "model error: {}", model), + Self::Decoder(decoder) => write!(f, "decoder error: {}", decoder), + Self::Implementation(msg) => write!(f, "implementation error: {}", msg), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(arrow: ArrowError) -> Self { + Self::Arrow(arrow) + } +} + +impl From for Error { + fn from(shape: ShapeError) -> Self { + Self::Shape(shape) + } +} diff --git a/semdet/src/lib.rs b/semdet/src/lib.rs new file mode 100644 index 000000000..bed04f2f1 --- /dev/null +++ b/semdet/src/lib.rs @@ -0,0 +1,74 @@ +#![feature(total_cmp)] + +pub mod encode; +pub use encode::Encoder; + +pub mod decode; +pub use decode::Decoder; + +pub mod module; +pub use module::Module; + +pub mod error; +pub use error::Error; + +#[cfg(feature = "train")] +#[pyo3::proc_macro::pymodule] +fn semantic_detection(py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { + let dummy = module::python_bindings::bind(py)?; + m.add_submodule(dummy)?; + Ok(()) +} + +#[cfg(test)] +pub mod tests { + use fake::{Dummy, Fake}; + use rand::{rngs::StdRng, SeedableRng}; + + use arrow::array::{ArrayRef, StringArray, StringBuilder}; + use arrow::record_batch::RecordBatch; + + use std::sync::Arc; + + pub fn rng() -> StdRng { + ::seed_from_u64(0xAAAAAAAAAAAAAAAA) + } + + pub fn string_array_of(f: F, len: usize) -> StringArray + where + String: Dummy, + { + let mut builder = StringBuilder::new(len); + let mut rng = rng(); + (0..len) + .into_iter() + .try_for_each(|_| builder.append_option(f.fake_with_rng::, _>(&mut rng))) + .unwrap(); + builder.finish() + } + + pub fn record_batch_of_with_names(iter: I) -> RecordBatch + where + I: IntoIterator, + S: AsRef, + A: arrow::array::Array + 'static, + { + RecordBatch::try_from_iter( + iter.into_iter() + .map(|(idx, array)| (idx.as_ref().to_string(), Arc::new(array) as ArrayRef)), + ) + .unwrap() + } + + pub fn record_batch_of(iter: I) -> RecordBatch + where + I: IntoIterator, + A: arrow::array::Array + 'static, + { + record_batch_of_with_names( + iter.into_iter() + .enumerate() + .map(|(k, v)| (k.to_string(), v)), + ) + } +} diff --git a/semdet/src/module.rs b/semdet/src/module.rs new file mode 100644 index 000000000..7d3aa8898 --- /dev/null +++ b/semdet/src/module.rs @@ -0,0 +1,426 @@ +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use ndarray::{Array, Array2, ArrayView, Dimension, Ix1, Ix2}; + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::{Decoder, Encoder, Error}; + +/// A builder for [`Module`](Module). +#[derive(Default)] +pub struct ModuleBuilder { + encoder: Option, + model: Option, + decoder: Option, +} + +impl ModuleBuilder { + pub fn new() -> Self { + Self::default() + } +} + +impl ModuleBuilder { + /// Set the [`Module`](Module)'s [`Encoder`](crate::Encoder). + pub fn encoder>(self, encoder: EE) -> ModuleBuilder { + ModuleBuilder:: { + encoder: Some(encoder), + model: self.model, + decoder: self.decoder, + } + } + + /// Set the [`Module`](Module)'s [`Decoder`](crate::Decoder). + pub fn decoder>(self, decoder: DD) -> ModuleBuilder { + ModuleBuilder:: { + encoder: self.encoder, + model: self.model, + decoder: Some(decoder), + } + } + + /// Set the [`Module`](Module)'s [`Model`](Model). + pub fn model>(self, model: MM) -> ModuleBuilder { + ModuleBuilder:: { + encoder: self.encoder, + model: Some(model), + decoder: self.decoder, + } + } +} + +impl ModuleBuilder +where + E: Encoder, + M: Model, + D: Decoder, +{ + /// Build the [`Module`](Module). + /// + /// # Panics + /// If one of encoder, model or decoder has not been set. + fn build(self) -> Result, Error> { + let encoder = self.encoder.expect("missing an encoder"); + + let model = self.model.expect("missing a model"); + + let decoder = self.decoder.expect("missing a decoder"); + + Ok(Module { + encoder, + model, + decoder, + }) + } +} + +/// Trait for functions that compute an output [`Array`](ndarray::Array) of variable shape from +/// an input [`ArrayView`](ndarray::ArrayView) of variable shape. +pub trait Model { + type Err: std::error::Error; + type Output: Dimension; + + fn forward(&self, input: ArrayView) -> Result, Self::Err>; +} + +/// Put together an [`Encoder`](crate::Encoder), a [`Model`](Model) and a [`Decoder`](crate::Decoder) +/// into a pipeline that computes outputs from an input [`RecordBatch`](arrow::record_batch::RecordBatch). +pub struct Module { + encoder: E, + model: M, + decoder: D, +} + +impl Module { + /// Create a new [`Module`](Module) builder. + pub fn builder() -> ModuleBuilder { + ModuleBuilder::new() + } +} + +impl Module +where + E: Encoder, + M: Model, + D: Decoder, +{ + /// Compute the result of passing an input [`RecordBatch`](arrow::record_batch::RecordBatch) + /// through the pipeline. + /// + /// Each column of `input` is passed separately to the [`Encoder`](crate::Encoder), [`Model`](Model) + /// and [`Decoder`](crate::Decoder). The output [`Decoder::Value`](crate::Decoder::Value) are + /// then assembled into a [`HashMap`](HashMap) with keys corresponding to `input` column names. + pub fn forward(&self, input: &RecordBatch) -> Result, Error> { + let fields = input.schema().fields().clone(); + let columns = input.columns().iter().cloned(); + + let mut buffer = Array2::zeros(Ix2(columns.len(), self.encoder.shape()[0])); + let slices = buffer.outer_iter_mut(); + + for ((field, column), slice) in fields.iter().cloned().zip(columns).zip(slices) { + let column_input = + RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![column])?; + self.encoder + .encode(&column_input, slice) + .map_err(Error::encoder)?; + } + + let output = self.model.forward(buffer.view())?; + + let slices = output.outer_iter(); + let mut decoded = HashMap::new(); + for (field, row) in fields.into_iter().zip(slices) { + decoded.insert( + field.name().to_string(), + self.decoder.decode(row).map_err(Error::decoder)?, + ); + } + Ok(decoded) + } +} + +pub mod dummy { + //! A simple proof-of-concept pipeline + use ndarray::{Array, ArrayView, Axis, Ix1, Ix2}; + + use super::{Model, Module}; + use crate::decode::{Decoder, MaxIndexDecoder}; + use crate::encode::{Dictionary, Encoder, StackedEncoder}; + + #[cfg(feature = "torch")] + use ndarray::ArrayD; + + #[cfg(feature = "torch")] + use tch::{jit::CModule, nn::Module as NnModule, TchError, Tensor}; + + #[cfg(feature = "torch")] + use std::{ + convert::{TryFrom, TryInto}, + io::{Cursor, Read}, + }; + + #[cfg(feature = "torch")] + static PRETRAINED: &[u8] = include_bytes!(env!("PRETRAINED")); + + use std::convert::Infallible; + + use crate::Error; + + macro_rules! dummy_feature { + ($id:ident, $locale:ident) => { + Dictionary::new( + ::$id + .into_iter() + .map(|s| *s), + ) + }; + } + + macro_rules! dummies { + {$(($id:ident, $locale:ident, $name:literal, $ty:path)$(,)?)+} => { + pub fn encoder() -> impl Encoder + Send { + StackedEncoder::from_vec(vec![ + $(dummy_feature!($id, $locale),)+ + ]) + } + + pub fn decoder() -> impl Decoder> + Send { + MaxIndexDecoder::from_vec(vec![ + $($name,)+ + ]) + } + + #[cfg(test)] + pub mod tests { + use arrow::record_batch::RecordBatch; + + use crate::tests::*; + + pub fn data(len: usize) -> RecordBatch { + record_batch_of_with_names(vec![ + $(($name, string_array_of($ty(fake::locales::$locale), len)),)+ + ]) + } + } + } + } + + dummies! { + (NAME_FIRST_NAME, EN, "name_first_name", fake::faker::name::raw::FirstName), + (NAME_LAST_NAME, EN, "name_last_name", fake::faker::name::raw::LastName), + (JOB_FIELD, EN, "job_field", fake::faker::job::raw::Field), + (ADDRESS_COUNTRY, EN, "address_country", fake::faker::address::raw::CountryName), + (ADDRESS_COUNTRY_CODE, EN, "address_country_code", fake::faker::address::raw::CountryCode), + (ADDRESS_TIME_ZONE, EN, "address_time_zone", fake::faker::address::raw::TimeZone), + (ADDRESS_STATE, EN, "address_state", fake::faker::address::raw::StateName), + (ADDRESS_STATE_ABBR, EN, "address_state_abbr", fake::faker::address::raw::StateAbbr), + (CURRENCY_NAME, EN, "currency_name", fake::faker::currency::raw::CurrencyName), + (CURRENCY_CODE, EN, "currency_code", fake::faker::currency::raw::CurrencyCode) + } + + #[cfg(feature = "torch")] + pub struct DummyCModule(CModule); + + #[cfg(feature = "torch")] + impl DummyCModule { + pub fn load_data(buffer: &mut R) -> Result { + let c_module = CModule::load_data(buffer)?; + Ok(Self(c_module)) + } + + pub fn pretrained() -> Self { + Self::load_data(&mut Cursor::new(PRETRAINED)) + .expect("failed to load pretrained module (torch backend)") + } + } + + #[cfg(feature = "torch")] + impl Model for DummyCModule { + type Err = Error; + type Output = Ix2; + + fn forward( + &self, + input: ArrayView, + ) -> Result, Self::Err> { + let input_t = Tensor::try_from(input).map_err(Error::model)?; + let output_t = self.0.forward(&input_t); + let output: ArrayD = (&output_t).try_into()?; + let output = output.into_dimensionality::()?; + Ok(output) + } + } + + pub struct DummyNativeModule; + + impl Model for DummyNativeModule { + type Err = Error; + type Output = Ix2; + + fn forward( + &self, + input: ArrayView, + ) -> Result, Self::Err> { + let max = input.fold_axis(Axis(0), f32::NEG_INFINITY, |m, v| m.max(*v)); + let mut exp = ((-max) + input).mapv(f32::exp); + let total = exp + .map_axis(Axis(1), |view| view.sum()) + .into_shape(Ix2(exp.shape()[0], 1)) + .unwrap(); + exp = exp / total; + Ok(exp) + } + } + + pub fn module() -> Module< + impl Encoder, + impl Model, + impl Decoder>, + > { + Module::builder() + .encoder(encoder()) + .decoder(decoder()) + .model(DummyNativeModule) + .build() + .expect("failed to build dummy module (torch backend)") + } + + #[cfg(feature = "torch")] + pub fn module_torch() -> Module< + impl Encoder, + impl Model, + impl Decoder>, + > { + Module::builder() + .encoder(encoder()) + .decoder(decoder()) + .model(DummyCModule::pretrained()) + .build() + .expect("failed to build dummy module (native backend)") + } +} + +#[cfg(feature = "train")] +pub mod python_bindings { + use arrow::{ + array::{Array as ArrowArrayTrait, StructArray}, + ffi::{ArrowArray, ArrowArrayRef}, + record_batch::RecordBatch, + }; + + use ndarray::{Array, Dimension}; + + use pyo3::prelude::*; + + use std::convert::Infallible; + + use super::*; + + trait IntoPyRes { + type Ok; + fn or_else_raise(self) -> PyResult; + } + + impl IntoPyRes for Result + where + E: std::error::Error, + { + type Ok = T; + fn or_else_raise(self) -> PyResult<::Ok> { + self.map_err(|arr_err| pyo3::panic::PanicException::new_err(arr_err.to_string())) + } + } + + pub unsafe fn import_record_batch(record_batch: &PyAny) -> PyResult { + let (p_array, p_schema) = ArrowArray::into_raw(ArrowArray::empty()); + record_batch + .getattr("_export_to_c")? + .call1((p_array as usize, p_schema as usize))?; + let c_array = ArrowArray::try_from_raw(p_array, p_schema).or_else_raise()?; + let s_array = StructArray::from(c_array.to_data().or_else_raise()?); + Ok(RecordBatch::from(&s_array)) + } + + pub unsafe fn export_record_batch( + pyarrow: &PyModule, + record_batch: RecordBatch, + ) -> PyResult<&PyAny> { + let struct_array = StructArray::from(record_batch); + let (p_array, p_schema) = struct_array.to_raw().or_else_raise()?; + let output = pyarrow + .getattr("RecordBatch")? + .getattr("_import_from_c")? + .call1((p_array as usize, p_schema as usize))?; + Ok(output) + } + + #[pyclass(name = "Encoder")] + struct BoundEncoder { + encoder: Box + Send>, + } + + #[pymethods] + impl BoundEncoder { + #[new] + fn new() -> Self { + Self { + encoder: Box::new(dummy::encoder()), + } + } + + fn encode<'p>(&self, py: Python<'p>, record_batch: &PyAny) -> PyResult<&'p PyAny> { + let record_batch = unsafe { import_record_batch(record_batch) }?; + let shape = self.encoder.shape(); + let mut buffer = Array::zeros(shape); + self.encoder + .encode(&record_batch, buffer.view_mut()) + .or_else_raise()?; + let as_py = buffer.into_raw_vec().into_py(py); + let py_tensor = py + .import("torch")? + .getattr("FloatTensor")? + .call1((as_py,))?; + py_tensor.getattr("reshape")?.call1((shape.into_pattern(),)) + } + } + + pub fn bind(py: Python) -> PyResult<&PyModule> { + let module = PyModule::new(py, "dummy")?; + module.add_class::()?; + Ok(module) + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + fn dummy_end_to_end(module: Module) + where + E: Encoder, + M: Model, + D: Decoder>, + { + let input = dummy::tests::data(1000); + let output = module.forward(&input).unwrap(); + for (column, target) in output.into_iter() { + assert_eq!(Some(column.as_str()), target) + } + } + + #[test] + fn dummy_end_to_end_native() { + let module = dummy::module(); + dummy_end_to_end(module) + } + + #[cfg(feature = "torch")] + #[test] + fn dummy_end_to_end_torch() { + let module = dummy::module_torch(); + dummy_end_to_end(module) + } +} diff --git a/semdet/train/dummy.tch b/semdet/train/dummy.tch new file mode 100644 index 0000000000000000000000000000000000000000..0b97f488eed4a938f0e581c3089b3e7c290cd994 GIT binary patch literal 4634 zcmbVQ2|U#48=o<*afNk8j3bN2ko&$ev?N2$p~e_9jccYc<7h0Q9J{g=8LFWW+M+_F z2sMNhZON4o>x$5!DDnS|Z9}Ww{qH-Uci#EE$M<=k=lwpv_jv~80O16IczHp8DZ(I8 z5H=J?3<$?-`eSKW#DSn-TWcW@dR>X-z+o^nGB6K=;d4CbNcP7E^HD%*anY)tY*1P# zHh>gB+D`!|drAP$m;fq<5EdMa@x$VRBCw(U7$TV*MB&)&$p?&LakPMNEG>XcqCi%0 z5i9rz5(yDP_74ljQxU84`1X(qv=D40ABB@GlL$^Ga_kQCqi_M08-a4*;Z5;M=*j`G z*%bw7D+Xx$S9PU9GGDaC0ZM3o(F$?kpsls!S89K;ft%9j-n^^C57j6?|Jagw?HB9VXHiYr;jONAMVIK{Si|F}EBTCAUKv9S zU+>Yoi&oCf^vNrz9jvLcW^UblYBwtCP#P@~SNna8g-vTtpAmm~W-n3xu@R(3ifll;y-;2%u0hWtW0zUw9KyAYsc;pw0jA0;F^4?+nFG;#qJeq zs{FWp^!prll!iImGN($o#;X9UD1-m7%*H%W0vUYUG6;XXU)X-k3N>`TWuq^~+Tcel zhpGd}6W9cZ0wRJOKuHv^1G`CdSqj8~FNwnGpyHqhh~x$n z|CBBcCjQy0^re?Z;aZ`T#EhO)DFDT;-=>rv2>uPjX|51jlSI<|OlZwjLa)$G^#W7CGUtLwe===ncg!j9I6x zypvJ%C&HjOrwXb0Fy_|OtDmY4IYzuH(1^00%XoP&@4}nZ&Zp_y^eAUG8@pF9`TAWg zY;ITG;9xWJ;~b3h5nx>N&&=G|XDvWh|INT_S!?#MO}zno(P-rHy4@@-;iPmIhIxYI zd3XD%bCxoS;`&GY-TZu;6LwQH|H%*>Jw=ud~O(jP% zvL4D{BhSJF-E zIsLYEfB`!C>FQt@XkY5q1v{h(I)v=Ouh&Yyz6si|gexcO=c43H2EYq?|^xmvbg7)k=+;_jm#`+vuB4mWKMPnv; z#HF5qlB}LtaDMO#5);~`RL30N&oru$d)0ndNxLQ@l3E#5w|EFmuY8(b@j`S7@C6=T z89AjCLEzYke|vNqAgO;mIzku;xDLo9EZb3St;IGxs8kn6GSgDv)|iK96vt1B`hCYM zGjr$xnntb{)q{+)@M2QqzNc-Rm!eGNFTu~GiXm9tG41U!k6n0LJ6>Io)4`2p$W$p+ z@48$CUO?;U&r%6GUV4lAjAhxFT?{`_N%s#i=azMKVK;VUTQ2H8*7Fp?)T(^2RWFT) zz+V4OJLZ{KeywuK{gTqP2BqyP(W-u&c;n=ZWP;UwWwfO)d3W0Toa^_(p;g;Ntz9V1 zHHMNH%NT1qGRJaS7v8~aLb^G9$U%1mVikt&6uk~Qy~wS9QMMG-koTY#VwVwkJ;baf zDK%PqSvE}hb)(Boh4vkqA#z?`kmnuHTLs(Dc!q5+dD`E`GahAXOxvGJ^6vCP zO;_=4%QTsgEQ&_85ng$0KKipxpX6YY@Zo_Dl5F0gQ-V%~_HBbDJL{<1LKyOSlF8AE zcb>ZL!5}+lPRv2Q&ekOc9!@F!)LG!cW&!b6t%h`f0cRn9_l7h8@&11-SnIf)D||?B zdypcF6g%6K{ZkoiNI@|kb#y1;K;eD{vRuFTstTAD+;kX<@MrBRsHWdcn|_gV`E<5) zr(lkc>HNW01G8!Qs{7(?O!?_--Ebu@5rt<2x#O`T7-7v9;@L=p9?bn zDLy}NO{Sf2?Wo&Q9xYMLX;|M40$1GcPDdJ_e_!Z4by;M7q{KbQq}8uYIk$jInF*Co z$kvOV*t=siEjHWdYAk9(!qgyFJ+vw6XN``y8j{FHD@$YQ!cJb^ji@_D^#c9lU&?mmfHO%4E3ruSBAN3uRRgfL- z$T||RJ56`LFemaLJhda-`~*_#RZ+TKFY>4fDel;fK}g)9x^J*jthX>>c+hkxhQ@@v z9C5T=)CCK|52$a=!aPD~Xb3!Mpk$s`nVjV?-lEk)n6~Y4<_Ix~X0YfZUN!%vwvH%#~#|E|bC0Z**0H`dRMkX3YP>{DMvQ@m=Gcb*a( z6FXJ35W*YL_M^C_LYcavM|gO+i&}M48uD!}{CQdjY3lrUw8!_r%qPs!yqtTN>#c|E z%M6P-4mxU!7Nt9%+J%4-bz}Q3QGMkTwYuU(GA@YaW^IL!p&kqK9X%JHGbLr0>oAv@ zq`Li)ntTT77ef4Kt6wnfB1+6F@uOmL5A;Ok_F@OVQc{t(GWysVjc0+K2YSld2(`fz z%h`|A;latKqj#f=d#^VW@3=}QiyZIOfS^Ymd49@nk}H1uLff!F`bWrv1e%e)m4m`v zzlCLwvxA9HsUJHFw^d8ue^EP_QhVfK=~78tb9gk#o=l@Mp#~Sy&_im3{V{NjI zGW}hGGd(Ksq-tl@$I9o=$n^_y%$ZYfqEnW3xISAJX|ph>;n}&&)wW5d|7_&ale}f2 zF+0I4iSOmzXfqS^g660Sxl>#td#_GNiZDay5yC#Uv(}HJxNfJ?@n`Kq+8n)YVUZ;h zw-N@Gyhm|o-u$HT29LjQp#Ld~bGR@Q7U@fsa~I8Y4<~n}a7A4d<#P0Ow&(J8AJ^m0 zzu~YO?siM(x%bHq<|qf@ByIMg`Q@_C5l;Qg_6)rzpHvg(3X-1Q&YhrTZ2=#49@`^3 zao$a>N_gVak@UuFpi2GAFZkLUQb8CF9Gv7elB^rTBeluGu{BUQP`NUBAhz!v6!efN=0r zri1>N8RwV)ugF`$tZrsA<2i5Vi~caqnJZ|7Y%jAtc!$MxBTO-4?oy$X(eg^nwCw~Q z!_5T(nE{CuJ1E^anipRe&6B7!EQv<_BenaF@w?heNF1UG^#DYI*!A1eH-Rj^yzs;E zp;X{@x3xCe@M7OE+WKp!4j=ov`7LL^8wVhnT~Q8TcJn2tU7P<2vj^6duQS{~R;QDDYv=|6}g|04mU}( namespace.put_collection( &Name::from_str(table_name)?, - Collection::try_from((datasource, column_infos))?.collection, + Collection::try_from((datasource, column_infos.clone()))?.collection, )?; + + let module = semantic_detection::module::dummy::module(); + let values = task::block_on(datasource.get_deterministic_samples(table_name))?; + let mut pivoted = HashMap::new(); + for value in values.iter() { + let row = value.as_object().unwrap(); + for (column, field) in row.iter() { + if let Some(content) = field.as_str() { + pivoted + .entry(column.to_string()) + .or_insert_with(Vec::new) + .push(Some(content)); + } + + if field.is_null() { + if let Some(values) = pivoted.get_mut(column) { + values.push(None); + } + } + } + } + + let column_infos = column_infos + .into_iter() + .map(|ci| (ci.column_name.to_string(), ci)) + .collect::>(); + let pivoted = pivoted.into_iter().map(|(k, v)| (k, Arc::new(StringArray::from(v)) as ArrayRef)); + let record_batch = RecordBatch::try_from_iter(pivoted).unwrap(); + let target = module.forward(&record_batch).unwrap(); + for (column, generator) in target { + let column_meta = column_infos.get(&column).unwrap(); + if let Some(generator) = generator { + let field_ref = FieldRef::new(& if column_meta.is_nullable { + format!("{}.content.{}.0", table_name, &column_meta.column_name) + } else { + format!("{}.content.{}", table_name, &column_meta.column_name) + })?; + if let Content::String(string_content) = namespace.get_s_node_mut(&field_ref)? { + *string_content = StringContent::Faker(FakerContent { + generator: generator.to_string(), + locales: vec![], + args: FakerArgs { + locales: vec![Locale::EN] + } + }) + } + } + } } Ok(()) diff --git a/synth/src/cli/mongo.rs b/synth/src/cli/mongo.rs index 6e5e67497..92e5ffaa8 100644 --- a/synth/src/cli/mongo.rs +++ b/synth/src/cli/mongo.rs @@ -8,7 +8,6 @@ use mongodb::{bson::Document, options::ClientOptions, sync::Client}; use serde_json::Value; use std::collections::BTreeMap; use std::convert::TryInto; -use std::iter::FromIterator; use std::str::FromStr; use synth_core::graph::prelude::content::number_content::U64; use synth_core::graph::prelude::number_content::I64; @@ -16,7 +15,7 @@ use synth_core::graph::prelude::{NumberContent, ObjectContent, RangeStep}; use synth_core::schema::number_content::F64; use synth_core::schema::{ ArrayContent, BoolContent, Categorical, ChronoValueType, DateTimeContent, FieldContent, - OneOfContent, RegexContent, StringContent, + RegexContent, StringContent, }; use synth_core::{Content, Name, Namespace}; @@ -126,7 +125,7 @@ fn bson_to_content(bson: &Bson) -> Content { Content::Array(ArrayContent { length: Box::new(length), - content: Box::new(Content::OneOf(OneOfContent::from_iter(content_iter))), + content: Box::new(Content::OneOf(content_iter.collect())), }) } Bson::Document(doc) => doc_to_content(doc), diff --git a/synth/src/datasource/relational_datasource.rs b/synth/src/datasource/relational_datasource.rs index a350fc31b..40cee94af 100644 --- a/synth/src/datasource/relational_datasource.rs +++ b/synth/src/datasource/relational_datasource.rs @@ -8,7 +8,7 @@ use synth_core::Content; const DEFAULT_INSERT_BATCH_SIZE: usize = 1000; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ColumnInfo { pub(crate) column_name: String, pub(crate) ordinal_position: i32,