blob: 87b7da8c57a4001828a87a44363d2b607a292bb2 [file] [log] [blame]
Craig Citro15744b12015-03-02 13:34:32 -08001#!/usr/bin/env python
Joe Gregorio652898b2011-05-02 21:07:43 -04002# -*- coding: utf-8 -*-
3#
Craig Citro751b7fb2014-09-23 11:20:38 -07004# Copyright 2014 Google Inc. All Rights Reserved.
Joe Gregorio652898b2011-05-02 21:07:43 -04005#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070025 $ python prediction.py "bucket/object" "model_id" "project_id"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
INADA Naokie8d87822014-08-20 15:25:24 +090036from __future__ import print_function
Joe Gregorio652898b2011-05-02 21:07:43 -040037
Joe Gregorio968a9582012-03-07 14:52:52 -050038__author__ = ('jcgregorio@google.com (Joe Gregorio), '
39 'marccohen@google.com (Marc Cohen)')
Joe Gregorio652898b2011-05-02 21:07:43 -040040
Joe Gregorioe8391152013-06-28 01:30:57 -040041import argparse
Joe Gregorio652898b2011-05-02 21:07:43 -040042import pprint
43import sys
Joe Gregorio968a9582012-03-07 14:52:52 -050044import time
Joe Gregorio652898b2011-05-02 21:07:43 -040045
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070046from apiclient import sample_tools
Joe Gregorioe8391152013-06-28 01:30:57 -040047from oauth2client import client
Joe Gregorio652898b2011-05-02 21:07:43 -040048
Joe Gregorio968a9582012-03-07 14:52:52 -050049
50# Time to wait (in seconds) between successive checks of training status.
51SLEEP_TIME = 10
52
Joe Gregorioe8391152013-06-28 01:30:57 -040053
54# Declare command-line flags.
55argparser = argparse.ArgumentParser(add_help=False)
56argparser.add_argument('object_name',
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070057 help='Full Google Storage path of csv data (ex bucket/object)')
58argparser.add_argument('model_id',
59 help='Model Id of your choosing to name trained model')
60argparser.add_argument('project_id',
Jas Sohi91673b22016-07-31 18:41:01 -070061 help='Project Id of your Google Cloud Project')
Joe Gregorioe8391152013-06-28 01:30:57 -040062
63
Joe Gregorio968a9582012-03-07 14:52:52 -050064def print_header(line):
65 '''Format and print header block sized to length of line'''
66 header_str = '='
67 header_line = header_str * len(line)
INADA Naokie8d87822014-08-20 15:25:24 +090068 print('\n' + header_line)
69 print(line)
70 print(header_line)
Joe Gregorioe8391152013-06-28 01:30:57 -040071
72
Joe Gregorio652898b2011-05-02 21:07:43 -040073def main(argv):
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070074 # If you previously ran this app with an earlier version of the API
75 # or if you change the list of scopes below, revoke your app's permission
76 # here: https://accounts.google.com/IssuedAuthSubTokens
77 # Then re-run the app to re-authorize it.
Joe Gregorioe8391152013-06-28 01:30:57 -040078 service, flags = sample_tools.init(
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070079 argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser],
80 scope=(
81 'https://www.googleapis.com/auth/prediction',
82 'https://www.googleapis.com/auth/devstorage.read_only'))
Joe Gregorio652898b2011-05-02 21:07:43 -040083
Joe Gregorio7d791212011-05-16 21:58:52 -070084 try:
Joe Gregorio968a9582012-03-07 14:52:52 -050085 # Get access to the Prediction API.
Joe Gregorio968a9582012-03-07 14:52:52 -050086 papi = service.trainedmodels()
Joe Gregorio652898b2011-05-02 21:07:43 -040087
Joe Gregorio968a9582012-03-07 14:52:52 -050088 # List models.
89 print_header('Fetching list of first ten models')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070090 result = papi.list(maxResults=10, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +090091 print('List results:')
Joe Gregorio968a9582012-03-07 14:52:52 -050092 pprint.pprint(result)
93
94 # Start training request on a data set.
95 print_header('Submitting model training request')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070096 body = {'id': flags.model_id, 'storageDataLocation': flags.object_name}
97 start = papi.insert(body=body, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +090098 print('Training results:')
Joe Gregorio7d791212011-05-16 21:58:52 -070099 pprint.pprint(start)
Joe Gregorioe8391152013-06-28 01:30:57 -0400100
Joe Gregorio968a9582012-03-07 14:52:52 -0500101 # Wait for the training to complete.
102 print_header('Waiting for training to complete')
Joe Gregorio7d791212011-05-16 21:58:52 -0700103 while True:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700104 status = papi.get(id=flags.model_id, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500105 state = status['trainingStatus']
INADA Naokie8d87822014-08-20 15:25:24 +0900106 print('Training state: ' + state)
Joe Gregorio968a9582012-03-07 14:52:52 -0500107 if state == 'DONE':
Joe Gregorio7d791212011-05-16 21:58:52 -0700108 break
Joe Gregorio968a9582012-03-07 14:52:52 -0500109 elif state == 'RUNNING':
110 time.sleep(SLEEP_TIME)
111 continue
112 else:
113 raise Exception('Training Error: ' + state)
Joe Gregorioe8391152013-06-28 01:30:57 -0400114
Joe Gregorio968a9582012-03-07 14:52:52 -0500115 # Job has completed.
INADA Naokie8d87822014-08-20 15:25:24 +0900116 print('Training completed:')
Joe Gregorio968a9582012-03-07 14:52:52 -0500117 pprint.pprint(status)
118 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400119
Joe Gregorio968a9582012-03-07 14:52:52 -0500120 # Describe model.
121 print_header('Fetching model description')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700122 result = papi.analyze(id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900123 print('Analyze results:')
Joe Gregorio968a9582012-03-07 14:52:52 -0500124 pprint.pprint(result)
Joe Gregorio7d791212011-05-16 21:58:52 -0700125
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700126 # Make some predictions using the newly trained model.
127 print_header('Making some predictions')
128 for sample_text in ['mucho bueno', 'bonjour, mon cher ami']:
129 body = {'input': {'csvInstance': [sample_text]}}
130 result = papi.predict(
131 body=body, id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900132 print('Prediction results for "%s"...' % sample_text)
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700133 pprint.pprint(result)
Joe Gregorio652898b2011-05-02 21:07:43 -0400134
Joe Gregorio968a9582012-03-07 14:52:52 -0500135 # Delete model.
136 print_header('Deleting model')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700137 result = papi.delete(id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900138 print('Model deleted.')
Joe Gregorio652898b2011-05-02 21:07:43 -0400139
Joe Gregorioe8391152013-06-28 01:30:57 -0400140 except client.AccessTokenRefreshError:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700141 print ('The credentials have been revoked or expired, please re-run '
142 'the application to re-authorize.')
Joe Gregorio652898b2011-05-02 21:07:43 -0400143
Joe Gregorioe8391152013-06-28 01:30:57 -0400144
Joe Gregorio652898b2011-05-02 21:07:43 -0400145if __name__ == '__main__':
146 main(sys.argv)