const tf = require('@tensorflow/tfjs');
const base64js = require('base64-js')
const opt = require('optimization-js'
)
const base_estimators = {}
// supported serializable classes
const prep = require('./preprocessing')
base_estimators['TableFeaturesTransformer'] = prep.TableFeaturesTransformer
base_estimators['StandardScaler'] = prep.StandardScaler
base_estimators['LabelBinarizer'] = prep.LabelBinarizer
const tree = require('./tree')
base_estimators['DecisionTreeClassifier'] = tree.DecisionTreeClassifier
base_estimators['DecisionTreeRegressor'] = tree.DecisionTreeRegressor
const nn = require('./nn')
base_estimators['MLPRegressor'] = nn.MLPRegressor
base_estimators['MLPClassifier'] = nn.MLPClassifier
const linear_model = require('./linear_model')
base_estimators['SGDRegressor'] = linear_model.SGDRegressor
base_estimators['SGDClassifier'] = linear_model.SGDClassifier
const ensemble = require('./ensemble')
base_estimators['GradientBoostingRegressor'] = ensemble.GradientBoostingRegressor
base_estimators['GradientBoostingClassifier'] = ensemble.GradientBoostingClassifier
const model_selection = require('./model_selection')
base_estimators['OMGSearchCV'] = model_selection.OMGSearchCV
module.exports.base_estimators = base_estimators
// supported numerical array types for serialization
const js_array_types = {
'Float32Array': Float32Array,
'Float64Array': Float64Array,
'Int8Array': Int8Array,
'Int16Array': Int16Array,
'Int32Array': Int32Array,
'Uint8Array': Uint8Array,
'Uint16Array': Uint16Array,
'Uint32Array': Uint32Array
}
// supported optimization-js classes
const opt_js_types = {
'Real': opt.Real,
'Integer': opt.Integer,
'Categorical': opt.Categorical
}
/**
* Convert various objects to json format. Is useful for
* serialization and deserialization of aitable objects.
* @param {Any} obj Instance of an object to be converted
* into a serializable json. Can be a json serializable
* value, a class that can be serialized, or instance
* of tensorflowjs objects.
*/
async function dumpjson(obj){
var type = typeof obj
if(['number', 'string', 'boolean'].includes(type) || obj === null){
// native type, serializable
return {
'type': 'native',
'value': obj
}
}else if(type === 'object'){
// get class name of the object
var cname = obj.constructor.name
// list type
if(cname === 'Array'){
var serialized = []
for(var value of obj){
var out = await dumpjson(value)
serialized.push(out)
}
return {
'type': 'list',
'value': serialized
}
}else if(cname === 'Object'){ // assumes the case of dictionary
var serialized = [] // list of pairs of key, value
for(var key in obj){
serialized.push([
await dumpjson(key), await dumpjson(obj[key])
])
}
return {
'type': 'dict',
'value': serialized
}
}else if(cname === 'Tensor'){
var data = await obj.data()
var shape = obj.shape
var dtype = obj.dtype
return {
'type': 'tf_tensor',
'value': {
'data': await dumpjson(data),
'shape': shape,
'dtype': dtype
}
}
}else if(cname === 'Sequential'){ // convert tf model
var results = []
var handleSave = function (artifacts){
results.push(artifacts.modelTopology)
results.push(artifacts.weightSpecs)
results.push(artifacts.weightData)
}
await obj.save(tf.io.withSaveHandler(handleSave));
var value = {
modelTopology: results[0],
weightSpecs: results[1],
weightData: await dumpjson(new Uint8Array(results[2]))
}
return {
'type': 'tf_model',
'value': value
}
}else if(cname in opt_js_types){
var value = {
'class':cname
}
if(cname === 'Categorical'){
value['categories'] = obj.categories
}else{
value['low'] = obj.low
value['high'] = obj.high
}
return {
'type': "optimization-js",
'value': value
}
}else if(cname in js_array_types){ // convert to supported array
var serialized = {
'class': cname,
'data': base64js.fromByteArray(new Uint8Array(obj.buffer))
}
return {
'type': 'native_array',
'value': serialized
}
}else if(cname in base_estimators){ // convert estimator to json
var serialized = {
'params': await dumpjson(obj.params),
'state': await dumpjson(obj.state),
'class': cname
}
return {
'type': 'estimator',
'value': serialized
}
}
}else{
throw Error('Unsupported object for serialization: ' + obj)
}
}
module.exports.dumpjson = dumpjson
/**
* Inverse of dumpjson.
* @param {Any} json Blueprint of the object, to be deserialized.
*/
async function loadjson(json){
var type = json['type']
var value = json['value']
if(type === 'native'){
return value
}else if(type === 'list'){
var result = []
for(var v of value){
result.push(await loadjson(v))
}
return result
}else if(type === 'dict'){
var result = {}
for(var v of value){
var key = await loadjson(v[0])
var value = await loadjson(v[1])
result[key] = value
}
return result
}else if(type === 'tf_tensor'){
var dtype = value['dtype']
var shape = value['shape']
var data = await loadjson(value['data'])
var result = new tf.tensor(data, shape, dtype)
return result
}else if(type === 'tf_model'){
var modelTopology = value['modelTopology']
var weightSpecs = value['weightSpecs']
var weightData = value['weightData']
// load the buffer
weightData = (await loadjson(weightData)).buffer
// load the model
var model = await tf.loadModel(
tf.io.fromMemory(modelTopology, weightSpecs, weightData)
);
return model
}else if(type === 'optimization-js'){
var cname = value['class']
var ctype = opt_js_types[cname]
var result = null
if(cname === 'Categorical'){
result = new opt.Categorical(value['categories'])
}else if(cname === 'Real'){
result = new opt.Real(value['low'], value['high'])
}else{
result = new opt.Integer(value['low'], value['high'])
}
return result
}else if(type === 'native_array'){
var cname = value['class']
var ctype = js_array_types[cname]
// should return Uint8Array; Convert it to appropriate format
var buffer = base64js.toByteArray(value['data'])
buffer = buffer.buffer
var result = new ctype(buffer)
return result
}else if(type === 'estimator'){
var cname = value['class']
var params = await loadjson(value['params'])
var state = await loadjson(value['state'])
// get the estimator class
var ctype = base_estimators[cname]
var obj = new ctype(params)
obj.state = state
return obj
}else{
throw Error('Unknown type for deserialization: ' + type + ', value: ' + value)
}
}
module.exports.loadjson = loadjson
/**
* Creates a copy of the object by serializing
* the object to json, and then deserializing.
* @param {Any} obj object to be cloned.
* @param {Boolean} stringify whether to stringify
* obtained blueprint. Should only be used for testing
* of whether the serialization representation is
* convertable to string, or in case serialized
* object might contain pointers to nested objects
* of the parent object, which erroneously are not
* cloned properly.
*/
async function clone(obj, stringify=false){
var blueprint = await dumpjson(obj)
if(stringify){
blueprint = JSON.stringify(blueprint)
blueprint = JSON.parse(blueprint)
}
var result = await loadjson(blueprint)
return result
}
module.exports.clone = clone