diff --git a/Cargo.lock b/Cargo.lock index 1a7e56e94..9e162d6dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,29 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "arrow" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf189dff0c7e0f40588fc25adbe5bb6837b82fc61bb7cadf5d76de030f710bb" +dependencies = [ + "bitflags", + "chrono", + "csv", + "flatbuffers", + "hex", + "indexmap", + "lazy_static", + "lexical-core", + "multiversion", + "num 0.4.0", + "rand 0.8.4", + "regex", + "serde", + "serde_derive", + "serde_json", +] + [[package]] name = "async-attributes" version = "1.1.2" @@ -557,6 +580,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "bstr" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" +dependencies = [ + "lazy_static", + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "build_const" version = "0.2.2" @@ -603,6 +638,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +[[package]] +name = "bzip2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6afcd980b5f3a45017c57e57a2fcccbb351cc43a356ce117ef760ef8052b89b0" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cache-padded" version = "1.1.1" @@ -772,6 +828,15 @@ dependencies = [ "build_const", ] +[[package]] +name = "crc32fast" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -856,6 +921,28 @@ dependencies = [ "subtle 2.4.1", ] +[[package]] +name = "csv" +version = "1.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +dependencies = [ + "bstr", + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "ctor" version = "0.1.20" @@ -885,6 +972,36 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "curl" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "003cb79c1c6d1c93344c7e1201bb51c2148f24ec2bd9c253709d6b2efb796515" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe", + "openssl-sys", + "schannel", + "socket2 0.4.1", + "winapi 0.3.9", +] + +[[package]] +name = "curl-sys" +version = "0.4.45+curl-7.78.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de9e5a72b1c744eb5dd20b2be4d7eb84625070bb5c4ab9b347b70464ab1e62eb" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "winapi 0.3.9", +] + [[package]] name = "darling" version = "0.13.0" @@ -1133,6 +1250,29 @@ dependencies = [ "web-sys", ] +[[package]] +name = "flatbuffers" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef4c5738bcd7fad10315029c50026f83c9da5e4a21f8ed66826f43e0e2bde5f6" +dependencies = [ + "bitflags", + "smallvec", + "thiserror", +] + +[[package]] +name = "flate2" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3aec53de10fe96d7d8c565eb17f2c687bb5518a2ec453b5b1252964526abe0" +dependencies = [ + "cfg-if 1.0.0", + "crc32fast", + "libc", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1409,6 +1549,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62aca2aba2d62b4a7f5b33f3712cb1b0692779a56fb510499d5c0aa594daeaf3" + [[package]] name = "hashbrown" version = "0.11.2" @@ -1678,6 +1824,29 @@ dependencies = [ "regex", ] +[[package]] +name = "indoc" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" +dependencies = [ + "indoc-impl", + "proc-macro-hack", +] + +[[package]] +name = "indoc-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" +dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", + "unindent", +] + [[package]] name = "infer" version = "0.2.3" @@ -1898,6 +2067,15 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +[[package]] +name = "matrixmultiply" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a8a15b776d9dfaecd44b03c5828c2199cddff5247215858aac14624f8d6b741" +dependencies = [ + "rawpointer", +] + [[package]] name = "md-5" version = "0.8.0" @@ -2048,6 +2226,26 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "multiversion" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "025c962a3dd3cc5e0e520aa9c612201d127dcdf28616974961a649dca64f5373" +dependencies = [ + "multiversion-macros", +] + +[[package]] +name = "multiversion-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a3e2bde382ebf960c1f3e79689fa5941625fe9bf694a1cb64af3e85faff3af" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "native-tls" version = "0.2.7" @@ -2066,6 +2264,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9" +dependencies = [ + "matrixmultiply", + "num-complex 0.4.0", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "net2" version = "0.2.37" @@ -2118,10 +2329,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" dependencies = [ "num-bigint 0.3.2", - "num-complex", + "num-complex 0.3.1", + "num-integer", + "num-iter", + "num-rational 0.3.2", + "num-traits", +] + +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint 0.4.0", + "num-complex 0.4.0", "num-integer", "num-iter", - "num-rational", + "num-rational 0.4.0", "num-traits", ] @@ -2174,6 +2399,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -2207,6 +2441,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg 1.0.1", + "num-bigint 0.4.0", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -2345,6 +2591,25 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "paste" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" +dependencies = [ + "paste-impl", + "proc-macro-hack", +] + +[[package]] +name = "paste-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" +dependencies = [ + "proc-macro-hack", +] + [[package]] name = "pbkdf2" version = "0.3.0" @@ -2513,6 +2778,54 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "pyo3" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af205762ba65eec9f27a2fa1a57a40644e8e3368784b8c8b2f2de48f6e8ddd96" +dependencies = [ + "cfg-if 1.0.0", + "indoc", + "libc", + "parking_lot", + "paste", + "pyo3-build-config", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "755944027ce803c7238e59c5a18e59c1d0a4553db50b23e9ba209a568353028d" +dependencies = [ + "once_cell", +] + +[[package]] +name = "pyo3-macros" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd31b36bccfd902c78804bd96c28ea93eac6fa0ca311f9d21ef2230b6665b29a" +dependencies = [ + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c21c59ba36db9c823e931c662766b0dd01a030b1d96585b67d8857a96a56b972" +dependencies = [ + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2635,6 +2948,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.5.1" @@ -2690,6 +3009,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "regex-syntax" version = "0.6.25" @@ -2890,6 +3215,18 @@ dependencies = [ "libc", ] +[[package]] +name = "semantic-detection" +version = "0.1.0" +dependencies = [ + "arrow", + "fake", + "ndarray", + "pyo3", + "rand 0.8.4", + "tch", +] + [[package]] name = "semver" version = "0.9.0" @@ -2927,9 +3264,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c5e91e4240b46c4c19219d6cc84784444326131a4210f496f948d5cc827a29" +checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" dependencies = [ "indexmap", "itoa", @@ -3393,6 +3730,7 @@ name = "synth" version = "0.5.3" dependencies = [ "anyhow", + "arrow", "async-std", "async-trait", "beau_collector", @@ -3414,6 +3752,7 @@ dependencies = [ "rand 0.8.4", "regex", "rust_decimal", + "semantic-detection", "serde", "serde_json", "sqlx", @@ -3442,7 +3781,7 @@ dependencies = [ "humantime-serde", "lazy_static", "log", - "num", + "num 0.3.1", "num_cpus", "rand 0.8.4", "rand_regex", @@ -3512,6 +3851,22 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tch" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26115a35496899782cc16c8e5a51c9375e7f3c0e7dcafa9f1883338e938a04f" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand 0.8.4", + "thiserror", + "torch-sys", + "zip", +] + [[package]] name = "tempfile" version = "3.2.0" @@ -3741,6 +4096,19 @@ dependencies = [ "tokio 1.9.0", ] +[[package]] +name = "torch-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ade099a9f752b2e57c7852fe022e3542ec7e196675acbc9bb0d3a380e18c62" +dependencies = [ + "anyhow", + "cc", + "curl", + "libc", + "zip", +] + [[package]] name = "tower-service" version = "0.3.1" @@ -3872,6 +4240,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7" + [[package]] name = "universal-hash" version = "0.4.1" @@ -4186,3 +4560,17 @@ dependencies = [ "syn", "synstructure", ] + +[[package]] +name = "zip" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" +dependencies = [ + "byteorder", + "bzip2", + "crc32fast", + "flate2", + "thiserror", + "time 0.1.44", +] diff --git a/Cargo.toml b/Cargo.toml index 60ca881e6..8fbe9724f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "gen", "core", + "semdet", "synth", "dist/playground" ] diff --git a/core/src/graph/string/faker.rs b/core/src/graph/string/faker.rs index f1f7b68b7..b602f1c1d 100644 --- a/core/src/graph/string/faker.rs +++ b/core/src/graph/string/faker.rs @@ -32,7 +32,7 @@ impl Default for Locale { #[derive(Clone, Default, Deserialize, Debug, Serialize, PartialEq, Eq)] pub struct FakerArgs { #[serde(default)] - locales: Vec, + pub locales: Vec, } type FakerFunction = for<'r> fn(&'r mut dyn RngCore, &FakerArgs) -> String; diff --git a/semdet/Cargo.toml b/semdet/Cargo.toml new file mode 100644 index 000000000..c0d047ee7 --- /dev/null +++ b/semdet/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "semantic-detection" +version = "0.1.0" +edition = "2018" +authors = [ + "Damien Broka ", +] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "semantic_detection" +crate-type=["lib", "dylib"] + +[features] +default = [ "dummy" ] +train = [ "pyo3" ] +dummy = [ ] +torch = [ "tch" ] + +[dependencies.arrow] +version = "5.1.0" + +[dependencies.fake] +version = "2.4.1" +features = ["http"] + +[dependencies.pyo3] +version = "0.14.2" +optional = true +features = [ "extension-module" ] + +[dependencies.tch] +version = "0.5.0" +optional = true + +[dependencies.ndarray] +version = "0.15.3" + +[dev-dependencies.rand] +version = "0.8.4" \ No newline at end of file diff --git a/semdet/build.rs b/semdet/build.rs new file mode 100644 index 000000000..6af6f37b6 --- /dev/null +++ b/semdet/build.rs @@ -0,0 +1,22 @@ +use std::env; +use std::fs; +use std::io::Result; +use std::path::{Path, PathBuf}; + +fn main() -> Result<()> { + let pretrained_path = env::var_os("PRETRAINED") + .map(PathBuf::from) + .unwrap_or_else(|| Path::new("train").join("dummy.tch")); + let target_path = PathBuf::from(env::var_os("OUT_DIR").unwrap()).join("pretrained.tch"); + eprintln!( + "attempting to copy pretrained weights:\n\t<- {}\n\t-> {}", + pretrained_path.to_str().unwrap(), + target_path.to_str().unwrap() + ); + fs::copy(&pretrained_path, &target_path)?; + println!( + "cargo:rustc-env=PRETRAINED={}", + target_path.to_str().unwrap() + ); + Ok(()) +} diff --git a/semdet/src/decode.rs b/semdet/src/decode.rs new file mode 100644 index 000000000..5060bfe44 --- /dev/null +++ b/semdet/src/decode.rs @@ -0,0 +1,102 @@ +use ndarray::{ArrayView, Ix1}; + +use std::convert::Infallible; + +/// Trait for functions that produce a value from an input [`Array`](ndarray::Array) of prescribed +/// shape. +/// +/// The type parameter `D` should probably be a [`Dimension`](ndarray::Dimension) for implementations +/// to be useful. +pub trait Decoder { + type Err: std::error::Error + 'static; + + /// The type of values returned. + type Value; + + /// Compute and return a [`Self::Value`](Decoder::Value) from the input `tensor`. + /// + /// Implementations are allowed to panic if `tensor.shape() != self.shape()`. + fn decode(&self, tensor: ArrayView) -> Result; + + /// The shape that is required of a valid input of this decoder. + fn shape(&self) -> D; +} + +impl<'d, D, Dm> Decoder for &'d D +where + D: Decoder, +{ + type Err = D::Err; + type Value = D::Value; + + fn decode(&self, tensor: ArrayView) -> Result { + >::decode(self, tensor) + } + + fn shape(&self) -> Dm { + >::shape(self) + } +} + +pub struct MaxIndexDecoder { + index: Vec, +} + +impl MaxIndexDecoder { + /// # Panics + /// + /// If `index` is empty. + pub fn from_vec(index: Vec) -> Self { + assert!( + !index.is_empty(), + "passed `index` to `from_values` must not be empty" + ); + Self { index } + } +} + +impl Decoder for MaxIndexDecoder +where + S: Clone, +{ + type Err = Infallible; + type Value = Option; + + fn decode(&self, tensor: ArrayView) -> Result { + let (idx, by) = tensor + .iter() + .enumerate() + .max_by(|(_, l), (_, r)| l.total_cmp(r)) + .unwrap(); + if *by > (1. / tensor.len() as f32) { + let value = self.index.get(idx).unwrap().clone(); + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn shape(&self) -> Ix1 { + Ix1(self.index.len()) + } +} + +#[cfg(test)] +pub mod tests { + use super::Decoder; + use super::MaxIndexDecoder; + + use ndarray::{Array, Ix1}; + + #[test] + fn decoder_max_index() { + let decoder = MaxIndexDecoder::from_vec((0..10).collect()); + + for idx in 0..10 { + let mut input = Array::zeros(Ix1(10)); + *input.get_mut(idx).unwrap() = 1.; + let output = decoder.decode(input.view()).unwrap(); + assert_eq!(output, Some(idx)); + } + } +} diff --git a/semdet/src/encode.rs b/semdet/src/encode.rs new file mode 100644 index 000000000..6cd923be4 --- /dev/null +++ b/semdet/src/encode.rs @@ -0,0 +1,219 @@ +use arrow::array::{GenericStringArray, StringOffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; + +use ndarray::{ArrayViewMut, Axis, Ix1, Slice}; + +use std::collections::HashSet; +use std::convert::Infallible; + +/// Trait for functions that compute an [`Array`](ndarray::Array) of prescribed shape from an input +/// [`RecordBatch`](arrow::record_batch::RecordBatch). +/// +/// The type parameter `D` should probably be a [`Dimension`](ndarray::Dimension) for +/// implementations to be useful. +pub trait Encoder { + type Err: std::error::Error + 'static; + + /// Compute the values from the `input` and store the result in the initialized mutable `buffer`. + /// + /// Implementations are allowed to panic if `buffer.shape() != self.shape()`. + fn encode<'f>( + &self, + input: &RecordBatch, + buffer: ArrayViewMut<'f, f32, D>, + ) -> Result<(), Self::Err>; + + /// The shape of the output of this encoder. + fn shape(&self) -> D; +} + +impl<'e, E, Dm> Encoder for &'e E +where + E: Encoder, +{ + type Err = E::Err; + + fn encode<'f>( + &self, + input: &RecordBatch, + buffer: ArrayViewMut<'f, f32, Dm>, + ) -> Result<(), Self::Err> { + >::encode(self, input, buffer) + } + + fn shape(&self) -> Dm { + >::shape(self) + } +} + +/// An [`Encoder`](Encoder) that simply counts how many rows in the input are present in a dictionary. +#[derive(Debug)] +pub struct Dictionary { + dict: HashSet, +} + +impl Dictionary { + /// Create a new dictionary from a collection of [`&str`](str). + pub fn new<'a, I>(values: I) -> Self + where + I: IntoIterator, + { + Self { + dict: values.into_iter().map(|s| s.to_string()).collect(), + } + } + + fn count(&self, data: &GenericStringArray) -> u64 { + data.into_iter() + .filter_map(|m_s| m_s.and_then(|s| self.dict.get(s))) + .count() as u64 + } +} + +impl Encoder for Dictionary { + type Err = Infallible; + + /// In the first column of `input`, count the number of rows which match an element of the + /// dictionary. + fn encode<'f>( + &self, + input: &RecordBatch, + mut buffer: ArrayViewMut<'f, f32, Ix1>, + ) -> Result<(), Self::Err> { + let column = input.column(0); + let count = match column.data_type() { + DataType::Utf8 => { + let sar: &GenericStringArray = column.as_any().downcast_ref().unwrap(); + Some(self.count(sar)) + } + DataType::LargeUtf8 => { + let sar: &GenericStringArray = column.as_any().downcast_ref().unwrap(); + Some(self.count(sar)) + } + _ => None, + } + .and_then(|matches| { + if column.len() != 0 { + Some(matches as f32) + } else { + None + } + }); + *buffer.get_mut(0).unwrap() = count.unwrap_or(f32::NAN); + Ok(()) + } + + fn shape(&self) -> Ix1 { + Ix1(1) + } +} + +/// An [`Encoder`](Encoder) that horizontally stacks the output of a collection of [`Encoder`](Encoder)s. +pub struct StackedEncoder { + stack: Vec, + shape: Ix1, +} + +impl StackedEncoder +where + D: Encoder, +{ + /// Construct a new [`StackedEncoder`](StackedEncoder) from a collection of [`Encoder`](Encoder)s. + pub fn from_vec(stack: Vec) -> Self { + let shape = stack.iter().map(|encoder| encoder.shape()[0]).sum(); + Self { + stack, + shape: Ix1(shape), + } + } +} + +impl Encoder for StackedEncoder +where + D: Encoder, +{ + type Err = D::Err; + + fn encode<'f>( + &self, + input: &RecordBatch, + mut buffer: ArrayViewMut<'f, f32, Ix1>, + ) -> Result<(), Self::Err> { + let mut idx = 0usize; + for encoder in self.stack.iter() { + let next = idx + encoder.shape()[0]; + let sliced = buffer.slice_axis_mut(Axis(0), Slice::from(idx..next)); + encoder.encode(input, sliced)?; + idx = next; + } + Ok(()) + } + + fn shape(&self) -> Ix1 { + self.shape + } +} + +#[cfg(test)] +pub mod tests { + use ndarray::{array, Array, Ix2}; + + use std::iter::once; + + use super::{Dictionary, Encoder, StackedEncoder}; + use crate::tests::*; + + macro_rules! encode { + ($encoder:ident, $input:ident, $output:expr) => {{ + $encoder.encode(&$input, $output).unwrap(); + }}; + ($encoder:ident, $input:ident) => {{ + let mut output = Array::zeros($encoder.shape()); + encode!($encoder, $input, output.view_mut()); + output + }}; + } + + #[test] + fn encoder_dictionary() { + let last_names = ::NAME_LAST_NAME; + let encoder = Dictionary::new(last_names.iter().copied()); + let names_array = string_array_of(fake::faker::name::en::LastName(), 1000); + let input = record_batch_of(once(names_array)); + let output = encode!(encoder, input); + assert_eq!(output, array![511.]); + } + + #[test] + fn encoder_stacked() { + let data = [ + ::NAME_LAST_NAME, + ::ADDRESS_COUNTRY_CODE, + ::JOB_FIELD, + ]; + let encoder = StackedEncoder::from_vec( + data.iter() + .map(|values| Dictionary::new(values.iter().copied())) + .collect(), + ); + + let mut output = Array::zeros(Ix2(encoder.stack.len(), encoder.shape()[0])); + let slices = output.outer_iter_mut(); + + let inputs = vec![ + string_array_of(fake::faker::name::en::LastName(), 1000), + string_array_of(fake::faker::address::en::CountryCode(), 1000), + string_array_of(fake::faker::job::en::Field(), 1000), + ]; + for (slice, array) in slices.zip(inputs) { + let input = record_batch_of(vec![array]); + encode!(encoder, input, slice); + } + + assert_eq!( + output, + array![[511.0, 0.0, 0.0], [0.0, 511.0, 1.0], [0.0, 21.0, 500.0]] + ); + } +} diff --git a/semdet/src/error.rs b/semdet/src/error.rs new file mode 100644 index 000000000..f880f7bb8 --- /dev/null +++ b/semdet/src/error.rs @@ -0,0 +1,53 @@ +use arrow::error::ArrowError; +use ndarray::ShapeError; + +#[derive(Debug)] +pub enum Error { + Arrow(ArrowError), + Shape(ShapeError), + Encoder(Box), + Model(Box), + Decoder(Box), + Implementation(String), +} + +impl Error { + pub fn encoder(err: E) -> Self { + Self::Encoder(Box::new(err)) + } + + pub fn model(err: E) -> Self { + Self::Model(Box::new(err)) + } + + pub fn decoder(err: E) -> Self { + Self::Decoder(Box::new(err)) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Arrow(arrow) => write!(f, "arrow error: {}", arrow), + Self::Shape(shape) => write!(f, "shape error: {}", shape), + Self::Encoder(encoder) => write!(f, "encoder error: {}", encoder), + Self::Model(model) => write!(f, "model error: {}", model), + Self::Decoder(decoder) => write!(f, "decoder error: {}", decoder), + Self::Implementation(msg) => write!(f, "implementation error: {}", msg), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(arrow: ArrowError) -> Self { + Self::Arrow(arrow) + } +} + +impl From for Error { + fn from(shape: ShapeError) -> Self { + Self::Shape(shape) + } +} diff --git a/semdet/src/lib.rs b/semdet/src/lib.rs new file mode 100644 index 000000000..bed04f2f1 --- /dev/null +++ b/semdet/src/lib.rs @@ -0,0 +1,74 @@ +#![feature(total_cmp)] + +pub mod encode; +pub use encode::Encoder; + +pub mod decode; +pub use decode::Decoder; + +pub mod module; +pub use module::Module; + +pub mod error; +pub use error::Error; + +#[cfg(feature = "train")] +#[pyo3::proc_macro::pymodule] +fn semantic_detection(py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { + let dummy = module::python_bindings::bind(py)?; + m.add_submodule(dummy)?; + Ok(()) +} + +#[cfg(test)] +pub mod tests { + use fake::{Dummy, Fake}; + use rand::{rngs::StdRng, SeedableRng}; + + use arrow::array::{ArrayRef, StringArray, StringBuilder}; + use arrow::record_batch::RecordBatch; + + use std::sync::Arc; + + pub fn rng() -> StdRng { + ::seed_from_u64(0xAAAAAAAAAAAAAAAA) + } + + pub fn string_array_of(f: F, len: usize) -> StringArray + where + String: Dummy, + { + let mut builder = StringBuilder::new(len); + let mut rng = rng(); + (0..len) + .into_iter() + .try_for_each(|_| builder.append_option(f.fake_with_rng::, _>(&mut rng))) + .unwrap(); + builder.finish() + } + + pub fn record_batch_of_with_names(iter: I) -> RecordBatch + where + I: IntoIterator, + S: AsRef, + A: arrow::array::Array + 'static, + { + RecordBatch::try_from_iter( + iter.into_iter() + .map(|(idx, array)| (idx.as_ref().to_string(), Arc::new(array) as ArrayRef)), + ) + .unwrap() + } + + pub fn record_batch_of(iter: I) -> RecordBatch + where + I: IntoIterator, + A: arrow::array::Array + 'static, + { + record_batch_of_with_names( + iter.into_iter() + .enumerate() + .map(|(k, v)| (k.to_string(), v)), + ) + } +} diff --git a/semdet/src/module.rs b/semdet/src/module.rs new file mode 100644 index 000000000..7d3aa8898 --- /dev/null +++ b/semdet/src/module.rs @@ -0,0 +1,426 @@ +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use ndarray::{Array, Array2, ArrayView, Dimension, Ix1, Ix2}; + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::{Decoder, Encoder, Error}; + +/// A builder for [`Module`](Module). +#[derive(Default)] +pub struct ModuleBuilder { + encoder: Option, + model: Option, + decoder: Option, +} + +impl ModuleBuilder { + pub fn new() -> Self { + Self::default() + } +} + +impl ModuleBuilder { + /// Set the [`Module`](Module)'s [`Encoder`](crate::Encoder). + pub fn encoder>(self, encoder: EE) -> ModuleBuilder { + ModuleBuilder:: { + encoder: Some(encoder), + model: self.model, + decoder: self.decoder, + } + } + + /// Set the [`Module`](Module)'s [`Decoder`](crate::Decoder). + pub fn decoder>(self, decoder: DD) -> ModuleBuilder { + ModuleBuilder:: { + encoder: self.encoder, + model: self.model, + decoder: Some(decoder), + } + } + + /// Set the [`Module`](Module)'s [`Model`](Model). + pub fn model>(self, model: MM) -> ModuleBuilder { + ModuleBuilder:: { + encoder: self.encoder, + model: Some(model), + decoder: self.decoder, + } + } +} + +impl ModuleBuilder +where + E: Encoder, + M: Model, + D: Decoder, +{ + /// Build the [`Module`](Module). + /// + /// # Panics + /// If one of encoder, model or decoder has not been set. + fn build(self) -> Result, Error> { + let encoder = self.encoder.expect("missing an encoder"); + + let model = self.model.expect("missing a model"); + + let decoder = self.decoder.expect("missing a decoder"); + + Ok(Module { + encoder, + model, + decoder, + }) + } +} + +/// Trait for functions that compute an output [`Array`](ndarray::Array) of variable shape from +/// an input [`ArrayView`](ndarray::ArrayView) of variable shape. +pub trait Model { + type Err: std::error::Error; + type Output: Dimension; + + fn forward(&self, input: ArrayView) -> Result, Self::Err>; +} + +/// Put together an [`Encoder`](crate::Encoder), a [`Model`](Model) and a [`Decoder`](crate::Decoder) +/// into a pipeline that computes outputs from an input [`RecordBatch`](arrow::record_batch::RecordBatch). +pub struct Module { + encoder: E, + model: M, + decoder: D, +} + +impl Module { + /// Create a new [`Module`](Module) builder. + pub fn builder() -> ModuleBuilder { + ModuleBuilder::new() + } +} + +impl Module +where + E: Encoder, + M: Model, + D: Decoder, +{ + /// Compute the result of passing an input [`RecordBatch`](arrow::record_batch::RecordBatch) + /// through the pipeline. + /// + /// Each column of `input` is passed separately to the [`Encoder`](crate::Encoder), [`Model`](Model) + /// and [`Decoder`](crate::Decoder). The output [`Decoder::Value`](crate::Decoder::Value) are + /// then assembled into a [`HashMap`](HashMap) with keys corresponding to `input` column names. + pub fn forward(&self, input: &RecordBatch) -> Result, Error> { + let fields = input.schema().fields().clone(); + let columns = input.columns().iter().cloned(); + + let mut buffer = Array2::zeros(Ix2(columns.len(), self.encoder.shape()[0])); + let slices = buffer.outer_iter_mut(); + + for ((field, column), slice) in fields.iter().cloned().zip(columns).zip(slices) { + let column_input = + RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![column])?; + self.encoder + .encode(&column_input, slice) + .map_err(Error::encoder)?; + } + + let output = self.model.forward(buffer.view())?; + + let slices = output.outer_iter(); + let mut decoded = HashMap::new(); + for (field, row) in fields.into_iter().zip(slices) { + decoded.insert( + field.name().to_string(), + self.decoder.decode(row).map_err(Error::decoder)?, + ); + } + Ok(decoded) + } +} + +pub mod dummy { + //! A simple proof-of-concept pipeline + use ndarray::{Array, ArrayView, Axis, Ix1, Ix2}; + + use super::{Model, Module}; + use crate::decode::{Decoder, MaxIndexDecoder}; + use crate::encode::{Dictionary, Encoder, StackedEncoder}; + + #[cfg(feature = "torch")] + use ndarray::ArrayD; + + #[cfg(feature = "torch")] + use tch::{jit::CModule, nn::Module as NnModule, TchError, Tensor}; + + #[cfg(feature = "torch")] + use std::{ + convert::{TryFrom, TryInto}, + io::{Cursor, Read}, + }; + + #[cfg(feature = "torch")] + static PRETRAINED: &[u8] = include_bytes!(env!("PRETRAINED")); + + use std::convert::Infallible; + + use crate::Error; + + macro_rules! dummy_feature { + ($id:ident, $locale:ident) => { + Dictionary::new( + ::$id + .into_iter() + .map(|s| *s), + ) + }; + } + + macro_rules! dummies { + {$(($id:ident, $locale:ident, $name:literal, $ty:path)$(,)?)+} => { + pub fn encoder() -> impl Encoder + Send { + StackedEncoder::from_vec(vec![ + $(dummy_feature!($id, $locale),)+ + ]) + } + + pub fn decoder() -> impl Decoder> + Send { + MaxIndexDecoder::from_vec(vec![ + $($name,)+ + ]) + } + + #[cfg(test)] + pub mod tests { + use arrow::record_batch::RecordBatch; + + use crate::tests::*; + + pub fn data(len: usize) -> RecordBatch { + record_batch_of_with_names(vec![ + $(($name, string_array_of($ty(fake::locales::$locale), len)),)+ + ]) + } + } + } + } + + dummies! { + (NAME_FIRST_NAME, EN, "name_first_name", fake::faker::name::raw::FirstName), + (NAME_LAST_NAME, EN, "name_last_name", fake::faker::name::raw::LastName), + (JOB_FIELD, EN, "job_field", fake::faker::job::raw::Field), + (ADDRESS_COUNTRY, EN, "address_country", fake::faker::address::raw::CountryName), + (ADDRESS_COUNTRY_CODE, EN, "address_country_code", fake::faker::address::raw::CountryCode), + (ADDRESS_TIME_ZONE, EN, "address_time_zone", fake::faker::address::raw::TimeZone), + (ADDRESS_STATE, EN, "address_state", fake::faker::address::raw::StateName), + (ADDRESS_STATE_ABBR, EN, "address_state_abbr", fake::faker::address::raw::StateAbbr), + (CURRENCY_NAME, EN, "currency_name", fake::faker::currency::raw::CurrencyName), + (CURRENCY_CODE, EN, "currency_code", fake::faker::currency::raw::CurrencyCode) + } + + #[cfg(feature = "torch")] + pub struct DummyCModule(CModule); + + #[cfg(feature = "torch")] + impl DummyCModule { + pub fn load_data(buffer: &mut R) -> Result { + let c_module = CModule::load_data(buffer)?; + Ok(Self(c_module)) + } + + pub fn pretrained() -> Self { + Self::load_data(&mut Cursor::new(PRETRAINED)) + .expect("failed to load pretrained module (torch backend)") + } + } + + #[cfg(feature = "torch")] + impl Model for DummyCModule { + type Err = Error; + type Output = Ix2; + + fn forward( + &self, + input: ArrayView, + ) -> Result, Self::Err> { + let input_t = Tensor::try_from(input).map_err(Error::model)?; + let output_t = self.0.forward(&input_t); + let output: ArrayD = (&output_t).try_into()?; + let output = output.into_dimensionality::()?; + Ok(output) + } + } + + pub struct DummyNativeModule; + + impl Model for DummyNativeModule { + type Err = Error; + type Output = Ix2; + + fn forward( + &self, + input: ArrayView, + ) -> Result, Self::Err> { + let max = input.fold_axis(Axis(0), f32::NEG_INFINITY, |m, v| m.max(*v)); + let mut exp = ((-max) + input).mapv(f32::exp); + let total = exp + .map_axis(Axis(1), |view| view.sum()) + .into_shape(Ix2(exp.shape()[0], 1)) + .unwrap(); + exp = exp / total; + Ok(exp) + } + } + + pub fn module() -> Module< + impl Encoder, + impl Model, + impl Decoder>, + > { + Module::builder() + .encoder(encoder()) + .decoder(decoder()) + .model(DummyNativeModule) + .build() + .expect("failed to build dummy module (torch backend)") + } + + #[cfg(feature = "torch")] + pub fn module_torch() -> Module< + impl Encoder, + impl Model, + impl Decoder>, + > { + Module::builder() + .encoder(encoder()) + .decoder(decoder()) + .model(DummyCModule::pretrained()) + .build() + .expect("failed to build dummy module (native backend)") + } +} + +#[cfg(feature = "train")] +pub mod python_bindings { + use arrow::{ + array::{Array as ArrowArrayTrait, StructArray}, + ffi::{ArrowArray, ArrowArrayRef}, + record_batch::RecordBatch, + }; + + use ndarray::{Array, Dimension}; + + use pyo3::prelude::*; + + use std::convert::Infallible; + + use super::*; + + trait IntoPyRes { + type Ok; + fn or_else_raise(self) -> PyResult; + } + + impl IntoPyRes for Result + where + E: std::error::Error, + { + type Ok = T; + fn or_else_raise(self) -> PyResult<::Ok> { + self.map_err(|arr_err| pyo3::panic::PanicException::new_err(arr_err.to_string())) + } + } + + pub unsafe fn import_record_batch(record_batch: &PyAny) -> PyResult { + let (p_array, p_schema) = ArrowArray::into_raw(ArrowArray::empty()); + record_batch + .getattr("_export_to_c")? + .call1((p_array as usize, p_schema as usize))?; + let c_array = ArrowArray::try_from_raw(p_array, p_schema).or_else_raise()?; + let s_array = StructArray::from(c_array.to_data().or_else_raise()?); + Ok(RecordBatch::from(&s_array)) + } + + pub unsafe fn export_record_batch( + pyarrow: &PyModule, + record_batch: RecordBatch, + ) -> PyResult<&PyAny> { + let struct_array = StructArray::from(record_batch); + let (p_array, p_schema) = struct_array.to_raw().or_else_raise()?; + let output = pyarrow + .getattr("RecordBatch")? + .getattr("_import_from_c")? + .call1((p_array as usize, p_schema as usize))?; + Ok(output) + } + + #[pyclass(name = "Encoder")] + struct BoundEncoder { + encoder: Box + Send>, + } + + #[pymethods] + impl BoundEncoder { + #[new] + fn new() -> Self { + Self { + encoder: Box::new(dummy::encoder()), + } + } + + fn encode<'p>(&self, py: Python<'p>, record_batch: &PyAny) -> PyResult<&'p PyAny> { + let record_batch = unsafe { import_record_batch(record_batch) }?; + let shape = self.encoder.shape(); + let mut buffer = Array::zeros(shape); + self.encoder + .encode(&record_batch, buffer.view_mut()) + .or_else_raise()?; + let as_py = buffer.into_raw_vec().into_py(py); + let py_tensor = py + .import("torch")? + .getattr("FloatTensor")? + .call1((as_py,))?; + py_tensor.getattr("reshape")?.call1((shape.into_pattern(),)) + } + } + + pub fn bind(py: Python) -> PyResult<&PyModule> { + let module = PyModule::new(py, "dummy")?; + module.add_class::()?; + Ok(module) + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + fn dummy_end_to_end(module: Module) + where + E: Encoder, + M: Model, + D: Decoder>, + { + let input = dummy::tests::data(1000); + let output = module.forward(&input).unwrap(); + for (column, target) in output.into_iter() { + assert_eq!(Some(column.as_str()), target) + } + } + + #[test] + fn dummy_end_to_end_native() { + let module = dummy::module(); + dummy_end_to_end(module) + } + + #[cfg(feature = "torch")] + #[test] + fn dummy_end_to_end_torch() { + let module = dummy::module_torch(); + dummy_end_to_end(module) + } +} diff --git a/semdet/train/dummy.tch b/semdet/train/dummy.tch new file mode 100644 index 000000000..0b97f488e Binary files /dev/null and b/semdet/train/dummy.tch differ diff --git a/synth/Cargo.toml b/synth/Cargo.toml index 4be4c4c68..4138043bd 100644 --- a/synth/Cargo.toml +++ b/synth/Cargo.toml @@ -53,6 +53,9 @@ ctrlc = { version = "3.0", features = ["termination"] } synth-core = { path = "../core" } synth-gen = { path = "../gen" } +arrow = "5.1.0" +semantic-detection = { path = "../semdet" } + rust_decimal = "1.10.3" indicatif = "0.15.0" diff --git a/synth/src/cli/import_utils.rs b/synth/src/cli/import_utils.rs index 43c6855b8..ec183c50f 100644 --- a/synth/src/cli/import_utils.rs +++ b/synth/src/cli/import_utils.rs @@ -9,9 +9,14 @@ use std::str::FromStr; use synth_core::schema::content::number_content::U64; use synth_core::schema::{ ArrayContent, FieldRef, Id, NumberContent, ObjectContent, OptionalMergeStrategy, RangeStep, - SameAsContent, + SameAsContent, StringContent, FakerContent }; +use synth_core::graph::string::{FakerArgs, Locale}; use synth_core::{Content, Name, Namespace}; +use arrow::record_batch::RecordBatch; +use arrow::array::{StringArray, ArrayRef}; +use std::collections::HashMap; +use std::sync::Arc; #[derive(Debug)] pub(crate) struct Collection { @@ -53,8 +58,56 @@ fn populate_namespace_collections( namespace.put_collection( &Name::from_str(table_name)?, - Collection::try_from((datasource, column_infos))?.collection, + Collection::try_from((datasource, column_infos.clone()))?.collection, )?; + + let module = semantic_detection::module::dummy::module(); + let values = task::block_on(datasource.get_deterministic_samples(table_name))?; + let mut pivoted = HashMap::new(); + for value in values.iter() { + let row = value.as_object().unwrap(); + for (column, field) in row.iter() { + if let Some(content) = field.as_str() { + pivoted + .entry(column.to_string()) + .or_insert_with(Vec::new) + .push(Some(content)); + } + + if field.is_null() { + if let Some(values) = pivoted.get_mut(column) { + values.push(None); + } + } + } + } + + let column_infos = column_infos + .into_iter() + .map(|ci| (ci.column_name.to_string(), ci)) + .collect::>(); + let pivoted = pivoted.into_iter().map(|(k, v)| (k, Arc::new(StringArray::from(v)) as ArrayRef)); + let record_batch = RecordBatch::try_from_iter(pivoted).unwrap(); + let target = module.forward(&record_batch).unwrap(); + for (column, generator) in target { + let column_meta = column_infos.get(&column).unwrap(); + if let Some(generator) = generator { + let field_ref = FieldRef::new(& if column_meta.is_nullable { + format!("{}.content.{}.0", table_name, &column_meta.column_name) + } else { + format!("{}.content.{}", table_name, &column_meta.column_name) + })?; + if let Content::String(string_content) = namespace.get_s_node_mut(&field_ref)? { + *string_content = StringContent::Faker(FakerContent { + generator: generator.to_string(), + locales: vec![], + args: FakerArgs { + locales: vec![Locale::EN] + } + }) + } + } + } } Ok(()) diff --git a/synth/src/datasource/relational_datasource.rs b/synth/src/datasource/relational_datasource.rs index a350fc31b..40cee94af 100644 --- a/synth/src/datasource/relational_datasource.rs +++ b/synth/src/datasource/relational_datasource.rs @@ -8,7 +8,7 @@ use synth_core::Content; const DEFAULT_INSERT_BATCH_SIZE: usize = 1000; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ColumnInfo { pub(crate) column_name: String, pub(crate) ordinal_position: i32,