Black Lives Matter. Please consider donating to Black Girls Code today.
Learn how to use COVID-19 data in open source Dash apps. Register for the Sept 23rd webinar with IQT!

[Solved] Plotting an SVM decision boundary using go.Contour

I am having great trouble generating the decision boundary for the results of an SVM classification.

Using the following chunk of code:

contmap = go.Contour(z = pred_grid, showlegend = True, 
                         name = "Prediction", 
                        text =np.asarray(list(map(lambda x: "Accepted" if x == 1 else "Rejected",
                                            pred_grid.flatten()))).reshape(x0m.shape),
                        hoverinfo = "name+x+y+text" ,
                         ncontours = 5,
                         line = dict(smoothing=1.3, width = 2),
                         contours = dict(coloring="lines"), 
                          showscale = False)

plotly.offline.iplot({
        "data": [contmap],
        "layout" : go.Layout(title = "Test",
                             xaxis = dict(title="X1"),
                             yaxis = dict(title="X2"))
})

I am able to see a decision boundary line that looks like this:


However, when I add in the arguments for x and y into go.Contour, with this code, the decision boundary just disappears:

contmap = go.Contour(z = pred_grid, x=x0m, y=x1m, showlegend = True, 
                         name = "Prediction", 
                        text =np.asarray(list(map(lambda x: "Accepted" if x == 1 else "Rejected",
                                            pred_grid.flatten()))).reshape(x0m.shape),
                        hoverinfo = "name+x+y+text" ,
                         ncontours = 5,
                         line = dict(smoothing=1.3, width = 2),
                         contours = dict(coloring="lines"), 
                          showscale = False)

plotly.offline.iplot({
        "data": [contmap],
        "layout" : go.Layout(title = "Test",
                             xaxis = dict(title="X1"),
                             yaxis = dict(title="X2"))
})

that code gives just a blank plot:


To clarify what values the variables are, they are:

pred_grid
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=uint8)

x1m
array([[ 1.6177    ,  1.6177    ,  1.6177    ,  1.6177    ,  1.6177    ,
         1.6177    ,  1.6177    ,  1.6177    ,  1.6177    ,  1.6177    ],
       [ 1.95086667,  1.95086667,  1.95086667,  1.95086667,  1.95086667,
         1.95086667,  1.95086667,  1.95086667,  1.95086667,  1.95086667],
       [ 2.28403333,  2.28403333,  2.28403333,  2.28403333,  2.28403333,
         2.28403333,  2.28403333,  2.28403333,  2.28403333,  2.28403333],
       [ 2.6172    ,  2.6172    ,  2.6172    ,  2.6172    ,  2.6172    ,
         2.6172    ,  2.6172    ,  2.6172    ,  2.6172    ,  2.6172    ],
       [ 2.95036667,  2.95036667,  2.95036667,  2.95036667,  2.95036667,
         2.95036667,  2.95036667,  2.95036667,  2.95036667,  2.95036667],
       [ 3.28353333,  3.28353333,  3.28353333,  3.28353333,  3.28353333,
         3.28353333,  3.28353333,  3.28353333,  3.28353333,  3.28353333],
       [ 3.6167    ,  3.6167    ,  3.6167    ,  3.6167    ,  3.6167    ,
         3.6167    ,  3.6167    ,  3.6167    ,  3.6167    ,  3.6167    ],
       [ 3.94986667,  3.94986667,  3.94986667,  3.94986667,  3.94986667,
         3.94986667,  3.94986667,  3.94986667,  3.94986667,  3.94986667],
       [ 4.28303333,  4.28303333,  4.28303333,  4.28303333,  4.28303333,
         4.28303333,  4.28303333,  4.28303333,  4.28303333,  4.28303333],
       [ 4.6162    ,  4.6162    ,  4.6162    ,  4.6162    ,  4.6162    ,
         4.6162    ,  4.6162    ,  4.6162    ,  4.6162    ,  4.6162    ]])

x0m
array([[ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ],
       [ 0.086405  ,  0.52291556,  0.95942611,  1.39593667,  1.83244722,
         2.26895778,  2.70546833,  3.14197889,  3.57848944,  4.015     ]])

all 3 meshgrids, pred_grid,x0m and x1m are of exactly the same shape.

I am really puzzled why my plot just isn’t working right, even though I am using the same code which worked for a contour plot that I had successfully generated previously in an earlier project that can be seen here.

I just realized that the reason is because I was passing in 2D arrays into the x and y parameters of the go.Contour, when they should have been 1D… passing in the 1D versions fixed the problem!

x1m
array([ 1.3177    ,  1.71753333,  2.11736667,  2.5172    ,  2.91703333,
        3.31686667,  3.7167    ,  4.11653333,  4.51636667,  4.9162    ])

x0m
array([-0.213595  ,  0.28958222,  0.79275944,  1.29593667,  1.79911389,
        2.30229111,  2.80546833,  3.30864556,  3.81182278,  4.315     ])

with pred_grid the same as above, gives the correct plot with the data: