var margin = { top: 20, right: 0, bottom: 50, left: 85 },
    svg_dx = 500, 
    svg_dy = 400,
    plot_dx = svg_dx - margin.right - margin.left,
    plot_dy = svg_dy - margin.top - margin.bottom;

var xPos = d3.scaleLinear().range([margin.left, plot_dx]),
    yPos = d3.scaleLinear().range([plot_dy, margin.top]);

var svg = d3.select("#plot")
            .append("svg")
            .attr("width", svg_dx)
            .attr("height", svg_dy);

d3.csv("/data/logistic_reg_grad_descent.csv", d => {

    var d_extent_x = d3.extent(d, d => +d.x),
        d_extent_y = d3.extent(d, d => +d.y);

    xPos.domain(d_extent_x);
    yPos.domain(d_extent_y);

    var axis_x = d3.axisBottom(xPos),
        axis_y = d3.axisLeft(yPos);

    svg.append("g")
       .attr("id", "axis_x")
       .attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")")
       .call(axis_x);

    svg.append("g")
       .attr("id", "axis_y")
       .attr("transform", "translate(" + (margin.left / 2) + ", 0)")
       .call(axis_y);

    svg.append("g")
       .selectAll("path")
       .data(d)
       .enter()
       .append("path")
       .attr("class", d => d.group == "1" ? "pts group1" : "pts group2")
       .attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross))
       .attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")")
       .call(d3.drag()
               .on("start", dragstarted)
               .on("drag", dragged));

    runGradientDescent(400, 0.0004, -24.0, 0.5, 0.2);
});

function dragstarted() {
  d3.select(this).raise();
}

function dragged(d) {

    var dx = d3.event.sourceEvent.offsetX,
        dy = d3.event.sourceEvent.offsetY;

    d3.select(this)
      .attr("transform", d => "translate(" + dx + "," + dy + ")");
}

function sigmoid(z) {
    var s = 1 / (1 + Math.pow(Math.E, -z));
    return s; 
}

function computeGradient(m, y, h, X) {

    // conversion from octave of grad = (1 / m) * (h - y)' * X;
    var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X)
                   .map(d => (1 / m) * d);

    return grad;
}

function updateParams(form) {

    var iterationNumber = +form.iterationNumber.value,
                  alpha = +form.alpha.value,
                 theta0 = +form.theta0.value,
                 theta1 = +form.theta1.value,
                 theta2 = +form.theta2.value;

    // remove previous decision boundary
    d3.select("#dec_boundary").remove();

    runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2);
}

function runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2) {

    var coords = [],
        group = [];

    d3.selectAll(".pts")
      .each(function() {

        var pt = d3.select(this);
        var xy_re = /\d+.?\d+,\d+.?\d+/;

        // translated x and y values
        var xy = pt.attr("transform")
                   .match(xy_re)[0]
                   .split(",");

        coords.push(xy);

        // group data
        group.push(pt.data()[0].group);

      });

    var d = coords.map((coord, i) => {
        return { "group": group[i],
                     "x": xPos.invert(+coord[0]),
                     "y": yPos.invert(+coord[1]) 
               }
    });

    var d_extent_x = d3.extent(d, pt => +pt.x);

    var X = d.map(pt => [1, +pt.x, +pt.y]),
        y = d.map(pt => +pt.group);

    X = math.matrix(X);
    y = math.matrix(y);

    var iteration = 0,
        m = math.subset(math.size(X), math.index(0)),
        theta = math.matrix([theta0, theta1, theta2])

    var dec_bnd = svg.append("line")
                     .attr("id", "dec_boundary");

    var iterate = d3.timer(() => {

        var h = math.multiply(X, theta).map(z => sigmoid(z)),
            grad = computeGradient(m, y, h, X);

        // update theta
        theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i))))

        var theta0 = math.subset(theta, math.index(0)),
            theta1 = math.subset(theta, math.index(1)),
            theta2 = math.subset(theta, math.index(2));

        dec_bnd.attr("x1",xPos(d_extent_x[0]))
               .attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0)))
               .attr("x2",xPos(d_extent_x[1]))
               .attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0)));

        if (iteration++ > iterationNumber) {
            iterate.stop();
        }
    }, 200)
}