gophernotes/kernel.go

385 lines
9.0 KiB
Go
Raw Normal View History

2017-07-22 16:49:22 -04:00
package main
import (
"bufio"
"bytes"
2017-07-22 16:49:22 -04:00
"encoding/json"
"fmt"
"io"
2017-07-22 16:49:22 -04:00
"io/ioutil"
"log"
"os"
"strings"
2017-07-22 16:49:22 -04:00
"github.com/cosmos72/gomacro/base"
2017-07-22 16:49:22 -04:00
"github.com/cosmos72/gomacro/classic"
zmq "github.com/pebbe/zmq4"
)
// ExecCounter is incremented each time we run user code in the notebook.
var ExecCounter int
// ConnectionInfo stores the contents of the kernel connection
// file created by Jupyter.
type ConnectionInfo struct {
SignatureScheme string `json:"signature_scheme"`
Transport string `json:"transport"`
StdinPort int `json:"stdin_port"`
ControlPort int `json:"control_port"`
IOPubPort int `json:"iopub_port"`
HBPort int `json:"hb_port"`
ShellPort int `json:"shell_port"`
Key string `json:"key"`
IP string `json:"ip"`
}
// SocketGroup holds the sockets needed to communicate with the kernel,
// and the key for message signing.
type SocketGroup struct {
ShellSocket *zmq.Socket
ControlSocket *zmq.Socket
StdinSocket *zmq.Socket
IOPubSocket *zmq.Socket
Key []byte
}
// kernelInfo holds information about the igo kernel, for
// kernel_info_reply messages.
type kernelInfo struct {
ProtocolVersion []int `json:"protocol_version"`
Language string `json:"language"`
}
// kernelStatus holds a kernel state, for status broadcast messages.
type kernelStatus struct {
ExecutionState string `json:"execution_state"`
}
// shutdownReply encodes a boolean indication of stutdown/restart
type shutdownReply struct {
Restart bool `json:"restart"`
}
// runKernel is the main entry point to start the kernel.
func runKernel(connectionFile string) {
// Set up the "Session" with the replpkg.
ir := classic.New()
// Parse the connection info.
var connInfo ConnectionInfo
connData, err := ioutil.ReadFile(connectionFile)
if err != nil {
log.Fatal(err)
}
if err = json.Unmarshal(connData, &connInfo); err != nil {
log.Fatal(err)
}
// Set up the ZMQ sockets through which the kernel will communicate.
sockets, err := prepareSockets(connInfo)
if err != nil {
log.Fatal(err)
}
poller := zmq.NewPoller()
poller.Add(sockets.ShellSocket, zmq.POLLIN)
poller.Add(sockets.StdinSocket, zmq.POLLIN)
poller.Add(sockets.ControlSocket, zmq.POLLIN)
// msgParts will store a received multipart message.
var msgParts [][]byte
// Start a message receiving loop.
for {
2017-07-22 16:49:22 -04:00
polled, err := poller.Poll(-1)
if err != nil {
log.Fatal(err)
}
for _, item := range polled {
// Handle various types of messages.
switch socket := item.Socket; socket {
// Handle shell messages.
case sockets.ShellSocket:
msgParts, err = sockets.ShellSocket.RecvMessageBytes(0)
2017-07-22 16:49:22 -04:00
if err != nil {
log.Println(err)
}
msg, ids, err := WireMsgToComposedMsg(msgParts, sockets.Key)
2017-07-22 16:49:22 -04:00
if err != nil {
log.Println(err)
return
}
handleShellMsg(ir, msgReceipt{msg, ids, sockets})
// TODO Handle stdin socket.
case sockets.StdinSocket:
sockets.StdinSocket.RecvMessageBytes(0)
// Handle control messages.
case sockets.ControlSocket:
msgParts, err = sockets.ControlSocket.RecvMessageBytes(0)
2017-07-22 16:49:22 -04:00
if err != nil {
log.Println(err)
return
}
msg, ids, err := WireMsgToComposedMsg(msgParts, sockets.Key)
2017-07-22 16:49:22 -04:00
if err != nil {
log.Println(err)
return
}
handleShellMsg(ir, msgReceipt{msg, ids, sockets})
}
}
}
}
// prepareSockets sets up the ZMQ sockets through which the kernel
// will communicate.
func prepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
// Initialize the context.
context, err := zmq.NewContext()
if err != nil {
return SocketGroup{}, err
}
// Initialize the socket group.
var sg SocketGroup
sg.ShellSocket, err = context.NewSocket(zmq.ROUTER)
if err != nil {
return sg, err
}
sg.ControlSocket, err = context.NewSocket(zmq.ROUTER)
if err != nil {
return sg, err
}
sg.StdinSocket, err = context.NewSocket(zmq.ROUTER)
if err != nil {
return sg, err
}
sg.IOPubSocket, err = context.NewSocket(zmq.PUB)
if err != nil {
return sg, err
}
// Bind the sockets.
address := fmt.Sprintf("%v://%v:%%v", connInfo.Transport, connInfo.IP)
sg.ShellSocket.Bind(fmt.Sprintf(address, connInfo.ShellPort))
sg.ControlSocket.Bind(fmt.Sprintf(address, connInfo.ControlPort))
sg.StdinSocket.Bind(fmt.Sprintf(address, connInfo.StdinPort))
sg.IOPubSocket.Bind(fmt.Sprintf(address, connInfo.IOPubPort))
// Set the message signing key.
sg.Key = []byte(connInfo.Key)
return sg, nil
}
// handleShellMsg responds to a message on the shell ROUTER socket.
func handleShellMsg(ir *classic.Interp, receipt msgReceipt) {
switch receipt.Msg.Header.MsgType {
case "kernel_info_request":
if err := sendKernelInfo(receipt); err != nil {
log.Fatal(err)
}
case "execute_request":
if err := handleExecuteRequest(ir, receipt); err != nil {
log.Fatal(err)
}
case "shutdown_request":
handleShutdownRequest(receipt)
2017-07-22 16:49:22 -04:00
default:
log.Println("Unhandled shell message: ", receipt.Msg.Header.MsgType)
}
}
// sendKernelInfo sends a kernel_info_reply message.
func sendKernelInfo(receipt msgReceipt) error {
reply, err := NewMsg("kernel_info_reply", receipt.Msg)
if err != nil {
return err
}
reply.Content = kernelInfo{[]int{4, 0}, "go"}
if err := receipt.SendResponse(receipt.Sockets.ShellSocket, reply); err != nil {
return err
}
return nil
}
// handleExecuteRequest runs code from an execute_request method,
// and sends the various reply messages.
func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
// Prepare the reply message.
2017-07-22 16:49:22 -04:00
reply, err := NewMsg("execute_reply", receipt.Msg)
if err != nil {
return err
}
content := make(map[string]interface{})
reqcontent := receipt.Msg.Content.(map[string]interface{})
code := reqcontent["code"].(string)
in := bufio.NewReader(strings.NewReader(code))
2017-07-22 16:49:22 -04:00
silent := reqcontent["silent"].(bool)
if !silent {
ExecCounter++
}
content["execution_count"] = ExecCounter
// Redirect the standard out from the REPL.
2017-08-04 10:30:31 -04:00
oldStdout := os.Stdout
rOut, wOut, err := os.Pipe()
if err != nil {
return err
}
2017-08-04 10:30:31 -04:00
os.Stdout = wOut
// Redirect the standard error from the REPL.
rErr, wErr, err := os.Pipe()
if err != nil {
return err
}
ir.Stderr = wErr
// Prepare and perform the multiline evaluation.
env := ir.Env
env.Options &^= base.OptShowPrompt
env.Line = 0
// Perform the first iteration manually, to collect comments
var comments string
str, firstToken := env.ReadMultiline(in, base.ReadOptCollectAllComments)
if firstToken >= 0 {
comments = str[0:firstToken]
if firstToken > 0 {
str = str[firstToken:]
env.IncLine(comments)
}
}
if ir.ParseEvalPrint(str, in) {
ir.Repl(in)
}
2017-08-04 10:30:31 -04:00
// Copy the stdout in a separate goroutine to prevent
// blocking on printing.
2017-08-04 10:30:31 -04:00
outStdout := make(chan string)
go func() {
var buf bytes.Buffer
2017-08-04 10:30:31 -04:00
io.Copy(&buf, rOut)
outStdout <- buf.String()
}()
2017-08-04 10:30:31 -04:00
// Return stdout back to normal state.
wOut.Close()
os.Stdout = oldStdout
val := <-outStdout
2017-07-22 16:49:22 -04:00
2017-08-04 10:30:31 -04:00
// Copy the stderr in a separate goroutine to prevent
// blocking on printing.
2017-08-04 10:30:31 -04:00
outStderr := make(chan string)
go func() {
var buf bytes.Buffer
io.Copy(&buf, rErr)
2017-08-04 10:30:31 -04:00
outStderr <- buf.String()
}()
wErr.Close()
2017-08-04 10:30:31 -04:00
stdErr := <-outStderr
if len(val) > 0 {
2017-07-22 16:49:22 -04:00
content["status"] = "ok"
content["payload"] = make([]map[string]interface{}, 0)
content["user_variables"] = make(map[string]string)
content["user_expressions"] = make(map[string]string)
if !silent {
2017-07-22 16:49:22 -04:00
var outContent OutputMsg
2017-08-30 14:14:12 -04:00
out, err := NewMsg("execute_result", receipt.Msg)
if err != nil {
return err
}
2017-07-22 16:49:22 -04:00
outContent.Execcount = ExecCounter
outContent.Data = make(map[string]string)
outContent.Data["text/plain"] = val
2017-07-22 16:49:22 -04:00
outContent.Metadata = make(map[string]interface{})
out.Content = outContent
receipt.SendResponse(receipt.Sockets.IOPubSocket, out)
}
}
if len(stdErr) > 0 {
content["status"] = "error"
content["ename"] = "ERROR"
content["evalue"] = stdErr
content["traceback"] = nil
errormsg, err := NewMsg("pyerr", receipt.Msg)
if err != nil {
return err
}
errormsg.Content = ErrMsg{"Error", stdErr, []string{stdErr}}
receipt.SendResponse(receipt.Sockets.IOPubSocket, errormsg)
}
2017-07-22 16:49:22 -04:00
// Send the output back to the notebook.
reply.Content = content
if err := receipt.SendResponse(receipt.Sockets.ShellSocket, reply); err != nil {
return err
}
idle, err := NewMsg("status", receipt.Msg)
if err != nil {
return err
}
idle.Content = kernelStatus{"idle"}
2017-07-22 16:49:22 -04:00
if err := receipt.SendResponse(receipt.Sockets.IOPubSocket, idle); err != nil {
return err
}
return nil
2017-07-22 16:49:22 -04:00
}
// handleShutdownRequest sends a "shutdown" message
func handleShutdownRequest(receipt msgReceipt) {
2017-07-22 16:49:22 -04:00
reply, err := NewMsg("shutdown_reply", receipt.Msg)
if err != nil {
log.Fatal(err)
2017-07-22 16:49:22 -04:00
}
content := receipt.Msg.Content.(map[string]interface{})
restart := content["restart"].(bool)
reply.Content = shutdownReply{restart}
2017-07-22 16:49:22 -04:00
if err := receipt.SendResponse(receipt.Sockets.ShellSocket, reply); err != nil {
log.Fatal(err)
2017-07-22 16:49:22 -04:00
}
log.Println("Shutting down in response to shutdown_request")
os.Exit(0)
}