function updateParams(form) {
var iterationNumber = +form.iterationNumber.value,
alpha = +form.alpha.value,
theta0 = +form.theta0.value,
theta1 = +form.theta1.value;
runGradientDescent(iterationNumber, alpha, theta0, theta1);
}
var margin = { top: 20, right: 20, bottom: 50, left: 50 },
width = 550 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var format = d3.format(".3f");
var x = d3.scale.linear()
.range([0, width]);
var x_axis = d3.svg.axis()
.scale(x)
.orient("bottom");
var y = d3.scale.linear()
.range([height, 0]);
var y_axis = d3.svg.axis()
.scale(y)
.orient("left");
var svg = d3.select("#plot")
.append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var hyp = svg.append("text")
.attr("id", "hypothesis_fx")
.attr("x", 150)
.attr("y", 50);
var cost_plot = svg.append("g")
.attr("id", "cost_plot")
.attr("transform", "translate(350, 260)");
d3.csv("/data/data_gradient_descent.csv", function(error, data) {
data.forEach(function(d) {
d.population = +d.population;
d.profit = +d.profit;
});
x.domain([0, d3.max(data, function(d) { return d.population; })]).nice();
y.domain(d3.extent(data, function(d) { return d.profit; })).nice();
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(x_axis)
.append("text")
.attr("x", width / 2)
.attr("y", 40)
.text("Population of City in 10,000s");
svg.append("g")
.attr("class", "y axis")
.call(y_axis)
.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -height / 2)
.attr("y", -50)
.attr("dy", ".71em")
.text("Profit in $10,000s")
svg.append("g")
.attr("id", "scatterplot")
.selectAll(".dot")
.data(data)
.enter()
.append("circle")
.attr("class", "dot")
.attr("r", 3.5)
.attr("cx", function(d) { return x(d.population); })
.attr("cy", function(d) { return y(d.profit); });
runGradientDescent(100, 0.001, 0, 0);
});
function resetPlot() {
d3.select("#line").remove();
d3.select("#cost_line").remove();
d3.selectAll(".axis_cost").remove();
}
function runGradientDescent(iterationNumber, alpha, theta0, theta1) {
resetPlot();
var data = d3.selectAll("circle").data();
var iteration = 0,
m = data.length;
var xMin = x.domain()[0],
xMax = x.domain()[1],
yMin = y.domain()[0],
yMax = y.domain()[1];
var line = svg.append("line")
.attr("class", "line")
.attr("id", "line")
.attr("x1",x(xMin))
.attr("y1",y(theta1 * xMin + theta0))
.attr("x2",x(xMax))
.attr("y2",y(theta1 * xMax + theta0));
hyp.text("hθ(x) = 0 + 0x");
function computeCost (data, theta0, theta1) {
var cost = 0;
data.forEach(function(d) {
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
});
return cost/(2 * m);
};
var max_cost = computeCost(data, theta0, theta1);
var d_cost = [];
d3.timer(function() {
d_cost.push({ "iteration" : iteration,
"cost" : computeCost(data, theta0, theta1)
});
var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));
theta0 = temp0;
theta1 = temp1;
line.attr("x1",x( xMin ))
.attr("y1",y( theta1 * xMin + theta0 ))
.attr("x2",x( xMax ))
.attr("y2",y( theta1 * xMax + theta0 ));
hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x");
if (iteration == iterationNumber) {
plotCost(d_cost, iterationNumber, max_cost);
}
return ++iteration > iterationNumber;
}, 200);
};
function plotCost(d_cost, iterationNumber, max_cost) {
var x_cost = d3.scale.linear()
.domain([0, iterationNumber])
.range([0, 100]);
var x_axis_cost = d3.svg.axis()
.scale(x_cost)
.orient("bottom")
.ticks(3);
var y_cost = d3.scale.linear()
.domain([0, max_cost])
.range([100, 0]);
var y_axis_cost = d3.svg.axis()
.scale(y_cost)
.orient("left")
.ticks(3);
cost_line = d3.svg.line()
.x(function(d) { return x_cost(d.iteration); })
.y(function(d) { return y_cost(d.cost); })
.interpolate(["basis"]);
cost_plot.append("path")
.datum(d_cost)
.attr("id", "cost_line")
.attr("d", cost_line)
.attr("stroke", "black")
.attr("fill", "none");
cost_plot.append("g")
.attr("class", "y axis axis_cost")
.call(y_axis_cost)
.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -55)
.attr("y", -45)
.attr("dy", ".71em")
.text("J(θ)");
cost_plot.append("g")
.attr("class", "x axis axis_cost")
.attr("transform", "translate(0, 100)")
.call(x_axis_cost)
.append("text")
.attr("x", 55)
.attr("y", 40)
.text("Num. of Iterations");
}