1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use variable::Variable;
use value::{Value, ValueInner};

cpp! {{
  #include <CNTKLibrary.h>
  #include <cstdio>
  #include <vector>
  #include <unordered_map>

  using namespace CNTK;
  using namespace std;
}}

type DataMapInner = [u64; 1usize];

/// Wrapper around unordered_map<Variable, Value> to pass bindings to Function evaluation
pub struct DataMap {
    pub(super) payload: *mut DataMapInner
}

impl DataMap {
    /// Creates an empty DataMap
    pub fn new() -> DataMap {
        DataMap {
            payload: unsafe {
                cpp!([] -> *mut DataMapInner as "unordered_map<Variable, ValuePtr>*" {
                    return new unordered_map<Variable, ValuePtr>;
                })
            }
        }
    }

    /// Adds binding to DataMap. If mapping for given Variable exists, it will be overwritten.
    pub fn add<T: Into<Variable>>(&mut self, variable: T, value: &Value) {
        let v = variable.into();
        let var_payload = v.payload;
        let val_payload = value.payload;
        let mut payload = self.payload;

        unsafe {
            cpp!([mut payload as "unordered_map<Variable, ValuePtr>*", var_payload as "Variable", val_payload as "ValuePtr"] {
                payload->insert({var_payload, val_payload});
            })
        }
    }

    /// Adds binding to null to DataMap. Useful, when we want function evaluation to create the Value.
    pub fn add_null<T: Into<Variable>>(&mut self, variable: T) {
        let v = variable.into();
        let var_payload = v.payload;
        let mut payload = self.payload;

        unsafe {
            cpp!([mut payload as "unordered_map<Variable, ValuePtr>*", var_payload as "Variable"] {
                payload->insert({var_payload, nullptr});
            })
        }
    }

    pub fn get<T: Into<Variable>>(&self, variable: T) -> Option<Value> {
        let v = variable.into();
        let var_payload = v.payload;
        let payload = self.payload;

        let has_var = unsafe {
            cpp!([payload as "unordered_map<Variable, ValuePtr>*", var_payload as "Variable"] -> bool as "bool" {
                auto it = payload->find(var_payload);
                if (it == payload->end()) return false;
                if (it->second.get() == NULL) return false;
                return true;
            })
        };

        if has_var {
            Some(
                Value { payload: unsafe {
                    cpp!([payload as "unordered_map<Variable, ValuePtr>*", var_payload as "Variable"] -> ValueInner as "ValuePtr" {
                        return payload->find(var_payload)->second;
                    })
                }}
            )
        } else {
            None
        }
    }
}

impl Drop for DataMap {
    fn drop(&mut self) {
        let payload = self.payload;
        unsafe {
            cpp!([payload as "unordered_map<Variable, ValuePtr>*"] {
                delete payload;
            })
        };
    }
}

#[macro_export]
macro_rules! datamap {
    (@single $($x:tt)*) => (());
    (@count $($rest:expr),*) => (<[()]>::len(&[$(datamap!(@single $rest)),*]));

    ($($key:expr => $value:expr,)+) => { datamap!($($key => $value),+) };
    ($($key:expr => $value:expr),*) => {
        {
            let mut _map = DataMap::new();
            $(
                _map.add($key, $value);
            )*
            _map
        }
    };
}

#[macro_export]
macro_rules! outdatamap {
    (@single $($x:tt)*) => (());
    (@count $($rest:expr),*) => (<[()]>::len(&[$(outdatamap!(@single $rest)),*]));
    
    ($($key:expr,)+) => { outdatamap!($($key),+) };
    ($($key:expr),*) => {
        {
            let mut _set = DataMap::new();
            $(
                _set.add_null($key);
            )*
            _set
        }
    };
}

#[cfg(test)]
mod tests {
    use super::*;
    use variable::*;
    use value::*;
    use device::*;
    use shape::Shape;

    #[test]
    fn test_create() {
        let _map = DataMap::new();
    }

    #[test]
    fn test_add_and_get() {
        let mut map = DataMap::new();
        let var = Variable::input_variable(&Shape::scalar());
        let var2 = Variable::input_variable(&Shape::scalar());

        let data: Vec<f32> = vec!(11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 110.0);

        let val = Value::batch_from_vec(&var.shape(), &data, DeviceDescriptor::cpu());
        map.add(var.clone(), &val);
        assert_eq!(map.get(var).is_some(), true);
        assert_eq!(map.get(var2).is_some(), false);
    }
}